Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Added DST Teacher for multiwoz_v22 task #4656

Merged
merged 12 commits into from Jul 28, 2022
3 changes: 2 additions & 1 deletion parlai/agents/transformer/modules/decoder.py
Expand Up @@ -432,7 +432,8 @@ def _create_selfattn_mask(self, x):
)
class TransformerDecoderLayer(BaseTransformerDecoderLayer):
"""
Implements a single Transformer decoder layer with cross (encoder) attention as in
Implements a single Transformer decoder layer with cross (encoder) attention as in.

[Vaswani, 2017](https://arxiv.org/abs/1706.03762).

Decoder layers are similar to encoder layers but:
Expand Down
2 changes: 1 addition & 1 deletion parlai/core/tod/tod_agents.py
Expand Up @@ -40,7 +40,7 @@ class TodStructuredDataParser(Agent):

Inherit from this class and implement `setup_episodes()` to implement the intermediate representation for a specific dataset. Use multiple inheritence with classes that implement an `act()` below to use.

For example, if we have a `MyDataset_DataParser(TodStructuredDataParser)` and wanted to make a teacher to train a model togenerate User Utterances based on a goal prompt, we would do so by defining `class MyDatasetUserSimulatorTeacher(MyDataset_DataParser, TodUserSimulatorTeacher)`.
For example, if we have a `MyDataset_DataParser(TodStructuredDataParser)` and wanted to make a teacher to train a model to generate User Utterances based on a goal prompt, we would do so by defining `class MyDatasetUserSimulatorTeacher(MyDataset_DataParser, TodUserSimulatorTeacher)`.
"""

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions parlai/core/torch_generator_agent.py
Expand Up @@ -325,8 +325,8 @@ class TorchGeneratorAgent(TorchAgent, ABC):

TorchGeneratorAgent aims to handle much of the bookkeeping and infrastructure work
for any generative models, like seq2seq or transformer. It implements the train_step
and eval_step. The only requirement is that your model *must* implemented the
interface TorchGeneratorModel interface.
and eval_step. The only requirement is that your model *must* be implemented with
the TorchGeneratorModel interface.
"""

@classmethod
Expand Down
243 changes: 234 additions & 9 deletions parlai/tasks/multiwoz_v22/agents.py
Expand Up @@ -5,24 +5,27 @@
# LICENSE file in the root directory of this source tree.

"""
implementation for ParlAI.
Multiwoz 2.2 Dataset implementation for ParlAI.
"""

from parlai.core.params import ParlaiParser
import copy
import json
import os
from typing import Optional

import numpy as np
import pandas as pd
from parlai.core.opt import Opt

import parlai.core.tod.tod_agents as tod_agents
import parlai.core.tod.tod_core as tod
import json
from typing import Optional
import parlai.tasks.multiwoz_v22.build as build_
from parlai.core.message import Message
from parlai.core.metrics import AverageMetric
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.utils.data import DatatypeHelper
from parlai.utils.io import PathManager

import parlai.tasks.multiwoz_v22.build as build_
import parlai.core.tod.tod_agents as tod_agents


DOMAINS = [
"attraction",
"bus",
Expand All @@ -36,6 +39,14 @@

WELL_FORMATTED_DOMAINS = ["attraction", "bus", "hotel", "restaurant", "train", "taxi"]

DATA_LEN = {"train": 17, "dev": 2, "test": 2}

SEED = 42


def fold_size(fold):
return DATA_LEN[fold]


class MultiwozV22Parser(tod_agents.TodStructuredDataParser):
"""
Expand Down Expand Up @@ -373,6 +384,220 @@ def get_id_task_prefix(self):
return "MultiwozV22"


class MultiWOZv22DSTTeacher(tod_agents.TodUserSimulatorTeacher):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be better to add a comment saying this teacher is needed for reproducing the Joint Goal Accuracy values reported in the simpleTOD & SOLOIST papers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better now ? @chinnadhurai

"""
This Teacher is responsible for performing the task of Dialogue State Tracking.

It can be used to evaluate LM on JGA (Joint Goal Accuracy) metric (as shown in
[SimpleTOD](https://arxiv.org/abs/2005.00796) and
[Soloist](https://arxiv.org/abs/2005.05298)).
"""

BELIEF_STATE_DELIM = " ; "

domains = [
"attraction",
"hotel",
"hospital",
"restaurant",
"police",
"taxi",
"train",
]

named_entity_slots = {
"attraction--name",
"restaurant--name",
"hotel--name",
"bus--departure",
"bus--destination",
"taxi--departure",
"taxi--destination",
"train--departure",
}

rng = np.random.RandomState(SEED)

def __init__(self, opt: Opt, shared=None, *args, **kwargs):
self.opt = opt
self.fold = opt["datatype"].split(":")[0]
opt["datafile"] = self.fold
self.dpath = os.path.join(opt["datapath"], "multiwoz_v22")
self.id = "multiwoz_v22"

if shared is None:
build_.build(opt)
super().__init__(opt, shared)

def _load_data(self, fold):
dataset_fold = "dev" if fold == "valid" else fold
fold_path = os.path.join(self.dpath, dataset_fold)
dialogs = []
for file_id in range(1, fold_size(dataset_fold) + 1):
filename = os.path.join(fold_path, f"dialogues_{file_id:03d}.json")
with PathManager.open(filename, "r") as f:
dialogs += json.load(f)
return dialogs

def _get_curr_belief_states(self, turn):
belief_states = []
for frame in turn["frames"]:
if "state" in frame:
if "slot_values" in frame["state"]:
for domain_slot_type in frame["state"]["slot_values"]:
for slot_value in frame["state"]["slot_values"][
domain_slot_type
]:
domain, slot_type = domain_slot_type.split("-")
belief_state = f"{domain} {slot_type} {slot_value.lower()}"
belief_states.append(belief_state)
return list(set(belief_states))

