diff --git a/parlai/core/tod/README.md b/parlai/core/tod/README.md index 82a94bf2fbb..f29e5d2778a 100644 --- a/parlai/core/tod/README.md +++ b/parlai/core/tod/README.md @@ -12,6 +12,8 @@ As a convention, files referenced externally to this directory are prefixed with tl;dr Extend `TodStructuredDataParser` for your particular dataset and implement `setup_episodes()` that converts the dataset into a list of episodes (`List[TodStructuredEpisode]`). Use multiple inheritence to generate teachers for training models. See files like `parlai/tasks/multiwoz_v22/agents.py` for an example. +See `tod_agents.py` for the classes. + ## Overview of usage For a given dataset, extend `TodStructuredDataParser` and implement `setup_episodes()` and `get_id_task_prefix()`. The former of these is expected to do the data processing to convert a dataset to `List[TodStructuredEpisode]`. From here, multiple inheritance can be used to define Agents and Teachers that utilize the data. @@ -80,6 +82,3 @@ The world itself is stored in `tod_world.py`. The world follows the same interme A general class for collecting metrics out of `TODWorld` is stored within `world_metrics.py` with individual 'metric handlers' responsible for calculating a given metric stored in `world_metric_handlers.py`. - - - diff --git a/parlai/core/tod/tod_agents.py b/parlai/core/tod/tod_agents.py index 559899e91ff..a50b1adfdf9 100644 --- a/parlai/core/tod/tod_agents.py +++ b/parlai/core/tod/tod_agents.py @@ -38,7 +38,7 @@ class TodStructuredDataParser(Agent): """ Base class that specifies intermediate representations for Tod conversations. - Inherit from this class and implement `generate_episodes()` to implement the intermediate representation for a specific dataset. Use multiple inheritence with classes that implement an `act()` below to use. + Inherit from this class and implement `setup_episodes()` to implement the intermediate representation for a specific dataset. Use multiple inheritence with classes that implement an `act()` below to use. For example, if we have a `MyDataset_DataParser(TodStructuredDataParser)` and wanted to make a teacher to train a model togenerate User Utterances based on a goal prompt, we would do so by defining `class MyDatasetUserSimulatorTeacher(MyDataset_DataParser, TodUserSimulatorTeacher)`. """ @@ -374,7 +374,7 @@ def act(self): self.already_reset = False if tod.STANDARD_API_SCHEMAS in self.observation.get("text", ""): return { - "text": tod.STANDARD_API_SCHEMAS, # Default convention for the first turn for NO SCHEMA models, which is fine for evaluation + "text": tod.STANDARD_API_SCHEMAS, # Default convention for the first turn "id": self.id, "domain": self.episode.domain, "episode_done": False, diff --git a/parlai/core/tod/tod_test_utils/standalone_api_file.pickle b/parlai/core/tod/tod_test_utils/standalone_api_file.pickle new file mode 100644 index 00000000000..b27e3b29741 Binary files /dev/null and b/parlai/core/tod/tod_test_utils/standalone_api_file.pickle differ diff --git a/parlai/core/tod/tod_world.py b/parlai/core/tod/tod_world.py new file mode 100644 index 00000000000..32ff7072889 --- /dev/null +++ b/parlai/core/tod/tod_world.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Class for running task-oriented dialogue chats. + +Specifically this class handles: +1. Setting up and running the different conversational agents to fit the TOD Conversational structure (see `tod_core.py`; also retiterated in TodWorld below) +2. Handle logic for batching when running dialogues +3. Recording various metrics associated with running the world. + +See long comment on TodWorld for description of the conversation format and more functionality descriptions. + +Metrics calculated from these simulations are documented in `world_metrics.py` (for +general usage) and `world_metrics_handlers.py` (for specific metric calculations) +""" +from parlai.core.metrics import Metric, LegacyMetric +from parlai.core.message import Message +from parlai.core.opt import Opt +from parlai.core.worlds import World +from parlai.agents.local_human.local_human import LocalHumanAgent +from parlai.utils.misc import display_messages + +import parlai.core.tod.tod_core as tod +import parlai.core.tod.world_metrics as tod_metrics + +import sys +import copy + +# Following needs to be kept consistent with opt settings/tod script +USER_UTT_IDX = 0 +API_CALL_IDX = 1 +API_RESP_IDX = 2 +SYSTEM_UTT_IDX = 3 +API_SCHEMA_GROUNDING_IDX = 4 +GOAL_GROUNDING_IDX = 5 +AGENT_COUNT = 6 + +SPEAKER_TO_NAME = { + USER_UTT_IDX: tod.TodAgentType.USER_UTT_AGENT, + API_CALL_IDX: tod.TodAgentType.API_CALL_AGENT, + API_RESP_IDX: tod.TodAgentType.API_RESP_AGENT, + SYSTEM_UTT_IDX: tod.TodAgentType.SYSTEM_UTT_AGENT, + API_SCHEMA_GROUNDING_IDX: tod.TodAgentType.API_SCHEMA_GROUNDING_AGENT, + GOAL_GROUNDING_IDX: tod.TodAgentType.GOAL_GROUNDING_AGENT, +} + +NAME_TO_IDX = {v: k for k, v in SPEAKER_TO_NAME.items()} + + +class TodWorld(World): + """ + Base world for running TOD model-model chats. Includes the following agents: + + * User utt agent + * API call agent + * Currently assumed to be same as system utt agent in script code, though used as if separate in this world for clarity + * API responder agent + * System utt agent + * API schema groundinger agent (given to api call + response agent) + * Goal groundinger agent (given to user) + + As is standard for ParlAI, these agents may be models or may be standalone classes that extend the "Agent" class. The models for these *are* expected to have their utterances in a standard format. + + We do expect these agents to be passed in with a set order (see above), since some assumptions of regular ParlAI Worlds (ex. task = agent[0], model = agent[1]) are broken here since there is no "task agent" and one agent can be two "roles" (ex. system agent also making API calls) + """ + + def __init__(self, opt: Opt, agents=None, shared=None): + super().__init__(opt, agents, shared) + self.batchsize = opt["batchsize"] + self.batch_agents = [] + self.batch_acts = [] + self.batch_goals = [] # for case when num_episodes < batchsize + self.batch_tod_world_metrics = [] + for i in range(self.batchsize): + here_agents = [] + for j, agent in enumerate(agents): + if ( + j == SYSTEM_UTT_IDX + ): # handle separately cause we expect it to be same as API_CALL agent + here_agents.append(here_agents[API_CALL_IDX]) + continue + share = agent.share() + batch_opt = copy.deepcopy(share["opt"]) + batch_opt["batchindex"] = i + here_agents.append(share["class"](batch_opt, share)) + self.batch_agents.append(here_agents) + self.batch_acts.append([Message.padding_example()] * 4) + self.batch_tod_world_metrics.append(tod_metrics.TodMetrics()) + self.end_episode = [False] * self.batchsize + + self.max_turns = self.opt.get("max_turns", 30) + self.turns = 0 + self.need_grounding = True + + def grounding(self): + """ + Preempt with goal and schema-based intent schemas. + + As a logging hack, we stick the schema gronding in as a user utterance, but + manually pass the value in to the relevant API call/resp agent, since passing it + to the API call agent elsewhere is a little awkward. Similarly, we stick the + goal as a system utterance so that it is captured in logging. However, we do not + pass it in manually, since getting the user utterance will be the first turn of + `parley()`. + """ + self._observe_and_act( + SYSTEM_UTT_IDX, # Doesn't matter, empty at this point + USER_UTT_IDX, # Hack in to a place that'll look nice when printing + f"getting API schema grounding. (Must start with `{tod.STANDARD_API_SCHEMAS}`)", + API_SCHEMA_GROUNDING_IDX, + ) + + self._observe_and_act( + USER_UTT_IDX, + API_CALL_IDX, + "responding to api schema grounding (empty enter is usually fine) ", + ) + self._observe_and_act( + USER_UTT_IDX, + API_RESP_IDX, + "responding to api schema grounding (empty enter is usually fine)", + ) + + self._observe_and_act( + SYSTEM_UTT_IDX, # Doesn't matter for the most part, but want something empty + SYSTEM_UTT_IDX, # Hack into a place per comment above + f"getting goal grounding. (Must start with `{tod.STANDARD_GOAL}`)", + GOAL_GROUNDING_IDX, + ) + self.batch_goals = [act[SYSTEM_UTT_IDX] for act in self.batch_acts] + self.turns = 0 + + def parley(self): + if self.need_grounding: + self.grounding() + self.need_grounding = False + + else: + self._observe_and_act(SYSTEM_UTT_IDX, USER_UTT_IDX) + self._observe_and_act(USER_UTT_IDX, API_CALL_IDX) + self._observe_and_act(API_CALL_IDX, API_RESP_IDX) + self._observe_and_act(API_RESP_IDX, SYSTEM_UTT_IDX) + + self.turns += 1 + self.update_counters() + + def _observe_and_act( + self, observe_idx, act_idx, info="for regular parley", override_act_idx=None + ): + act_agent_idx = override_act_idx if override_act_idx else act_idx + act_agent = self.agents[act_agent_idx] + record_output_idx = act_idx + if hasattr(act_agent, "batch_act"): + batch_observations = [] + for i in range(self.batchsize): + if not self.end_episode[i]: + observe = self.batch_acts[i][observe_idx] + observe = self.batch_agents[i][act_agent_idx].observe(observe) + batch_observations.append(Message(observe)) + else: + # We're done with this episode, so just do a pad. + # NOTE: This could cause issues with RL down the line + batch_observations.append(Message.padding_example()) + self.batch_acts[i][record_output_idx] = {"text": "", "id": ""} + batch_actions = act_agent.batch_act(batch_observations) + for i in range(self.batchsize): + if self.end_episode[i]: + continue + self.batch_acts[i][record_output_idx] = batch_actions[i] + self.batch_agents[i][record_output_idx].self_observe(batch_actions[i]) + else: # Run on agents individually + for i in range(self.batchsize): + act_agent = ( + self.batch_agents[i][override_act_idx] + if override_act_idx + else self.batch_agents[i][act_idx] + ) + if hasattr(act_agent, "episode_done") and act_agent.episode_done(): + self.end_episode[i] = True + if self.end_episode[i]: + # Following line exists because: + # 1. Code for writing converseations is not hapy if an "id" does not exists with a sample + # 2. Because of the `self.end_episode` code, no agent will see this example anyway. + self.batch_acts[i][record_output_idx] = {"text": "", "id": ""} + continue + act_agent.observe(self.batch_acts[i][observe_idx]) + if isinstance(act_agent, LocalHumanAgent): + print( + f"Getting message for {SPEAKER_TO_NAME[record_output_idx]} for {info} in batch {i}" + ) + try: + self.batch_acts[i][record_output_idx] = act_agent.act() + except StopIteration: + self.end_episode[i] = True + for i in range(self.batchsize): + if self.end_episode[i]: + continue + self.batch_tod_world_metrics[i].handle_message( + self.batch_acts[i][record_output_idx], SPEAKER_TO_NAME[act_agent_idx] + ) + if tod.STANDARD_DONE in self.batch_acts[i][record_output_idx].get( + "text", "" + ): + # User models trained to output a "DONE" on last turn; same with human agents. + self.end_episode[i] = True + + def report(self): + """ + Report all metrics of all subagents + of this world in aggregate. + """ + metrics_separate = [] + for i in range(self.batchsize): + here_metrics = self.batch_tod_world_metrics[i].report() + for name, agent in [ + (SPEAKER_TO_NAME[j], self.batch_agents[i][j]) + for j in [USER_UTT_IDX, API_CALL_IDX, API_RESP_IDX, SYSTEM_UTT_IDX] + ]: + name_prefix = name[:-6] # strip "_agent" + if hasattr(agent, "report"): + m = agent.report() + if m is None: + continue + for k, v in m.items(): + if not isinstance(v, Metric): + v = LegacyMetric(v) + here_metrics[f"{name_prefix}_{k}"] = v + metrics_separate.append(here_metrics) + metrics = metrics_separate[0] + for i in range(1, self.batchsize): + for k, v in metrics_separate[i].items(): + if k not in metrics: + metrics[k] = v + else: + metrics[k] = metrics[k] + v + return metrics + + def reset(self): + """ + Resets state of world; also sets up episode metrics. + """ + super().reset() + self.need_grounding = True + self.turns = 0 + + self.last_batch_episode_metrics = [] + self.batch_acts = [] + for i in range(self.batchsize): + for agent in self.batch_agents[i]: + agent.reset() + self.batch_acts.append([None] * 4) + + self.batch_tod_world_metrics[i].episode_reset() + metrics = self.batch_tod_world_metrics[i].get_last_episode_metrics() + if metrics: + self.last_batch_episode_metrics.append(metrics) + self.end_episode = [False] * self.batchsize + + def get_last_batch_episode_metrics(self): + return self.last_batch_episode_metrics + + def get_last_batch_goals(self): + return self.batch_goals + + def episode_done(self): + if self.turns >= self.max_turns or all(self.end_episode): + return True + for i in range(self.batchsize): + for j in [USER_UTT_IDX, API_CALL_IDX, API_RESP_IDX, SYSTEM_UTT_IDX]: + if ( + self.batch_acts[i][j] is not None + and tod.STANDARD_DONE in self.batch_acts[i][j].get("text", "") + ) or ( + hasattr(self.batch_agents[i][j], "episode_done") + and self.batch_agents[i][j].episode_done() + ): + self.end_episode[i] = True + return all(self.end_episode) + + def epoch_done(self): + for agent in self.agents: + if agent.epoch_done(): + return True + + def num_episodes(self): + result = sys.maxsize + for agent in self.agents: + if hasattr(agent, "num_episodes") and agent.num_episodes() > 0: + result = min(result, agent.num_episodes()) + if result == sys.maxsize: + return 0 + return result + + def get_batch_acts(self): + return self.batch_acts + + def display(self): + s = "[--batchsize " + str(self.batchsize) + "--]\n" + for i in range(self.batchsize): + s += "[batch " + str(i) + ":]\n" + s += display_messages( + self.batch_acts[i], + ignore_agent_reply=self.opt.get("ignore_agent_reply", False), + add_fields=self.opt.get("display_add_fields", ""), + prettify=self.opt.get("display_prettify", False), + max_len=self.opt.get("max_display_len", 1000), + verbose=self.opt.get("verbose", False), + ) + s += "\n" + s += "[--end of batch--]\n" + return s diff --git a/parlai/core/tod/world_metrics.py b/parlai/core/tod/world_metrics.py new file mode 100644 index 00000000000..17d96b8698f --- /dev/null +++ b/parlai/core/tod/world_metrics.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Wrapper object holding metrics for TODWorld. + +This class is in its own file to prevent circular dependencies + monolithic files. +""" + +from parlai.core.message import Message +from parlai.core.metrics import Metrics +from parlai.core.tod.tod_core import ( + TodAgentType, + TOD_AGENT_TYPE_TO_PREFIX, + SerializationHelpers, + STANDARD_GOAL, +) +from typing import Any, Dict +import parlai.core.tod.world_metrics_handlers as world_metrics_handlers + +# Change the following to define which Metrics Handlers are used in TodWorld. +# The ones used below are from `world_metrics_handlers.py` only. However, See `parlai/projects/tod_simulator/world_metrics/extended_world_metrics.py` for others. + +WORLD_METRIC_HANDLERS = [ + world_metrics_handlers.AllGoalApiCallSuccessMetricsHandler, + world_metrics_handlers.UserGeneratedDoneMetricHandler, +] + + +class TodMetrics(Metrics): + """ + Helper container which encapsulates TOD metrics and does some basic prepocessing to + handlers to calculate said metrics. + + This class should generally not need to be changed; add new metrics handlers to + `WORLD_METRIC_HANDLERS` (or otherwise override `self.handlers` of this class) to + change metrics actively being used. + """ + + def __init__(self, shared: Dict[str, Any] = None) -> None: + super().__init__(shared=shared) + self.handlers = [x() for x in WORLD_METRIC_HANDLERS] + self.convo_started = False + self.last_episode_metrics = Metrics() + + def handle_message(self, message: Message, agent_type: TodAgentType): + if "text" not in message: + return + if agent_type == TodAgentType.GOAL_GROUNDING_AGENT and len( + message["text"] + ) > len(STANDARD_GOAL): + # Only count a conversation as started if there is a goal. + self.convo_started = True + for handler in self.handlers: + metrics = self._handle_message_impl(message, agent_type, handler) + if metrics is not None: + for name, metric in metrics.items(): + if metric is not None: + self.add(name, metric) + + def _handle_message_impl( + self, + message: Message, + agent_type: TodAgentType, + handler: world_metrics_handlers.TodMetricsHandler, + ): + prefix_stripped_text = message["text"].replace( + TOD_AGENT_TYPE_TO_PREFIX[agent_type], "" + ) + if agent_type is TodAgentType.API_SCHEMA_GROUNDING_AGENT: + return handler.handle_api_schemas( + message, SerializationHelpers.str_to_api_schemas(prefix_stripped_text) + ) + if agent_type is TodAgentType.GOAL_GROUNDING_AGENT: + return handler.handle_goals( + message, SerializationHelpers.str_to_goals(prefix_stripped_text) + ) + if agent_type is TodAgentType.USER_UTT_AGENT: + return handler.handle_user_utt(message, prefix_stripped_text) + if agent_type is TodAgentType.API_CALL_AGENT: + return handler.handle_api_call( + message, SerializationHelpers.str_to_api_dict(prefix_stripped_text) + ) + if agent_type is TodAgentType.API_RESP_AGENT: + return handler.handle_api_resp( + message, SerializationHelpers.str_to_api_dict(prefix_stripped_text) + ) + if agent_type is TodAgentType.SYSTEM_UTT_AGENT: + return handler.handle_sys_utt(message, prefix_stripped_text) + + def get_last_episode_metrics(self): + """ + This is a bit of a hack so that we can report whether or not a convo has + successfully hit all goals and associate this with each episode for the purposes + of doing filtering. + """ + return self.last_episode_metrics + + def episode_reset(self): + self.last_episode_metrics = None + if self.convo_started: + self.last_episode_metrics = Metrics() + for handler in self.handlers: + metrics = handler.get_episode_metrics() + handler.episode_reset() + if metrics is not None: + for name, metric in metrics.items(): + if metric is not None: + self.add(name, metric) + self.last_episode_metrics.add(name, metric) + self.convo_started = False diff --git a/parlai/core/tod/world_metrics_handlers.py b/parlai/core/tod/world_metrics_handlers.py new file mode 100644 index 00000000000..ffa7343f114 --- /dev/null +++ b/parlai/core/tod/world_metrics_handlers.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Metrics handlers - ie, objects that handle generations from Tod World and calculates metrics from them. + +Note that only metrics handler classes in `WORLD_METRIC_HANDLERS` (of `world_metrics.py`) are actively being recorded as metrics. +""" + +from parlai.core.message import Message +from parlai.core.metrics import Metric, AverageMetric +from parlai.core.tod.tod_core import STANDARD_DONE +from typing import Dict, List, Optional + +METRICS_HANDLER_CLASSES_TEST_REGISTRY = set() # for tests + + +def register_metrics_handler(cls): + METRICS_HANDLER_CLASSES_TEST_REGISTRY.add(cls) + return cls + + +class TodMetricsHandler: + """ + Base class for Tod Metrics handlers. Extend this class then add them to + `WORLD_METRIC_HANDLERS` to use. If you would like the class to be exposed to tests, + add the Metrics Handler to `METRICS_HANDLER_CLASSES_TEST_REGISTRY` via annotating + with `@register_metrics_handler`. + + The `TodMetrics` class will, on this class + 1. call `__init__` (which internally calls `episode_reset()`) to begin with. + 2. call each of the `handle..()` functions as the appropriate turns occur + 3. call `get_episode_metrics()` then `episode_reset()` at the end of the episode + + The `handle..()` should be used to set intermediate state within the class and `episode_reset()` should be used to clear this state. + + The output of the `handle..()` and `get_episode_metrics()` functions are both `Optional[Dict[str, Metric]]`s. Metrics from both of these paths will be aggregated and reported to `TodMetrics`, so which one to use is mostly a matter of preference, though + 1. one should take care to only use one or the other and not both, to avoid double-counting + 2. those from `get_episode_metrics()` will be recorded per-episode and saved to `tod_world_script`'s report as well + + `UserGeneratedDoneMetricHandler` in this file, which collects metrics about frequency of seeing the "[DONE]" token on User utterances and also records conversation length, is a fairly straightforward example of usage. + + Other tried (but not in current active use) Metrics Handers are in `projects/tod_simulator/world_metrics/extended_world_metrics.py`. + """ + + def __init__(self): + self.episode_reset() + + def episode_reset(self): + pass + + def handle_api_schemas( + self, message: Message, api_schemas: List[Dict] + ) -> Optional[Dict[str, Metric]]: + self.api_schemas = api_schemas + + def handle_goals( + self, message: Message, goals: List[Dict] + ) -> Optional[Dict[str, Metric]]: + self.goals = goals + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + pass + + def handle_api_call( + self, message: Message, api_call: Dict + ) -> Optional[Dict[str, Metric]]: + pass + + def handle_api_resp( + self, message: Message, api_resp: Dict + ) -> Optional[Dict[str, Metric]]: + pass + + def handle_sys_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + pass + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + pass + + +################################ +# Functions and classes associated with calculating statistics between API Calls and Goals. +def goals_hit_helper( + goals: List[Dict], turnDict: List[Dict], permissive=False +) -> (AverageMetric, AverageMetric, AverageMetric): + """ + Helper function that aids in seeing if the API calls the system has attempted to + make manages to meet the goals the conversation has. + + Return values: + * if all goals hit + * # of turns it took to hit all goals (or None) + * fraction of goals hit + """ + goals_left = goals + + def exact_match(goal, turn): # if and only if + return goal == turn + + def permissive_match(goal, turn): # guess is superset + for key in goal: + if turn.get(key, "definitelyNotIn") != goal[key]: + return False + return True + + compare_func = permissive_match if permissive else exact_match + + for i, turn in enumerate(turnDict): + goals_left = [goal for goal in goals_left if not compare_func(goal, turn)] + if len(goals_left) == 0: + return AverageMetric(True), AverageMetric(i + 1), AverageMetric(1) + return ( + AverageMetric(False), + AverageMetric(0), + AverageMetric(len(goals) - len(goals_left), len(goals)), + ) + + +class _ApiCallGoalInteractionHelper(TodMetricsHandler): + """ + Base class for storing details about valid API calls (information about Goals + handled in TodMetricsHandler) + """ + + def episode_reset(self): + self.api_turns = [] + + def handle_api_call( + self, message: Message, api_call: Dict + ) -> Optional[Dict[str, Metric]]: + if len(api_call) > 0: + self.api_turns.append(api_call) + + +@register_metrics_handler +class AllGoalApiCallSuccessMetricsHandler(_ApiCallGoalInteractionHelper): + """ + Calculates synthetic Task Success + related metrics for converseations. + + Test coverage of this class is with `LegacyGoalApiCallInteractionsMetricsHandler` + """ + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + all_goals_hit, _, _ = goals_hit_helper(self.goals, self.api_turns) + call_attempts = len(self.api_turns) + return { + "synthetic_task_success": all_goals_hit, + "api_call_attempts": AverageMetric(call_attempts), + } + + +@register_metrics_handler +class UserGeneratedDoneMetricHandler(TodMetricsHandler): + def episode_reset(self): + self.done_seen = False + self.turn_count = 0 + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + self.done_seen |= STANDARD_DONE in message["text"] + self.turn_count += 1 + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + result = {"done_seen": AverageMetric(self.done_seen)} + if self.done_seen: + result["round_count_done_seen"] = AverageMetric(self.turn_count) + result["rounds_count_all_conversations"] = AverageMetric(self.turn_count) + return result diff --git a/parlai/scripts/distributed_tod_world_script.py b/parlai/scripts/distributed_tod_world_script.py new file mode 100644 index 00000000000..8dc36a6feb7 --- /dev/null +++ b/parlai/scripts/distributed_tod_world_script.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Distributed script for running TOD model-model chats. + +Not to be called directly; should be called from SLURM +""" + +from parlai.scripts.tod_world_script import TodWorldScript +from parlai.core.script import ParlaiScript +import parlai.utils.distributed as distributed_utils + + +class DistributedTodWorldScript(ParlaiScript): + @classmethod + def setup_args(cls): + parser = TodWorldScript.setup_args() + parser.add_distributed_training_args() + parser.add_argument("--port", type=int, default=61337, help="TCP port number") + return parser + + def run(self): + with distributed_utils.slurm_distributed_context(self.opt) as opt: + return TodWorldScript(opt).run() + + +if __name__ == "__main__": + DistributedTodWorldScript.main() diff --git a/parlai/scripts/tod_world_script.py b/parlai/scripts/tod_world_script.py new file mode 100644 index 00000000000..edfd23bdafc --- /dev/null +++ b/parlai/scripts/tod_world_script.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Base script for running TOD model-model chats. + +For example, to extract gold ground truth data from the holdout version of Google SGD, run + +``` +python -u -m parlai.scripts.tod_world_script --api-schema-grounding-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiSchemaAgent --goal-grounding-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainGoalAgent --user-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainUserUttAgent --system-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiCallAndSysUttAgent --api-resp-model parlai.tasks.google_sgd_simulation_splits.agents:OutDomainApiResponseAgent -dt valid --num-episodes -1 --episodes-randomization-seed 42 --world-logs gold-valid +``` + +This file handles +1. Script param setup, including that used for loading agents which may have their own parameters +2. Running the world (including handling batching, until num episodes or length of epoch has been met). +3. File I/O for both reports (for metrics) and conversation logs + logic for displaying prints +""" + +import json +from copy import deepcopy +from shutil import copyfile +import os + +import parlai.utils.logging as logging +import parlai.core.tod.tod_world as tod_world +import parlai.core.tod.tod_agents as tod_world_agents +from parlai.core.agents import create_agent +from parlai.core.metrics import dict_report, aggregate_unnamed_reports +from parlai.core.opt import Opt +from parlai.core.params import ParlaiParser +from parlai.core.script import ParlaiScript, register_script +from parlai.utils.distributed import ( + is_primary_worker, + all_gather_list, + is_distributed, + get_rank, + sync_object, + num_workers, +) +from parlai.utils.io import PathManager +from parlai.utils.misc import TimeLogger, nice_report +from parlai.utils.world_logging import WorldLogger + + +class TodWorldLogger(WorldLogger): + """ + WorldLogger has most of what we need. + + We could if-class this logic in it directly, but inheritence + override here is + neater. + """ + + def _is_batch_world(self, world): + return True + + def _log_batch(self, world): + batch_acts = world.get_batch_acts() + for i, acts in enumerate(batch_acts): + # filter out for empty + acts = [act for act in acts if act["id"] != "" and act["text"] != ""] + if len(acts) > 0: + self._add_msgs(acts, idx=i) + if world.episode_done(): + self.reset_world(idx=i) + + +class TodWorldParser(ParlaiParser): + def add_extra_args(self, args=None): + super().add_extra_args(args) + parsed = vars(self.parse_known_args(args, nohelp=True)[0]) + # Also load extra args options if a file is given. + if parsed.get("init_opt") is not None: + try: + self._load_known_opts(parsed.get("init_opt"), parsed) + except FileNotFoundError: + # don't die if -o isn't found here. See comment in second call + # later on. + pass + parsed = self._infer_datapath(parsed) + + partial = Opt(parsed) + + for model in [ + "system_model", + "user_model", + "api_schema_grounding_model", + "goal_grounding_model", + "api_resp_model", + ]: + if ( + model in partial + and partial[model] is not None + and len(partial[model]) > 0 + ): + self.add_model_subargs(partial[model], partial) + + for model_file_prefix in ["system", "user"]: + key = model_file_prefix + "_model_file" + if key in partial and partial[key] and len(partial[key]) > 0: + model_name = self._get_model_name_from_model_file(key, partial) + self.add_model_subargs(model_name, partial) + + def _get_model_name_from_model_file(self, key, opt): + """ + Get the model name from either `--model` or `--model-file`. + """ + # try to get model name from model opt file + model_file = opt.get(key, None) + optfile = model_file + ".opt" + new_opt = Opt.load(optfile) + model = new_opt.get("model", None) + return model + + +@register_script("tod_world_script") +class TodWorldScript(ParlaiScript): + @classmethod + def setup_tod_args(cls, parser: ParlaiParser): + tod_args = parser.add_argument_group( + "TOD World Script Agent arguments. NOTE: If there are issues with invoking downstream opts of agents specified here sometimes you will have more luck with `python -u -m parlai.scripts.tod_world_script` than `parlai tod_world_script`." + ) + tod_args.add_argument( + "--system-model-file", + default="", + help="Define the system model for the chat. Exactly one of this or system-model must be specified", + ) + + tod_args.add_argument( + "--system-model", + default="", + help="Define the system agent for the chat. Exactly one of this or system-model-file must be specified", + ) + + tod_args.add_argument( + "--user-model-file", + default="", + help="Define the user model for the chat. Exactly one of this user-model must be specified. Currently assumed to be the API Call creation agent as well.", + ) + + tod_args.add_argument( + "--user-model", + default="", + help="Define the user agent for the chat. Exactly one of this or user-model-file must be specified. Currently assumed to be the API Call creation agent as well.", + ) + + tod_args.add_argument( + "--api-resp-model", + default="", + help="Agent used for defining API response values", + ) + + tod_args.add_argument( + "--api-schema-grounding-model", + default="", + help="Agent used in first turn to grounding api call/response agents with api schemas. Will use EmptyApiSchemaAgent if both this and `--api-schemas` not set.", + ) + + tod_args.add_argument( + "--goal-grounding-model", + default="", + help="Agent used in first turn to grounding user agent with goal. Will use EmptyGoalAgent if not set", + ) + + tod_args.add_argument( + "--api-schemas", + default=None, + help="If set and `--api-schema-grounding-model` is empty, will infer `--api-schema-grounding-model` based on this and a regex on `--goal-grounding-model`. If you run into issues with parsing order of opts using this flag, just switch to `--api-schema-grounding-model`.", + ) + + @classmethod + def setup_args(cls): + # Use default parlai args for logging + the like, but don't need model args since we specify those manually via command-line + parser = TodWorldParser( + True, False, "World for chatting with the TOD conversation structure" + ) + # Following params are same as the `eval_model` script + parser.add_argument( + "--report-filename", + type=str, + help="Saves a json file of the evaluation report either as an " + 'extension to the model-file (if begins with a ".") or a whole ' + "file path. Set to the empty string to not save at all.", + ) + parser.add_argument( + "--world-logs", + type=str, + help="Saves a jsonl file containing all of the task examples and " + "model replies.", + ) + parser.add_argument( + "--save-format", + type=str, + default="conversations", + choices=["conversations", "parlai"], + ) + parser.add_argument( + "--num-episodes", + type=int, + default=10, + help="Number of episodes to display. Set to -1 for infinity or the number of examples of the first agent with a non-unlimited number of episodes in the world.", + ) + parser.add_argument("-d", "--display-examples", type="bool", default=False) + parser.add_argument("-ltim", "--log-every-n-secs", type=float, default=10) + TodWorldLogger.add_cmdline_args(parser) + + # Following are specific to TOD World + parser.add_argument( + "--max-turns", + type=int, + default=30, + help="The max number of full turns before chat ends, excluding prompting", + ) + TodWorldScript.setup_tod_args(parser) + + return parser + + def _get_file_or_model_specifiable_agent(self, prefix, opt): + if len(opt.get(f"{prefix}_model_file", "")) > 0: + if len(opt.get(f"{prefix}_model", "")) > 0: + raise KeyError( + "Both `--{prefix}-model-file` and `--{prefix}-model` specified. Exactly one should be." + ) + model = self._make_agent( + opt, + f"{prefix}_model_file", + requireModelExists=True, + opt_key="model_file", + ) + elif len(opt.get(f"{prefix}_model", "")) > 0: + model = self._make_agent(opt, f"{prefix}_model", "") + else: + raise KeyError( + f"Both `--{prefix}-model-file` and `--{prefix}-model` specified. Neither currently set." + ) + return model + + def _get_model_or_default_agent(self, opt, key, default_class): + if len(opt.get(key, "")) > 0: + return self._make_agent(opt, key) + return default_class(opt) + + def _get_tod_agents(self, opt: Opt): + agents = [None] * tod_world.AGENT_COUNT + + agents[tod_world.USER_UTT_IDX] = self._get_file_or_model_specifiable_agent( + "user", opt + ) + + # Get system agent, nothing that api call agent currently same as system agent + system_model = self._get_file_or_model_specifiable_agent("system", opt) + agents[tod_world.SYSTEM_UTT_IDX] = system_model + agents[tod_world.API_CALL_IDX] = system_model + + agents[tod_world.API_RESP_IDX] = self._make_agent(opt, "api_resp_model") + agents[tod_world.GOAL_GROUNDING_IDX] = self._get_model_or_default_agent( + opt, "goal_grounding_model", tod_world_agents.EmptyGoalAgent + ) + + if "api_schema_grounding_model" not in opt and "api_schemas" in opt: + opt["api_schema_grounding_model"] = opt.get( + "goal_grounding_model", "" + ).replace("Goal", "ApiSchema") + + agents[tod_world.API_SCHEMA_GROUNDING_IDX] = self._get_model_or_default_agent( + opt, "api_schema_grounding_model", tod_world_agents.EmptyApiSchemaAgent + ) + + return agents + + def _make_agent(self, opt_raw, name, requireModelExists=False, opt_key="model"): + """ + Hack. + + `create_agent` expects opt[`model`] to specify the model type and we're + specifying multiple models from other opt arguments (ex. + `system_model`/`user_model` etc), so this swaps it in. + """ + opt = deepcopy(opt_raw) + opt[opt_key] = opt[name] + print(opt_key, name) + return create_agent(opt, requireModelExists) + + def _run_episode(self, opt, world, world_logger): + while not world.episode_done(): + world.parley() + world_logger.log(world) + + if opt["display_examples"]: + logging.info(world.display()) + + if opt["display_examples"]: + logging.info("-- end of episode --") + + world.reset() + world_logger.reset_world() # flush this episode + return zip(world.get_last_batch_goals(), world.get_last_batch_episode_metrics()) + + def _save_outputs(self, opt, world, logger, episode_metrics): + if is_distributed(): # flatten everything intelligently if need be + world_report = aggregate_unnamed_reports(all_gather_list(world.report())) + episode_metrics_unflattened = all_gather_list(episode_metrics) + flattened = [] + for rank_elem in episode_metrics_unflattened: + for elem in rank_elem: + flattened.append(elem) + episode_metrics = flattened + else: + world_report = world.report() + logging.report("Final report:\n" + nice_report(world_report)) + + report = dict_report(world_report) + + def get_episode_report(goal, episode_metric): + metrics_dict = dict_report(episode_metric.report()) + metrics_dict["goal"] = goal + return metrics_dict + + report["tod_metrics"] = [get_episode_report(g, e) for g, e in episode_metrics] + + if "report_filename" in opt and opt["report_filename"] is not None: + if len(world_report) == 0: + logging.warning("Report is empty; not saving report") + + report_fname = f"{opt['report_filename']}.json" + # Save report + if not is_distributed() or is_primary_worker(): + with PathManager.open(report_fname, "w") as f: + logging.info(f"Saving model report to {report_fname}") + json.dump({"opt": opt, "report": report}, f, indent=4) + f.write("\n") # for jq + + if "world_logs" in opt and opt["world_logs"] is not None: + if is_distributed(): # Save separately, then aggregate together + rank = get_rank() + log_outfile_part = ( + f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl" + ) + logger.write(log_outfile_part, world, file_format=opt["save_format"]) + sync_object(None) + if is_primary_worker(): + log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl" + log_outfile_metadata = ( + f"{opt['world_logs']}_{opt['save_format']}.metadata" + ) + with open(log_outfile, "w+") as outfile: + for rank in range(num_workers()): + log_outfile_part = ( + f"{opt['world_logs']}_{opt['save_format']}_{rank}.jsonl" + ) + with open(log_outfile_part) as infile: + for line in infile: + json_blob = json.loads(line.strip()) + if ( + len(json_blob["dialog"]) < 2 + ): # skip when we don't have generation + continue + json_blob["metadata_path"] = log_outfile_metadata + outfile.write(json.dumps(json_blob)) + outfile.write("\n") + log_output_part_metadata = f"{opt['world_logs']}_{opt['save_format']}_{rank}.metadata" + if rank == 0: + copyfile( + log_output_part_metadata, log_outfile_metadata + ), + os.remove(log_outfile_part) + os.remove(log_output_part_metadata) + else: + log_outfile = f"{opt['world_logs']}_{opt['save_format']}.jsonl" + logger.write(log_outfile, world, file_format=opt["save_format"]) + + return report + + def _setup_world(self): + # setup world, manually finaggling necessary opt info as needed + self.opt["task"] = "TodWorld" + world = tod_world.TodWorld(self.opt, agents=self._get_tod_agents(self.opt)) + return world + + def run(self): + opt = self.opt + + world = self._setup_world() + logger = TodWorldLogger(opt) + + # set up logging + log_every_n_secs = opt.get("log_every_n_secs", -1) + if log_every_n_secs <= 0: + log_every_n_secs = float("inf") + log_time = TimeLogger() + + # episode counter + max_episodes = opt.get("num_episodes", -1) + if max_episodes < 0: + max_episodes = float("inf") + world_num_episodes = world.num_episodes() + if world_num_episodes > 0: + max_episodes = min(max_episodes, world_num_episodes) + + ep_count = 0 + episode_metrics = [] + while not world.epoch_done() and ep_count < max_episodes: + episode_metrics.extend(self._run_episode(opt, world, logger)) + ep_count += opt.get("batchsize", 1) + if log_time.time() > log_every_n_secs: + report = world.report() + text, report = log_time.log(ep_count, max_episodes, report) + logging.info(text) + + return self._save_outputs(opt, world, logger, episode_metrics) + + +if __name__ == "__main__": + TodWorldScript.main() diff --git a/projects/tod_simulator/README.md b/projects/tod_simulator/README.md deleted file mode 100644 index 65fbb41753a..00000000000 --- a/projects/tod_simulator/README.md +++ /dev/null @@ -1 +0,0 @@ -Page to be filled. :) diff --git a/projects/tod_simulator/world_metrics/extended_world_metrics.py b/projects/tod_simulator/world_metrics/extended_world_metrics.py new file mode 100644 index 00000000000..dc7444fc84c --- /dev/null +++ b/projects/tod_simulator/world_metrics/extended_world_metrics.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Metrics handlers - ie, classes that handle generations from Tod World and calculates metrics from them. + +Note that only metrics handler classes in `WORLD_METRIC_HANDLERS` of ` parlai/core/tod/world_metrics_handlers.py` are actively being recorded as metrics. These "extended" metrics were ones we experimented with at one point in the past, found to be inconclusive, and are including primarily to not delete already done work. (Or for testing purposes.) +""" + +from parlai.core.message import Message +from parlai.core.metrics import ( + Metric, + AverageMetric, + normalize_answer, + BleuMetric, + SumMetric, +) +from parlai.core.tod.tod_core import ( + STANDARD_API_NAME_SLOT, + STANDARD_REQUIRED_KEY, + STANDARD_OPTIONAL_KEY, + STANDARD_API_SCHEMAS, +) +from typing import Dict, List, Optional, Tuple +from parlai.core.tod.world_metrics_handlers import ( + TodMetricsHandler, + register_metrics_handler, + _ApiCallGoalInteractionHelper, + goals_hit_helper, +) + +try: + from nltk.translate import bleu_score as nltkbleu +except ImportError: + # User doesn't have nltk installed, so we can't use it for bleu + # We'll just turn off things, but we might want to warn the user + nltkbleu = None + +################################ +# Functions and classes associated with calculating statistics between API Calls and Goals. + + +def get_req_only_goals(goals_list: List[Dict], api_schemas: List[Dict]) -> List[Dict]: + """ + Given a list of goals and a list of api schemas that say if slots are required or + optional, this function filters for the goals to be only the required ones. + + If we have no api schemas or a goal is malformed, we return the empty list. If a + goal is malformed, we print a warning, since this whole req-only goals thing is + experimental at best anyhow. + """ + if len(api_schemas) == 0: + return [] + result = [] + for goal in goals_list: + req_goals = {} + method = goal.get(STANDARD_API_NAME_SLOT, None) + if method is None: + return [] + required = [] + for schema in api_schemas: + if schema.get(STANDARD_API_NAME_SLOT, "") == method: + required = schema.get(STANDARD_REQUIRED_KEY, {}) + print("-".join(required)) + for key in required: + if key not in goal: + print(f"No required key `{key}` in goal `{goal}`") + return [] + req_goals[key] = goal[key] + if len(req_goals) > 0: + req_goals[STANDARD_API_NAME_SLOT] = method # for consistency with all. + result.append(req_goals) + return result + + +def goals_slots_helper( + goals: List[Dict], turnDict: List[Dict] +) -> Tuple[Tuple[int, int], Tuple[int, int]]: + """ + Helper function to see how well the slot keys + slot values match between attempted + API calls and goals. + + Output is precision, recall. + """ + all_call_slots = {k: v for call in turnDict for k, v in call.items()} + all_goal_slots = {k: v for goal in goals for k, v in goal.items()} + goal_in_call = { + k: v + for k, v in all_call_slots.items() + if all_goal_slots.get(k, "definitelyNotInValuexyz") == v + } + call_in_goal = { + k: v + for k, v in all_goal_slots.items() + if all_call_slots.get(k, "definitelyNotInValuexyz") == v + } + + print(goal_in_call, all_call_slots) + + return ( + AverageMetric(len(goal_in_call), len(all_call_slots)), + AverageMetric(len(call_in_goal), len(all_goal_slots)), + ) + + +@register_metrics_handler +class LegacyGoalApiCallInteractionsMetricsHandler(_ApiCallGoalInteractionHelper): + """ + This class was reporting a few too many metrics, but is useful for test purposes, so + we're keeping it around. + + `AllGoalApiCallSuccessMetricsHandler` is the streamlined, less spammy version of + this class. + """ + + def handle_goals( + self, message: Message, goals: List[Dict] + ) -> Optional[Dict[str, Metric]]: + self.goals = goals + self.required_goals = get_req_only_goals(goals, self.api_schemas) + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + all_goals_hit, all_goals_hit_turn_count, all_part_hit = goals_hit_helper( + self.goals, self.api_turns + ) + all_precision, all_recall = goals_slots_helper(self.goals, self.api_turns) + req_goals_hit, req_goals_hit_turn_count, req_part_hit = goals_hit_helper( + self.required_goals, self.api_turns, permissive=True + ) + req_precision, req_recall = goals_slots_helper( + self.required_goals, self.api_turns + ) + call_attempts = len(self.api_turns) + return { + "all_goals_hit": all_goals_hit, + "all_goals_hit_turn_count": all_goals_hit_turn_count, + "all_goals_fractional_hit": all_part_hit, + "all_goals_slot_precision": all_precision, + "all_goals_slot_recall": all_recall, + "req_goals_hit": req_goals_hit, + "req_goals_hit_turn_count": req_goals_hit_turn_count, + "req_goals_fractional_hit": req_part_hit, + "req_goals_slot_precision": req_precision, + "req_goals_slot_recall": req_recall, + "call_attempts": AverageMetric(call_attempts), + } + + +@register_metrics_handler +class UserGoalSlotCoverageMetricHandler(TodMetricsHandler): + """ + How well does our user simulator do at outputting utterances that goes closer to + satisfying relevant groundinged goals? Does it dump out all of the slots at once or + is it more intelligent than that? + + Since this is the user and we don't know the identity of potential slots, we ignore + the short (< 4 chars) goal slots since this tends to be things that are substrings + of other things. (Ex. "2" showing up as # of people in a reservation, but also + showing up as a phone number.) + """ + + def episode_reset(self): + self.mentioned_all_slot_values = set() + self.mentioned_req_slot_values = set() + self.all_goal_slot_values = set() + self.all_req_goal_slot_values = set() + + def handle_goals( + self, message: Message, goals: List[Dict] + ) -> Optional[Dict[str, Metric]]: + """ + Parse out all the slots as a blob, filtering out for short things. + """ + required_goals = get_req_only_goals(goals, self.api_schemas) + + def get_slot_values(goal_list): + result = set() + for goal in goal_list: + for key, value in goal.items(): + if key is not STANDARD_API_NAME_SLOT and len(value) > 3: + result.add(value) + return result + + self.all_goal_slot_values = get_slot_values(goals) + self.all_req_goal_slot_values = get_slot_values(required_goals) + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + """ + Grab slots out of the user utterance based on an exact match. + """ + utterance = prefix_stripped_text + + def get_slots(utt, options): + results = set() + for option in options: + if option in utt: + results.add(option) + return results + + all_slot_values_here = get_slots(utterance, self.all_goal_slot_values) + req_slot_values_here = get_slots(utterance, self.all_req_goal_slot_values) + + self.mentioned_all_slot_values |= all_slot_values_here + self.mentioned_req_slot_values |= req_slot_values_here + + metrics = {} + metrics["user_utt_avg_any_slot"] = AverageMetric(len(all_slot_values_here)) + metrics["user_utt_avg_req_slot"] = AverageMetric(len(req_slot_values_here)) + return metrics + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + result = { + "user_any_goal_slots_recall": AverageMetric( + len(self.mentioned_all_slot_values), len(self.all_goal_slot_values) + ), + "user_req_goal_slots_recall": AverageMetric( + len(self.mentioned_req_slot_values), len(self.all_req_goal_slot_values) + ), + } + + self.mentioned_all_slot_values = set() + self.mentioned_req_slot_values = set() + return result + + +class _ExactRepeatMetricsHandler(TodMetricsHandler): + """ + Helper class for defining % of episodes where a given agent type has exactly + repeated the same utterance. + """ + + def episode_reset(self): + self.turns = [] + self.repeated = False + + def metric_key(self): + raise NotImplementedError("must implement") + + def handle_message_helper( + self, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + normalized = normalize_answer(prefix_stripped_text) + if normalized in self.turns: + self.repeated = True + self.turns.append(normalized) + + def get_episode_metrics(self) -> Optional[Dict[str, Metric]]: + repeat = int(self.repeated) + self.repeated = False + self.turns = [] + return {self.metric_key(): AverageMetric(repeat)} + + +@register_metrics_handler +class UserUttRepeatMetricHandler(_ExactRepeatMetricsHandler): + def metric_key(self): + return "user_utt_repeat" + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + return self.handle_message_helper(prefix_stripped_text) + + +@register_metrics_handler +class SystemUttRepeatMetricHandler(_ExactRepeatMetricsHandler): + def metric_key(self): + return "sys_utt_repeat" + + def handle_sys_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + return self.handle_message_helper(prefix_stripped_text) + + +@register_metrics_handler +class _Bleu3MetricsHandler(TodMetricsHandler): + """ + For a given agent, this calculates the Bleu-3 of a new turn against prior turns. + + This is an alternate metric for repetativeness + """ + + def episode_reset(self): + self.turns = [] + + def metric_key(self): + raise NotImplementedError("must implement") + + def handle_message_helper( + self, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + here = [normalize_answer(x) for x in prefix_stripped_text.split(" ")] + score = 1 + if len(self.turns) > 0: + score = nltkbleu.corpus_bleu( + [self.turns], + [here], + smoothing_function=nltkbleu.SmoothingFunction(epsilon=1e-12).method1, + weights=[1.0 / 3.0] * 3, + ) + self.turns.append(here) + return {self.metric_key(): BleuMetric(score)} + + +@register_metrics_handler +class UserUttSelfBleu3MetricHandler(_Bleu3MetricsHandler): + def metric_key(self): + return "user_utt_self_bleu3" + + def handle_user_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + return self.handle_message_helper(prefix_stripped_text) + + +@register_metrics_handler +class SystemUttSelfBleu3MetricHandler(_Bleu3MetricsHandler): + def metric_key(self): + return "sys_utt_self_bleu3" + + def handle_sys_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + return self.handle_message_helper(prefix_stripped_text) + + +@register_metrics_handler +class ApiCallMalformedMetricHandler(TodMetricsHandler): + def episode_reset(self): + self.api_schemas = [] + + def handle_api_call( + self, message: Message, api_call: Dict + ) -> Optional[Dict[str, Metric]]: + if STANDARD_API_SCHEMAS in message["text"]: + return # Happens for API call groundingion, so it's fine + if len(api_call) == 0: + return + if STANDARD_API_NAME_SLOT not in api_call: + return { + "apiCall_wellFormed": AverageMetric(0), + "apiCall_hasSlotsButNoApiNameSlot_count": SumMetric(1), + } + method = api_call[STANDARD_API_NAME_SLOT] + + method_found = False + if len(self.api_schemas) > 0: + for schema in self.api_schemas: + if method == schema.get(STANDARD_API_NAME_SLOT, ""): + method_found = True + check = api_call.keys() + required = set(schema.get(STANDARD_REQUIRED_KEY, [])) + required.add(STANDARD_API_NAME_SLOT) + for req in required: + if req not in check: # miissing required + return { + "apiCall_wellFormed": AverageMetric(0), + "apiCall_missingRequiredSlot_count": SumMetric(1), + } + opt_count = 0 + for opt in schema.get(STANDARD_OPTIONAL_KEY, []): + if opt in check: + opt_count += 1 + if opt_count + len(required) != len(check): + # have extra APIs that are not + return { + "apiCall_wellFormed": AverageMetric(0), + "apiCall_hasExtraParams_count": SumMetric(1), + } + break + if method_found: + return { + "apiCall_wellFormed": AverageMetric(1), + "apiCall_wellFormed_count": SumMetric(1), + } + return { + "apiCall_wellFormed": AverageMetric(0), + "apiCall_methodDNE_count": SumMetric(1), + } + + +@register_metrics_handler +class PseudoInformMetricsHandler(TodMetricsHandler): + """ + Pseudo-inform rate. + """ + + def episode_reset(self): + self.api_resp_slots = {} + + def handle_api_resp( + self, message: Message, api_resp: Dict + ) -> Optional[Dict[str, Metric]]: + self.api_resp_slots.update(api_resp) + + def handle_sys_utt( + self, message: Message, prefix_stripped_text: str + ) -> Optional[Dict[str, Metric]]: + count = 0 + for val in self.api_resp_slots.values(): + if val in prefix_stripped_text: + count += 1 + result = {"pseudo_inform_allSysTurns": AverageMetric(count)} + if len(self.api_resp_slots) > 0: + result["pseudo_inform_postApiRespSysTurns"] = AverageMetric(count) + return result diff --git a/tests/tod/test_tod_world_and_script.py b/tests/tod/test_tod_world_and_script.py new file mode 100644 index 00000000000..2d8e8944024 --- /dev/null +++ b/tests/tod/test_tod_world_and_script.py @@ -0,0 +1,251 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests tod world + script, notably for batching, by comparing saved script logs to the +data that should have been generated. + +Metrics are handled in separate files. +""" + +import copy +import unittest + +import parlai.core.tod.tod_test_utils.test_agents as test_agents +import parlai.core.tod.tod_core as tod_core +import parlai.scripts.tod_world_script as tod_world_script +from parlai.core.tod.tod_agents import StandaloneApiAgent + + +class TestTodWorldScript(tod_world_script.TodWorldScript): + """ + Wrap around it to check its logic; also makes it easier to do things w/ underlying + World. + """ + + def _get_tod_agents(self, opt): + """ + Hack so we can separate out logic of making sure agent parsing is correct. + """ + if hasattr(self, "agents"): + return self.agents + return super()._get_tod_agents(opt) + + def _save_outputs(self, opt, world, logger, episode_metrics): + self.world = world + self.logger = logger + + +class TodWorldInScriptTestBase(unittest.TestCase): + def add_tod_world_opts(self, base_opts): + """ + Convenience since we're initing the opt directly without parlai parser. + """ + opts = copy.deepcopy(base_opts) + opts["datatype"] = "DUMMY" + opts["datafile"] = "DUMMY" + opts["episodes_randomization_seed"] = -1 + opts["standalone_api_file"] = test_agents.API_DATABASE_FILE + opts["exact_api_call"] = True + opts["log_keep_fields"] = "all" + opts["display_examples"] = False + opts[ + "include_api_schemas" + ] = True # do this to test_agents.make sure they're done correctly. + return opts + + def setup_agents(self, added_opts): + full_opts = self.add_tod_world_opts(added_opts) + sys = test_agents.ApiCallAndSysUttAgent(full_opts) + agents = [ + test_agents.UserUttAgent(full_opts), + sys, + StandaloneApiAgent(full_opts), + sys, + test_agents.ApiSchemaAgent(full_opts), + test_agents.GoalAgent(full_opts), + ] + return agents, full_opts + + def _test_roundDataCorrect(self): + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__SINGLE_API_CALL) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_ROUND) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE) + self._test_roundDataCorrect_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE_BS) + + def _check_correctness_from_script_logs( + self, script, opt, process_round_utts=lambda x: x + ): + """ + Last argument is only relevant for the max_turn test. + """ + max_rounds = opt[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = opt[test_agents.TEST_NUM_EPISODES_OPT_KEY] + # there's something funky with logger.get_log() that inserts a space, but not gonna worry about it for now + logs = [x for x in script.logger.get_logs() if len(x) > 0] + for episode_idx in range(max_episodes): + episode_from_world = logs[episode_idx] + # first round is context + context = episode_from_world[0] + self.assertEquals( + context[0]["text"], + "APIS: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_api_schemas_machine(max_rounds) + ), + ) + self.assertEquals( + context[3]["text"], + "GOAL: " + + tod_core.SerializationHelpers.list_of_maps_to_str( + test_agents.make_goal_calls_machine(max_rounds) + ), + ) + # Check the rest + world_utts = [[x["text"] for x in turn] for turn in episode_from_world[1:]] + # ... ignore the last DONE turn here cause it's not that important + + self.assertEquals( + world_utts[:-1], + process_round_utts( + test_agents.get_round_utts(episode_idx, max_rounds)[:-1] + ), + ) + + +class TodWorldSingleBatchTest(TodWorldInScriptTestBase): + """ + Checks that saved data is correct with a single batch. + """ + + def _test_roundDataCorrect_helper(self, config): + config["batchsize"] = 1 + config["max_turns"] = 10 + agents, opt = self.setup_agents(config) + script = TestTodWorldScript(opt) + script.agents = agents + script.run() + self._check_correctness_from_script_logs(script, opt) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + def test_max_turn(self): + self._test_max_turn_helper(4) + self._test_max_turn_helper(7) + + def _test_max_turn_helper(self, max_turns): + config = {} + config["batchsize"] = 1 + config["max_turns"] = max_turns + config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] = 10 + config[test_agents.TEST_NUM_EPISODES_OPT_KEY] = 5 # cause why not + agents, opt = self.setup_agents(config) + script = TestTodWorldScript(opt) + script.agents = agents + script.run() + + def filter_round_utt(utts): + # tad imprecise, but more important that it does stop. + # subtract 1 for the context turn, then 1 cause there's an off by one somewhere + return utts[: max_turns - 2] + + self._check_correctness_from_script_logs(script, opt, filter_round_utt) + + +class TodWorldNonSingleBatchTest(TodWorldInScriptTestBase): + """ + Checks saved data is correct with larger batchsizes. + """ + + def _test_roundDataCorrect_helper(self, config): + config["batchsize"] = 4 + config["max_turns"] = 10 + agents, opt = self.setup_agents(config) + script = TestTodWorldScript(opt) + script.agents = agents + script.run() + self._check_correctness_from_script_logs(script, opt) + + def test_roundDataCorrect(self): + self._test_roundDataCorrect() + + +class TodWorldTestSingleDumpAgents(TodWorldInScriptTestBase): + """ + Just to be safe, make sure that the agents with "single" versions (ex goal + api + schema) are correctly aligned. + + (Already tested in the agents test file as well, but to be safe.) + """ + + def setup_agents(self, added_opts, api_agent, goal_agent): + full_opts = self.add_tod_world_opts(added_opts) + full_opts["fixed_response"] = "USER: [DONE]" + sys = test_agents.ApiCallAndSysUttAgent(full_opts) + agents = [ + test_agents.UserUttAgent(full_opts), + sys, + StandaloneApiAgent(full_opts), + sys, + api_agent(full_opts), + goal_agent(full_opts), + ] + return agents, full_opts + + def _test_SingleGoalApiResp_helper(self, batchsize, num_episodes): + config = {} + config["batchsize"] = batchsize + config[test_agents.TEST_NUM_ROUNDS_OPT_KEY] = 10 + config[test_agents.TEST_NUM_EPISODES_OPT_KEY] = num_episodes + single_agents, opt = self.setup_agents( + config, test_agents.SingleApiSchemaAgent, test_agents.SingleGoalAgent + ) + single_script = TestTodWorldScript(opt) + single_script.agents = single_agents + single_script.run() + single_logs = [x for x in single_script.logger.get_logs() if len(x) > 0] + + multi_agents, opt = self.setup_agents( + config, test_agents.ApiSchemaAgent, test_agents.GoalAgent + ) + multi_script = TestTodWorldScript(opt) + multi_script.agents = multi_agents + multi_script.run() + multi_logs = [x for x in single_script.logger.get_logs() if len(x) > 0] + + single_idx = 0 + for multi_log in multi_logs: + context = multi_log[0] + goals = tod_core.SerializationHelpers.str_to_goals( + context[3]["text"][len("GOAL:") :].strip() + ) + for goal in goals: + single_context = single_logs[single_idx][0] + single_goal = tod_core.SerializationHelpers.str_to_goals( + single_context[3]["text"][len("GOAL:") :].strip() + ) + self.assertEqual(len(single_goal), 1) + self.assertEquals(goal, single_goal[0]) + single_des = tod_core.SerializationHelpers.str_to_api_schemas( + single_context[0]["text"][len("APIS:") :].strip() + ) + self.assertEqual(len(single_des), 1) + self.assertEqual(single_goal[0]["api_name"], single_des[0]["api_name"]) + + single_idx += 1 + + def test_SingleGoalApiResp_helper_singleBatch(self): + self._test_SingleGoalApiResp_helper(1, 2) + self._test_SingleGoalApiResp_helper(1, 5) + + def test_SingleGoalApiResp_helper_multiBatch(self): + self._test_SingleGoalApiResp_helper(4, 8) + self._test_SingleGoalApiResp_helper(4, 11) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tod/test_tod_world_metrics.py b/tests/tod/test_tod_world_metrics.py new file mode 100644 index 00000000000..293c677834f --- /dev/null +++ b/tests/tod/test_tod_world_metrics.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test world metrics + world metrics handlers against dummy conversations. +""" + +import unittest + +from parlai.core.tod.tod_core import ( + STANDARD_API_NAME_SLOT, + STANDARD_REQUIRED_KEY, + STANDARD_OPTIONAL_KEY, + TodStructuredRound, + TodStructuredEpisode, + TodAgentType, + TOD_AGENT_TYPE_TO_PREFIX, +) +from parlai.core.tod.world_metrics import TodMetrics +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY + +# Ignore lint on following line; want to have registered classes show up for tests +import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 + +GOAL__SINGLE_ONE_KEY = [{STANDARD_API_NAME_SLOT: "name", "a": "1"}] +GOAL__SINGLE_THREE_KEYS = [ + {STANDARD_API_NAME_SLOT: "name", "a": "1", "b": "2", "c": "3"} +] +GOAL__HARD = [ + { + STANDARD_API_NAME_SLOT: "otherName", + "w": "1", + "x": "2", + "y": "3", + "z": "will_be_missing", + "diff": "right", + } +] + +API_CALL__NO_API_NAME_SLOT = {"random": "blah"} +API_CALL__API_NAME_DNE = {STANDARD_API_NAME_SLOT: "not_an_api_name"} +API_CALL__VALID_NAME_BUT_EMPTY = {STANDARD_API_NAME_SLOT: "name"} +API_CALL__SINGLE_ONE_KEY = GOAL__SINGLE_ONE_KEY[0] +API_CALL__SINGLE_ONE_KEY_WITH_OPT = {**GOAL__SINGLE_ONE_KEY[0], **{"c": "3"}} +API_CALL__SINGLE_ONE_KEY_WITH_OPT_AND_NONVALID = { + **GOAL__SINGLE_ONE_KEY[0], + **{"c": "3", "nonExistent": "blah"}, +} +API_CALL__FUNKY_AGAINST_HARD = { + STANDARD_API_NAME_SLOT: "otherName", + "w": "1", + "x": "2", + "y": "3", + "diff": "wrong", +} + +API_SCHEMA__ONE_CALL_ONE_REQ_MATCH_ONE_KEY = [ + { + STANDARD_API_NAME_SLOT: "name", + STANDARD_REQUIRED_KEY: ["a"], + STANDARD_OPTIONAL_KEY: [], + } +] + +API_SCHEMA__ONE_CALL_MATCH_THREE_KEYS = [ + { + STANDARD_API_NAME_SLOT: "name", + STANDARD_REQUIRED_KEY: ["a"], + STANDARD_OPTIONAL_KEY: ["b", "c", "d"], + } +] + +API_SCHEMA__ONE_CALL_HARD = [ + { + STANDARD_API_NAME_SLOT: "otherName", + STANDARD_REQUIRED_KEY: ["w", "x"], + STANDARD_OPTIONAL_KEY: ["y", "z", "diff"], + } +] + + +class TodMetricsTestHelper: + """ + Given a synthetic intermediate converesation, calculates the metrics for said + conversation. + """ + + def __init__(self, e: TodStructuredEpisode): + self.m = TodMetrics() + self.m.handlers = [ + x() for x in METRICS_HANDLER_CLASSES_TEST_REGISTRY + ] # run on ALL + self.e = e + + def _process(self, t: TodAgentType, text: str): + self.m.handle_message({"text": f"{TOD_AGENT_TYPE_TO_PREFIX[t]}{text}"}, t) + + def run(self): + self._process(TodAgentType.API_SCHEMA_GROUNDING_AGENT, self.e.api_schemas_utt) + self._process(TodAgentType.GOAL_GROUNDING_AGENT, self.e.goal_calls_utt) + + for r in self.e.rounds: + self._process(TodAgentType.USER_UTT_AGENT, r.user_utt) + self._process(TodAgentType.API_CALL_AGENT, r.api_call_utt) + self._process(TodAgentType.API_RESP_AGENT, r.api_resp_utt) + self._process(TodAgentType.SYSTEM_UTT_AGENT, r.sys_utt) + + self.m.episode_reset() + + def report(self): + return self.m.report() + + +class TestApiGoalHitMetricsHandler(unittest.TestCase): + def __helper(self, api_schemas_machine, goal_calls_machine, single_turn_api_call): + e = TodStructuredEpisode( + api_schemas_machine=api_schemas_machine, + goal_calls_machine=goal_calls_machine, + rounds=[TodStructuredRound(api_call_machine=single_turn_api_call)], + ) + helper = TodMetricsTestHelper(e) + helper.run() + result = helper.report() + return result + + def test_one_goal_only_req(self): + result = self.__helper( + api_schemas_machine=API_SCHEMA__ONE_CALL_ONE_REQ_MATCH_ONE_KEY, + goal_calls_machine=GOAL__SINGLE_ONE_KEY, + single_turn_api_call=API_CALL__SINGLE_ONE_KEY, + ) + self.assertAlmostEqual(result["all_goals_hit"], 1) + self.assertAlmostEqual(result["all_goals_hit_turn_count"], 1) + self.assertAlmostEqual(result["all_goals_fractional_hit"], 1) + self.assertAlmostEqual(result["all_goals_slot_precision"], 1) + self.assertAlmostEqual(result["all_goals_slot_recall"], 1) + + self.assertAlmostEqual(result["req_goals_hit"], 1) + self.assertAlmostEqual(result["req_goals_hit_turn_count"], 1) + self.assertAlmostEqual(result["req_goals_fractional_hit"], 1) + self.assertAlmostEqual(result["req_goals_slot_precision"], 1) + self.assertAlmostEqual(result["req_goals_slot_recall"], 1) + + def test_one_goal_api_name_missing_slots(self): + result = self.__helper( + api_schemas_machine=API_SCHEMA__ONE_CALL_ONE_REQ_MATCH_ONE_KEY, + goal_calls_machine=GOAL__SINGLE_ONE_KEY, + single_turn_api_call=API_CALL__VALID_NAME_BUT_EMPTY, + ) + self.assertAlmostEqual(result["all_goals_hit"], 0) + self.assertAlmostEqual(result["all_goals_hit_turn_count"], 0) + self.assertAlmostEqual(result["all_goals_fractional_hit"], 0) + self.assertAlmostEqual(result["all_goals_slot_precision"], 1) # api_name + self.assertAlmostEqual(result["all_goals_slot_recall"], 0.5) + + self.assertAlmostEqual(result["req_goals_hit"], 0) + self.assertAlmostEqual(result["req_goals_hit_turn_count"], 0) + self.assertAlmostEqual(result["req_goals_fractional_hit"], 0) + self.assertAlmostEqual(result["req_goals_slot_precision"], 1) + self.assertAlmostEqual(result["req_goals_slot_recall"], 0.5) + + def test_one_goal_with_opts(self): + result = self.__helper( + api_schemas_machine=API_SCHEMA__ONE_CALL_MATCH_THREE_KEYS, + goal_calls_machine=GOAL__SINGLE_THREE_KEYS, + single_turn_api_call=API_CALL__SINGLE_ONE_KEY, + ) + self.assertAlmostEqual(result["all_goals_hit"], 0) + self.assertAlmostEqual(result["all_goals_hit_turn_count"], 0) + self.assertAlmostEqual(result["all_goals_fractional_hit"], 0) + self.assertAlmostEqual(result["all_goals_slot_precision"], 1) + self.assertAlmostEqual(result["all_goals_slot_recall"], 0.5) + + self.assertAlmostEqual(result["req_goals_hit"], 1) + self.assertAlmostEqual(result["req_goals_hit_turn_count"], 1) + self.assertAlmostEqual(result["req_goals_fractional_hit"], 1) + self.assertAlmostEqual(result["req_goals_slot_precision"], 1) + self.assertAlmostEqual(result["req_goals_slot_recall"], 1) + + def test_hard_case(self): + result = self.__helper( + api_schemas_machine=API_SCHEMA__ONE_CALL_HARD, + goal_calls_machine=GOAL__HARD, + single_turn_api_call=API_CALL__FUNKY_AGAINST_HARD, + ) + self.assertAlmostEqual(result["all_goals_hit"], 0) + self.assertAlmostEqual(result["all_goals_hit_turn_count"], 0) + self.assertAlmostEqual(result["all_goals_fractional_hit"], 0) + self.assertAlmostEqual(result["all_goals_slot_precision"], 0.8) + self.assertAlmostEqual(result["all_goals_slot_recall"], 2.0 / 3.0) + + self.assertAlmostEqual(result["req_goals_hit"], 1) + self.assertAlmostEqual(result["req_goals_hit_turn_count"], 1) + self.assertAlmostEqual(result["req_goals_fractional_hit"], 1) + self.assertAlmostEqual(result["req_goals_slot_precision"], 0.6) + self.assertAlmostEqual(result["req_goals_slot_recall"], 1) + + +class TestApiCallMalformedMetricsHandler(unittest.TestCase): + def __helper(self, single_turn_api_call): + e = TodStructuredEpisode( + api_schemas_machine=API_SCHEMA__ONE_CALL_MATCH_THREE_KEYS, + rounds=[TodStructuredRound(api_call_machine=single_turn_api_call)], + ) + helper = TodMetricsTestHelper(e) + helper.run() + return helper.report() + + def test_no_api_name_slot(self): + result = self.__helper(API_CALL__NO_API_NAME_SLOT) + self.assertEqual(result["apiCall_wellFormed"], 0) + self.assertEqual(result["apiCall_hasSlotsButNoApiNameSlot_count"], 1) + + def test_api_name_DNE(self): + result = self.__helper(API_CALL__API_NAME_DNE) + self.assertEqual(result["apiCall_wellFormed"], 0) + self.assertEqual(result["apiCall_methodDNE_count"], 1) + + def test_missing_required_slot(self): + result = self.__helper(API_CALL__VALID_NAME_BUT_EMPTY) + self.assertEqual(result["apiCall_wellFormed"], 0) + self.assertEqual(result["apiCall_missingRequiredSlot_count"], 1) + + def test_has_single_required_slot(self): + result = self.__helper(API_CALL__SINGLE_ONE_KEY) + self.assertEqual(result["apiCall_wellFormed"], 1) + self.assertEqual(result["apiCall_wellFormed_count"], 1) + + def test_has_valid_optional_slot(self): + result = self.__helper(API_CALL__SINGLE_ONE_KEY_WITH_OPT) + self.assertEqual(result["apiCall_wellFormed"], 1) + self.assertEqual(result["apiCall_wellFormed_count"], 1) + + def test_has_invalid_extra_slots(self): + result = self.__helper(API_CALL__SINGLE_ONE_KEY_WITH_OPT_AND_NONVALID) + self.assertEqual(result["apiCall_wellFormed"], 0) + self.assertEqual(result["apiCall_hasExtraParams_count"], 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tod/test_tod_world_metrics_in_script.py b/tests/tod/test_tod_world_metrics_in_script.py new file mode 100644 index 00000000000..19cc0cd20af --- /dev/null +++ b/tests/tod/test_tod_world_metrics_in_script.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests tod world metrics in the full script, *including* making the script properly set +up the agents on its own. + +Use a few of the API Call + goal hit metrics as the metric handlers to test proper +functionality. +""" + +import copy +import unittest + +from parlai.core.metrics import dict_report +from parlai.core.opt import Opt +from parlai.core.tod.tod_core import SerializationHelpers +import parlai.core.tod.tod_test_utils.test_agents as test_agents +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY +import parlai.scripts.tod_world_script as tod_world_script + +# Ignore lint on following line; want to have registered classes show up for tests +import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 + +NUM_EPISODES = 35 + +TEST_SETUP = { + "api_schema_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:ApiSchemaAgent", + "goal_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:GoalAgent", + "user_model": "parlai.core.tod.tod_test_utils.test_agents:UserUttAgent", + "system_model": "parlai.core.tod.tod_test_utils.test_agents:ApiCallAndSysUttAgent", + "api_resp_model": "fixed_response", + test_agents.TEST_NUM_EPISODES_OPT_KEY: NUM_EPISODES, +} +TEST_SETUP_BROKEN_USER_SYSTEM = { + "api_schema_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:ApiSchemaAgent", + "goal_grounding_model": "parlai.core.tod.tod_test_utils.test_agents:GoalAgent", + "user_model": "fixed_response", + "system_model": "fixed_response", + "api_resp_model": "fixed_response", + test_agents.TEST_NUM_EPISODES_OPT_KEY: NUM_EPISODES, +} + +TEST_SETUP_EMPTY_APISCHEMA = copy.deepcopy(TEST_SETUP) +TEST_SETUP_EMPTY_APISCHEMA[ + "api_schema_grounding_model" +] = "parlai.core.tod.tod_agents:EmptyApiSchemaAgent" + +TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA = copy.deepcopy( + TEST_SETUP_BROKEN_USER_SYSTEM +) +TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA[ + "api_schema_grounding_model" +] = "parlai.core.tod.tod_agents:EmptyApiSchemaAgent" + +DATATYPE = "valid" + + +class TestTodWorldScript(tod_world_script.TodWorldScript): + """ + Wrap around it to check its logic; also makes it easier to do things w/ underlying + World. + """ + + def __init__(self, opt: Opt): + opt["datatype"] = DATATYPE + # none of the below matter, but need to set to keep other code happy. + opt["log_keep_fields"] = "all" + opt["display_examples"] = False + + super().__init__(opt) + + def _setup_world(self): + world = super()._setup_world() + for i in range(len(world.batch_tod_world_metrics)): + world.batch_tod_world_metrics[i].handlers = [ + x() for x in METRICS_HANDLER_CLASSES_TEST_REGISTRY + ] + return world + + def _save_outputs(self, opt, world, logger, episode_metrics): + self.world = world + self.logger = logger + self.episode_metrics = episode_metrics + + +class TodMetricsInScriptTests(unittest.TestCase): + def test_all_goals_hit_all_success(self): + """ + For a setup where all the goals should be successfully hit, is it? + """ + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, batchsize=1, num_episodes=1, target_all_goals_hit=1 + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, batchsize=1, num_episodes=32, target_all_goals_hit=1 + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, batchsize=32, num_episodes=8, target_all_goals_hit=1 + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, batchsize=32, num_episodes=33, target_all_goals_hit=1 + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP, + batchsize=32, + num_episodes=-1, + target_all_goals_hit=1, + target_metrics_length=NUM_EPISODES, + ) + + def test_all_goals_hit_all_fail(self): + """ + For a setup where all the goals should *not* be successfully hit, do they fail? + """ + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=1, + num_episodes=1, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=1, + num_episodes=32, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=32, + num_episodes=32, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=32, + num_episodes=33, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM, + batchsize=32, + num_episodes=-1, + target_all_goals_hit=0, + target_metrics_length=NUM_EPISODES, + ) + + def test_all_goals_hit_all_success_emptySchema(self): + """ + Check to make sure empty API schema doesn't have any impact on goal (Necessary + cause original, more exhaustive implementation of goal success would separate + between required + optional opts using the schema; make sure it doesn't impact + anything broader) + """ + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=1, + num_episodes=1, + target_all_goals_hit=1, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=1, + num_episodes=32, + target_all_goals_hit=1, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=32, + target_all_goals_hit=1, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=33, + target_all_goals_hit=1, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=-1, + target_all_goals_hit=1, + target_metrics_length=NUM_EPISODES, + ) + + def test_all_goals_hit_all_fail_emptySchema(self): + """ + Make sure empty schema has no impact on goal success. + + (Necessary cause original, more exhaustive implementation of goal success would + separate between required + optional opts using the schema; make sure it doesn't + impact anything broader) + """ + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=1, + num_episodes=1, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=1, + num_episodes=32, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=32, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=33, + target_all_goals_hit=0, + ) + self._check_all_goals_hit_by_opt_and_batchsize( + TEST_SETUP_BROKEN_USER_SYSTEM_EMPTY_APISCHEMA, + batchsize=32, + num_episodes=-1, + target_all_goals_hit=0, + target_metrics_length=NUM_EPISODES, + ) + + def _check_all_goals_hit_by_opt_and_batchsize( + self, + opt, + batchsize, + num_episodes, + target_all_goals_hit, + target_metrics_length=None, + ): + opt = copy.deepcopy(opt) + opt["batchsize"] = batchsize + opt["num_episodes"] = num_episodes + report, metrics = self._run_opt_get_report(opt) + self.assertEqual(report.get("all_goals_hit"), target_all_goals_hit) + metrics_comp_length = num_episodes + if target_metrics_length: + metrics_comp_length = target_metrics_length + self.assertEqual(len(metrics), metrics_comp_length) + + def _run_opt_get_report(self, opt): + script = TestTodWorldScript(opt) + script.run() + + def get_episode_report(goal, episode_metric): + metrics_dict = dict_report(episode_metric.report()) + metrics_dict["goal"] = goal + return metrics_dict + + return ( + dict_report(script.world.report()), + [get_episode_report(g, e) for g, e in script.episode_metrics], + ) + + def test_apiCallAttempts_usingGold(self): + opt = copy.deepcopy(TEST_SETUP) + opt["batchsize"] = 1 + opt["num_episodes"] = -1 + _, metrics = self._run_opt_get_report(opt) + for metric in metrics: + self.assertEqual( + len( + SerializationHelpers.str_to_goals( + metric["goal"]["text"][len("GOALS: ") :] + ) + ), + metric["call_attempts"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tod/test_tod_world_script_metrics.py b/tests/tod/test_tod_world_script_metrics.py new file mode 100644 index 00000000000..5152a7bfa23 --- /dev/null +++ b/tests/tod/test_tod_world_script_metrics.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests tod world metrics in the full script, *without* making the script properly set up +the agents on its own. + +Use a few of the API Call + goal hit metrics as the metric handlers to test proper +functionality. +""" + +import copy +import unittest + +import parlai.core.tod.tod_test_utils.test_agents as test_agents +import parlai.scripts.tod_world_script as tod_world_script +from parlai.core.tod.tod_agents import StandaloneApiAgent +from parlai.core.tod.world_metrics_handlers import METRICS_HANDLER_CLASSES_TEST_REGISTRY +from parlai.core.metrics import dict_report + +# Ignore lint on following line; want to have registered classes show up for tests +import projects.tod_simulator.world_metrics.extended_world_metrics # noqa: F401 + + +class TestTodWorldScript(tod_world_script.TodWorldScript): + """ + Wrap around it to check its logic; also makes it easier to do things w/ underlying + World. + """ + + def _get_tod_agents(self, opt): + """ + Hack so we can separate out logic of making sure agent parsing is correct. + """ + if hasattr(self, "agents"): + return self.agents + return super()._get_tod_agents(opt) + + def _setup_world(self): + world = super()._setup_world() + for i in range(len(world.batch_tod_world_metrics)): + world.batch_tod_world_metrics[i].handlers = [ + x() for x in METRICS_HANDLER_CLASSES_TEST_REGISTRY + ] + return world + + def _save_outputs(self, opt, world, logger, episode_metrics): + self.world = world + self.episode_metrics = episode_metrics + + +class TodWorldInScriptTestBase(unittest.TestCase): + def add_tod_world_opts(self, base_opts): + """ + Convenience since we're initing the opt directly without parlai parser. + """ + opts = copy.deepcopy(base_opts) + opts["datatype"] = "DUMMY" + opts["datafile"] = "DUMMY" + opts["standalone_api_file"] = test_agents.API_DATABASE_FILE + opts["exact_api_call"] = True + opts["log_keep_fields"] = "all" + opts["display_examples"] = False + opts[ + "include_api_schemas" + ] = True # do this to test_agents.make sure they're done correctly. + return opts + + def setup_agents(self, added_opts): + full_opts = self.add_tod_world_opts(added_opts) + sys = test_agents.ApiCallAndSysUttAgent(full_opts) + agents = [ + test_agents.UserUttAgent(full_opts), + sys, + StandaloneApiAgent(full_opts), + sys, + test_agents.ApiSchemaAgent(full_opts), + test_agents.GoalAgent(full_opts), + ] + return agents, full_opts + + def _run_test(self): + self._run_test_helper(test_agents.EPISODE_SETUP__SINGLE_API_CALL) + self._run_test_helper(test_agents.EPISODE_SETUP__MULTI_ROUND) + self._run_test_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE) + self._run_test_helper(test_agents.EPISODE_SETUP__MULTI_EPISODE_BS) + + def _run_test_helper(self, config_base): + config = copy.deepcopy(config_base) + config["use_broken_mock_api_calls"] = True + add = self.config_args() + for key in add: + config[key] = add[key] + agents, opt = self.setup_agents(config) + script = TestTodWorldScript(opt) + script.agents = agents + script.run() + self._check_metrics_correct(script, opt) + + def _check_metrics_correct(self, script, opt): + """ + Last argument is only relevant for the max_turn test. + """ + max_rounds = opt[test_agents.TEST_NUM_ROUNDS_OPT_KEY] + max_episodes = opt[test_agents.TEST_NUM_EPISODES_OPT_KEY] + episode_metrics = script.episode_metrics + for episode_idx, episode in enumerate(episode_metrics): + goal, episode_metric = episode + episode_metric = dict_report(episode_metric.report()) + self.assertAlmostEqual( + episode_metric["all_goals_hit"], + not test_agents.episode_has_broken_api_turn(episode_idx, max_rounds), + ) + broken_episodes = sum( + [ + test_agents.episode_has_broken_api_turn(i, max_rounds) + for i in range(max_episodes) + ] + ) + report = dict_report(script.world.report()) + self.assertAlmostEqual( + report["all_goals_hit"], + float(max_episodes - broken_episodes) / max_episodes, + ) + + +class TodWorldSingleBatchTest(TodWorldInScriptTestBase): + def config_args(self): + config = {} + config["batchsize"] = 1 + config["max_turns"] = 10 + return config + + def test_metricsCorrect(self): + self._run_test() + + +class TodWorldNonSingleBatchTest(TodWorldInScriptTestBase): + def config_args(self): + config = {} + config["batchsize"] = 4 + config["max_turns"] = 10 + return config + + def test_metricsCorrect(self): + self._run_test() + + +if __name__ == "__main__": + unittest.main()