Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[TOD] World, world metrics, script, tests #4178

Merged
merged 33 commits into from Dec 22, 2021
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6cb4b86
[TOD] Core converesation structure, serialization, const tokens
Nov 15, 2021
1480def
fix test by adding init folder
Nov 16, 2021
de84801
[Tod] Agents, teacher metrics, and tests for these
Nov 16, 2021
638eb28
[TOD] World, world metrics, script, tests
Nov 16, 2021
0e3f492
hmmm... hoping stacks don't bite me. (change that was kept in upper d…
Nov 16, 2021
0643a62
Merge branch 'simpler_tod_1_core_only' into simpler_tod_2_agents_teac…
Nov 16, 2021
37aced2
minor, remove commented out print
Nov 16, 2021
4f91279
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 16, 2021
b05930f
comment
Nov 16, 2021
5086e85
more comment updates (not sure if it actually helps clarity..)
Nov 16, 2021
51ed1a9
Merge branch 'main' into simpler_tod_1_core_only
Nov 16, 2021
a6508be
Merge branch 'simpler_tod_1_core_only' into simpler_tod_2_agents_teac…
Nov 16, 2021
eebc36b
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 16, 2021
3675781
use same version of black as in the pre-commit hook
Nov 16, 2021
086c91c
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 16, 2021
0bc961e
use same version of black as in the pre-commit hook
Nov 16, 2021
dfc4989
Merge branch 'main' into simpler_tod_2_agents_teachers
Nov 29, 2021
2f15448
address eric comments; add new readme + more documentation
Nov 30, 2021
abd1c7e
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 30, 2021
5d0197d
minor wording change
Nov 30, 2021
39792a8
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Nov 30, 2021
76bfa89
add more documtnation to world tests (following comment on teacher te…
Nov 30, 2021
73c5c7a
minor comment update
Nov 30, 2021
7ab9d70
update to respect actual count of episodes (I think this might have i…
Dec 1, 2021
c6c728d
Merge branch 'main' into simpler_tod_2_agents_teachers
Dec 1, 2021
b3283d0
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Dec 1, 2021
0580ff0
Merge branch 'main' into simpler_tod_2_agents_teachers
Dec 2, 2021
e00accf
Merge branch 'simpler_tod_2_agents_teachers' into simpler_tod_3_world
Dec 2, 2021
7b24acf
Merge branch 'main' into simpler_tod_3_world
Dec 18, 2021
83439b5
update comments to be a bit more descriptive of what's happening
Dec 18, 2021
f7d210e
lint
Dec 20, 2021
55c198c
generate -> setup to be correct
Dec 20, 2021
82c88b6
address paul comments
Dec 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 2 additions & 3 deletions parlai/core/tod/README.md
Expand Up @@ -12,6 +12,8 @@ As a convention, files referenced externally to this directory are prefixed with

tl;dr Extend `TodStructuredDataParser` for your particular dataset and implement `setup_episodes()` that converts the dataset into a list of episodes (`List[TodStructuredEpisode]`). Use multiple inheritence to generate teachers for training models. See files like `parlai/tasks/multiwoz_v22/agents.py` for an example.

See `tod_agents.py` for the classes.

## Overview of usage

For a given dataset, extend `TodStructuredDataParser` and implement `setup_episodes()` and `get_id_task_prefix()`. The former of these is expected to do the data processing to convert a dataset to `List[TodStructuredEpisode]`. From here, multiple inheritance can be used to define Agents and Teachers that utilize the data.
Expand Down Expand Up @@ -80,6 +82,3 @@ The world itself is stored in `tod_world.py`. The world follows the same interme

A general class for collecting metrics out of `TODWorld` is stored within `world_metrics.py` with individual 'metric handlers' responsible for calculating a given metric stored in `world_metric_handlers.py`.




4 changes: 2 additions & 2 deletions parlai/core/tod/tod_agents.py
Expand Up @@ -38,7 +38,7 @@ class TodStructuredDataParser(Agent):
"""
Base class that specifies intermediate representations for Tod conversations.

Inherit from this class and implement `generate_episodes()` to implement the intermediate representation for a specific dataset. Use multiple inheritence with classes that implement an `act()` below to use.
Inherit from this class and implement `setup_episodes()` to implement the intermediate representation for a specific dataset. Use multiple inheritence with classes that implement an `act()` below to use.

For example, if we have a `MyDataset_DataParser(TodStructuredDataParser)` and wanted to make a teacher to train a model togenerate User Utterances based on a goal prompt, we would do so by defining `class MyDatasetUserSimulatorTeacher(MyDataset_DataParser, TodUserSimulatorTeacher)`.
"""
Expand Down Expand Up @@ -374,7 +374,7 @@ def act(self):
self.already_reset = False
if tod.STANDARD_API_SCHEMAS in self.observation.get("text", ""):
return {
"text": tod.STANDARD_API_SCHEMAS, # Default convention for the first turn for NO SCHEMA models, which is fine for evaluation
"text": tod.STANDARD_API_SCHEMAS, # Default convention for the first turn
"id": self.id,
"domain": self.episode.domain,
"episode_done": False,
Expand Down
Binary file not shown.
313 changes: 313 additions & 0 deletions parlai/core/tod/tod_world.py
@@ -0,0 +1,313 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Class for running task-oriented dialogue chats.

Specifically this class handles:
1. Setting up and running the different conversational agents to fit the TOD Conversational structure (see `tod_core.py`; also retiterated in TodWorld below)
2. Doing the above to handle logic for batching
moyapchen marked this conversation as resolved.
Show resolved Hide resolved
3. Recording various metrics associated with running the world.

See long comment on TodWorld for description of the conversation format and more functionality descriptions.

Metrics calculated from these simulations are documented in `world_metrics.py` (for
general usage) and `world_metrics_handlers.py` (for specific metric calculations)
"""
from parlai.core.metrics import Metric, LegacyMetric
from parlai.core.message import Message
from parlai.core.opt import Opt
from parlai.core.worlds import World
from parlai.agents.local_human.local_human import LocalHumanAgent
from parlai.utils.misc import display_messages

import parlai.core.tod.tod_core as tod
import parlai.core.tod.world_metrics as tod_metrics

import sys
import copy

# Following needs to be kept consistent with opt settings/tod script
USER_UTT_IDX = 0
API_CALL_IDX = 1
API_RESP_IDX = 2
SYSTEM_UTT_IDX = 3
API_SCHEMA_GROUNDING_IDX = 4
GOAL_GROUNDING_IDX = 5
AGENT_COUNT = 6

SPEAKER_TO_NAME = {
USER_UTT_IDX: tod.TodAgentType.USER_UTT_AGENT,
API_CALL_IDX: tod.TodAgentType.API_CALL_AGENT,
API_RESP_IDX: tod.TodAgentType.API_RESP_AGENT,
SYSTEM_UTT_IDX: tod.TodAgentType.SYSTEM_UTT_AGENT,
API_SCHEMA_GROUNDING_IDX: tod.TodAgentType.API_SCHEMA_GROUNDING_AGENT,
GOAL_GROUNDING_IDX: tod.TodAgentType.GOAL_GROUNDING_AGENT,
}

NAME_TO_IDX = {v: k for k, v in SPEAKER_TO_NAME.items()}


class TodWorld(World):
"""
Base world for running TOD model-model chats. Following agents.
moyapchen marked this conversation as resolved.
Show resolved Hide resolved

* User utt agent
* API call agent
* Currently assumed to be same as system utt agent in script code, though used as if separate in this world.
* API responder agent
* System utt agent
* API schema groundinger agent (given to api call + response agent)
* Goal groundinger agent (given to user)

As is standard for ParlAI, these agents may be models or may be standalone classes that extend the "Agent" class. The models for these *are* expected to have their utterances in a standard format.

Note that we expect these to be passed in via the opt manually, since some assumptions of regular ParlAI Worlds (ex. task = agent[0], model = agent[1]) are broken here since there is no "task agent" and one agent can be two "roles" (ex. system agent also making API calls)
moyapchen marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, opt: Opt, agents=None, shared=None):
super().__init__(opt, agents, shared)
self.batchsize = opt["batchsize"]
self.batch_agents = []
self.batch_acts = []
self.batch_goals = [] # for case when num_episodes < batchsize
self.batch_tod_world_metrics = []
for i in range(self.batchsize):
here_agents = []
for j, agent in enumerate(agents):
if (
j == SYSTEM_UTT_IDX
): # handle separately cause we expect it to be same as API_CALL agent
here_agents.append(here_agents[API_CALL_IDX])
continue
share = agent.share()
batch_opt = copy.deepcopy(share["opt"])
batch_opt["batchindex"] = i
here_agents.append(share["class"](batch_opt, share))
self.batch_agents.append(here_agents)
self.batch_acts.append([Message.padding_example()] * 4)
self.batch_tod_world_metrics.append(tod_metrics.TodMetrics())
self.end_episode = [False] * self.batchsize

self.max_turns = self.opt.get("max_turns", 30)
self.turns = 0
self.need_grounding = True

def grounding(self):
"""
Preempt with goal and schema-based intent schemas.

As a logging hack, we stick the schema gronding in as a user utterance, but
manually pass the value in to the relevant API call/resp agent, since passing it
to the API call agent elsewhere is a little awkward. Similarly, we stick the
goal as a system utterance so that it is captured in logging. However, we do not
pass it in manually, since getting the user utterance will be the first turn of
`parley()`.
"""
self._observe_and_act(
SYSTEM_UTT_IDX, # Doesn't matter, empty at this point
USER_UTT_IDX, # Hack in to a place that'll look nice when printing
f"getting API schema grounding. (Must start with `{tod.STANDARD_API_SCHEMAS}`)",
API_SCHEMA_GROUNDING_IDX,
)

self._observe_and_act(
USER_UTT_IDX,
API_CALL_IDX,
"responding to api schema grounding (empty enter is usually fine) ",
)
self._observe_and_act(
USER_UTT_IDX,
API_RESP_IDX,
"responding to api schema grounding (empty enter is usually fine)",
)

self._observe_and_act(
SYSTEM_UTT_IDX, # Doesn't matter for the most part, but want something empty
SYSTEM_UTT_IDX, # Hack into a place per comment above
f"getting goal grounding. (Must start with `{tod.STANDARD_GOAL}`)",
GOAL_GROUNDING_IDX,
)
self.batch_goals = [act[SYSTEM_UTT_IDX] for act in self.batch_acts]
self.turns = 0

def parley(self):
if self.need_grounding:
self.grounding()
self.need_grounding = False

else:
self._observe_and_act(SYSTEM_UTT_IDX, USER_UTT_IDX)
self._observe_and_act(USER_UTT_IDX, API_CALL_IDX)
self._observe_and_act(API_CALL_IDX, API_RESP_IDX)
self._observe_and_act(API_RESP_IDX, SYSTEM_UTT_IDX)

self.turns += 1
self.update_counters()

def _observe_and_act(
self, observe_idx, act_idx, info="for regular parley", override_act_idx=None
):
act_agent_idx = override_act_idx if override_act_idx else act_idx
act_agent = self.agents[act_agent_idx]
record_output_idx = act_idx
if hasattr(act_agent, "batch_act"):
batch_observations = []
for i in range(self.batchsize):
if not self.end_episode[i]:
observe = self.batch_acts[i][observe_idx]
observe = self.batch_agents[i][act_agent_idx].observe(observe)
batch_observations.append(Message(observe))
else:
# We're done with this episode, so just do a pad.
# NOTE: This could cause issues with RL down the line
batch_observations.append(Message.padding_example())
self.batch_acts[i][record_output_idx] = {"text": "", "id": ""}
batch_actions = act_agent.batch_act(batch_observations)
for i in range(self.batchsize):
if self.end_episode[i]:
continue
self.batch_acts[i][record_output_idx] = batch_actions[i]
self.batch_agents[i][record_output_idx].self_observe(batch_actions[i])
else: # Run on agents individually
for i in range(self.batchsize):
act_agent = (
self.batch_agents[i][override_act_idx]
if override_act_idx
else self.batch_agents[i][act_idx]
)
if hasattr(act_agent, "episode_done") and act_agent.episode_done():
self.end_episode[i] = True
if self.end_episode[i]:
# Following line exists because:
# 1. Code for writing converseations is not hapy if an "id" does not exists with a sample
# 2. Because of the `self.end_episode` code, no agent will see this example anyway.
self.batch_acts[i][record_output_idx] = {"text": "", "id": ""}
continue
act_agent.observe(self.batch_acts[i][observe_idx])
if isinstance(act_agent, LocalHumanAgent):
print(
f"Getting message for {SPEAKER_TO_NAME[record_output_idx]} for {info} in batch {i}"
)
try:
self.batch_acts[i][record_output_idx] = act_agent.act()
except StopIteration:
self.end_episode[i] = True
for i in range(self.batchsize):
if self.end_episode[i]:
continue
self.batch_tod_world_metrics[i].handle_message(
self.batch_acts[i][record_output_idx], SPEAKER_TO_NAME[act_agent_idx]
)
if tod.STANDARD_DONE in self.batch_acts[i][record_output_idx].get(
"text", ""
):
# User models trained to output a "DONE" on last turn; same with human agents.
self.end_episode[i] = True

def report(self):
"""
Report all metrics of all subagents + of this world in aggregate.
"""
metrics_separate = []
for i in range(self.batchsize):
here_metrics = self.batch_tod_world_metrics[i].report()
for name, agent in [
(SPEAKER_TO_NAME[j], self.batch_agents[i][j])
for j in [USER_UTT_IDX, API_CALL_IDX, API_RESP_IDX, SYSTEM_UTT_IDX]
]:
name_prefix = name[:-6] # strip "_agent"
if hasattr(agent, "report"):
m = agent.report()
if m is None:
continue
for k, v in m.items():
if not isinstance(v, Metric):
v = LegacyMetric(v)
here_metrics[f"{name_prefix}_{k}"] = v
metrics_separate.append(here_metrics)
metrics = metrics_separate[0]
for i in range(1, self.batchsize):
for k, v in metrics_separate[i].items():
if k not in metrics:
metrics[k] = v
else:
metrics[k] = metrics[k] + v
return metrics

def reset(self):
"""
Resets state of world; also sets up episode metrics.
"""
super().reset()
self.need_grounding = True
self.turns = 0

self.last_batch_episode_metrics = []
self.batch_acts = []
for i in range(self.batchsize):
for agent in self.batch_agents[i]:
agent.reset()
self.batch_acts.append([None] * 4)

self.batch_tod_world_metrics[i].episode_reset()
metrics = self.batch_tod_world_metrics[i].get_last_episode_metrics()
if metrics:
self.last_batch_episode_metrics.append(metrics)
self.end_episode = [False] * self.batchsize

def get_last_batch_episode_metrics(self):
return self.last_batch_episode_metrics

def get_last_batch_goals(self):
return self.batch_goals

def episode_done(self):
if self.turns >= self.max_turns or all(self.end_episode):
return True
for i in range(self.batchsize):
for j in [USER_UTT_IDX, API_CALL_IDX, API_RESP_IDX, SYSTEM_UTT_IDX]:
if (
self.batch_acts[i][j] is not None
and tod.STANDARD_DONE in self.batch_acts[i][j].get("text", "")
) or (
hasattr(self.batch_agents[i][j], "episode_done")
and self.batch_agents[i][j].episode_done()
):
self.end_episode[i] = True
return all(self.end_episode)

def epoch_done(self):
for agent in self.agents:
if agent.epoch_done():
return True

def num_episodes(self):
result = sys.maxsize
for agent in self.agents:
if hasattr(agent, "num_episodes") and agent.num_episodes() > 0:
result = min(result, agent.num_episodes())
if result == sys.maxsize:
return 0
return result

def get_batch_acts(self):
return self.batch_acts

def display(self):
s = "[--batchsize " + str(self.batchsize) + "--]\n"
for i in range(self.batchsize):
s += "[batch " + str(i) + ":]\n"
s += display_messages(
self.batch_acts[i],
ignore_agent_reply=self.opt.get("ignore_agent_reply", False),
add_fields=self.opt.get("display_add_fields", ""),
prettify=self.opt.get("display_prettify", False),
max_len=self.opt.get("max_display_len", 1000),
verbose=self.opt.get("verbose", False),
)
s += "\n"
s += "[--end of batch--]\n"
return s