def _extract_slot_from_string(self, slots_string):
"""
Either ground truth or generated result should be in the format: "dom slot_type
slot_val, dom slot_type slot_val, ..., dom slot_type slot_val," and this
function would reformat the string into list:

["dom--slot_type--slot_val", ... ]
"""

slots_list = []
per_domain_slot_lists = {}
named_entity_slot_lists = []

# split according to ";"
str_split = slots_string.split(self.BELIEF_STATE_DELIM)

if str_split[-1] == "":
str_split = str_split[:-1]

str_split = [slot.strip() for slot in str_split]

for slot_ in str_split:
slot = slot_.split()
if len(slot) > 2 and slot[0] in self.domains:
domain = slot[0]
slot_type = slot[1]
slot_val = " ".join(slot[2:])
if not slot_val == "dontcare":
slots_list.append(domain + "--" + slot_type + "--" + slot_val)
if domain in per_domain_slot_lists:
per_domain_slot_lists[domain].add(slot_type + "--" + slot_val)
else:
per_domain_slot_lists[domain] = {slot_type + "--" + slot_val}
if domain + "--" + slot_type in self.named_entity_slots:
named_entity_slot_lists.append(
domain + "--" + slot_type + "--" + slot_val
)
return slots_list, per_domain_slot_lists, named_entity_slot_lists

def custom_evaluation(
self, teacher_action: Message, labels, model_response: Message
):
"""
for dialog state tracking, we compute the joint goal accuracy, which is the
percentage of the turns where the model correctly and precisely predicts all
slots(domain, slot_type, slot_value).
"""
resp = model_response.get("text")
if not resp:
return

# extract ground truth from labels
(
slots_truth,
slots_truth_per_domain,
slots_truth_named_entity,
) = self._extract_slot_from_string(labels[0])

# extract generated slots from model_response
(
slots_pred,
slots_pred_per_domain,
slots_pred_named_entity,
) = self._extract_slot_from_string(resp)

for gt_slot in slots_truth:
self.metrics.add("all/slot_r", AverageMetric(gt_slot in slots_pred))
curr_domain = gt_slot.split("--")[0]
self.metrics.add(
f"{curr_domain}/slot_r", AverageMetric(gt_slot in slots_pred)
)

for gt_slot in slots_pred_named_entity:
self.metrics.add(
"hallucination", AverageMetric(gt_slot not in slots_truth_named_entity)
)

for predicted_slot in slots_pred:
self.metrics.add("all/slot_p", AverageMetric(predicted_slot in slots_truth))
curr_domain = predicted_slot.split("--")[0]
self.metrics.add(
f"{curr_domain}/slot_p", AverageMetric(predicted_slot in slots_truth)
)

self.metrics.add("jga", AverageMetric(set(slots_truth) == set(slots_pred)))
self.metrics.add(
"named_entities/jga",
AverageMetric(
set(slots_truth_named_entity) == set(slots_pred_named_entity)
),
)
for gt_slot in slots_truth_named_entity:
self.metrics.add("all_ne/slot_r", AverageMetric(gt_slot in slots_pred))
curr_domain = gt_slot.split("--")[0]
self.metrics.add(
f"{curr_domain}_ne/slot_r", AverageMetric(gt_slot in slots_pred)
)
for predicted_slot in slots_pred_named_entity:
self.metrics.add(
"all_ne/slot_p", AverageMetric(predicted_slot in slots_truth)
)
curr_domain = predicted_slot.split("--")[0]
self.metrics.add(
f"{curr_domain}_ne/slot_p", AverageMetric(predicted_slot in slots_truth)
)

for domain in slots_truth_per_domain:
if domain in slots_pred_per_domain:
self.metrics.add(
f"{domain}/jga",
AverageMetric(
slots_truth_per_domain[domain] == slots_pred_per_domain[domain]
),
)

def setup_data(self, fold):
dialogs = self._load_data(fold)
examples = []
for dialog in dialogs:
context = []
for turn in dialog["turns"]:
curr_turn = turn["utterance"].lower()
curr_speaker = (
"<user>" if turn["speaker"].lower() == "user" else "<system>"
)
curr_context = f"{curr_speaker} {curr_turn}"
context.append(curr_context)
cum_belief_states = self._get_curr_belief_states(turn)
if curr_speaker == "<user>":
examples.append(
{
"dialogue_id": dialog["dialogue_id"],
"turn_num": turn["turn_id"],
"text": " ".join(context),
"labels": self.BELIEF_STATE_DELIM.join(
set(cum_belief_states)
),
}
)

self.rng.shuffle(examples)
for example in examples:
yield example, True


class UserSimulatorTeacher(MultiwozV22Parser, tod_agents.TodUserSimulatorTeacher):
pass

Expand Down
4 changes: 4 additions & 0 deletions parlai/tasks/multiwoz_v22/test.py
Expand Up @@ -13,3 +13,7 @@ class TestSystemTeacher(AutoTeacherTest):

class TestUserSimulatorTeacher(AutoTeacherTest):
task = "multiwoz_v22:UserSimulatorTeacher"


class TestMultiWOZv22DSTTeacher(AutoTeacherTest):
task = "multiwoz_v22:MultiWOZv22DSTTeacher"
2 changes: 1 addition & 1 deletion tests/nightly/gpu/test_bart.py
Expand Up @@ -43,7 +43,7 @@ def test_bart_gen(self):

def test_bart_cache_text_vec(self):
"""
Test BART text vec caching
Test BART text vec caching.
"""
opt = ParlaiParser(True, True).parse_args(['--model', 'bart'])
bart = create_agent(opt)
Expand Down