In [1]:
from typing import Iterator, List, Dict
import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import pandas as pd
import re
import os
import glob
import json
from collections import defaultdict
import itertools
import random
from tqdm import tqdm
from sklearn.metrics import f1_score

from allennlp.data import Instance
from allennlp.data.fields import TextField, MetadataField, ListField, ArrayField, LabelField, MultiLabelField
from allennlp.data.dataset_readers import DatasetReader
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.token_indexers.wordpiece_indexer import PretrainedBertIndexer
from allennlp.data.tokenizers import Token
from allennlp.data.tokenizers.word_splitter import BertBasicWordSplitter
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.dataset import Batch
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.modules.attention.bilinear_attention import BilinearAttention
from allennlp.nn import util
from allennlp.training.metrics import CategoricalAccuracy, BooleanAccuracy
from allennlp.training.metrics.metric import Metric
from allennlp.data.iterators import BucketIterator
from allennlp.training.trainer import Trainer
from allennlp.modules.token_embedders.bert_token_embedder import BertEmbedder, PretrainedBertEmbedder

import warnings
warnings.filterwarnings("ignore")

In [2]:
EXP = 'st_baseline2-5'
CUDA = 2
# correct loss function
# add frame_change indicator
# trainable GRU init state
# -2 layer hidden

In [3]:
bert_model = 'bert-base-cased'
SLOT_ACTS = dict(INFORM=0, CONFIRM=1, OFFER=2, REQUEST=3)
GENERAL_ACTS = dict(NOTIFY_SUCCESS=0, NOTIFY_FAILURE=1, REQ_MORE=2, GOODBYE=3, INFORM_COUNT=4)
class DialogueDatasetReader(DatasetReader):
    def __init__(self,
                 bert_model: str = bert_model,
                 token_indexers: Dict[str, TokenIndexer] = None) -> None:
        super().__init__(lazy=False)
        do_lowercase = 'uncased' in bert_model
        self.token_indexers = token_indexers or {
            "tokens": PretrainedBertIndexer(bert_model, do_lowercase=do_lowercase, use_starting_offsets=True)}
        self.tokenizer = BertBasicWordSplitter(do_lower_case=do_lowercase)
        self.intent_cache = {}
        self.slot_cache = {}


    def text_to_instance(self, dialogue_id: str,
                         turn_id: int,
                         user_utterance: str,
                         service: str,
                         general_encoding: np.ndarray,
                         intent_encoding: np.ndarray,
                         slot_encoding: np.ndarray,
                         active_intent: str = None,
                         requested_slots: List[str] = None) -> Instance:
        meta_field = MetadataField(dict(dialogue_id=dialogue_id, turn_id=turn_id, service=service))
        user_utterance_field = TextField(self.tokenizer.split_words(user_utterance), self.token_indexers)

        intent_field = ListField(
            [TextField(self.tokenizer.split_words(intent['description']), self.token_indexers)
             for intent in self.intent_schema[service]])
        slot_field = ListField(
            [TextField(self.tokenizer.split_words(slot['description']), self.token_indexers)
             for slot in self.slot_schema[service]])
        
        fields = {"meta": meta_field,
                  "user_utterance": user_utterance_field,
                  "intent_descriptions": intent_field,
                  "slot_descriptions": slot_field,
                  "general_acts": ArrayField(general_encoding),
                  "intent_acts": ArrayField(intent_encoding),
                  "slot_acts": ArrayField(slot_encoding)}

        if active_intent:
            fields['intent_label'] = ArrayField(np.array(self.intent_to_index[service][active_intent]))
            fields['slot_label'] = ArrayField(
                np.array([int(slot['name'] in requested_slots) for slot in self.slot_schema[service]]),
                padding_value=-1
            )

        return Instance(fields)


    def _read(self, file_path: str) -> Iterator[Instance]:
        
        with open(os.path.join(file_path, 'schema.json'), 'r') as f:
            schema = json.load(f)
        self.intent_schema = {x['service_name']: x['intents'] for x in schema}
        self.slot_schema = {x['service_name']: x['slots'] for x in schema}
        self.index_to_intent = {
            service: ['NONE', 'UNCHANGED'] + [intent['name'] for intent in intents] 
            for service, intents in self.intent_schema.items()
        }
        self.intent_to_index = {
            service: {intent: i for i, intent in enumerate(intents)}
            for service, intents in self.index_to_intent.items()
        }
        self.index_to_slot = {
            service: [slot['name'] for slot in slots] 
            for service, slots in self.slot_schema.items()
        }
        self.slot_to_index = {
            service: {slot: i for i, slot in enumerate(slots)}
            for service, slots in self.index_to_slot.items()
        }
        for file in sorted(glob.glob(os.path.join(file_path, 'dialogues*.json'))):
            with open(file, 'r') as f:
                ds = json.load(f)
            for dialog in ds:
                dialogue_id = dialog['dialogue_id']
                turn_id = 0
                last_service = None
                last_intent = None
                act_dict = {}
                intent_encoding = None
                for turn in dialog['turns']:
                    if turn['speaker'] == 'SYSTEM':  # system action encoding
                        for frame in turn['frames']:  # only one frame
                            service = frame['service']
                            general_encoding = np.zeros(len(GENERAL_ACTS) + 1)
                            intent_encoding = np.zeros(len(self.intent_schema[service]))
                            slot_encoding = np.zeros((len(self.slot_schema[service]), len(SLOT_ACTS)))
                            for action in frame['actions']:
                                if action['act'] == 'OFFER_INTENT':
                                    for offered_intent in action['values']:
                                        i = self.intent_to_index[service][offered_intent] - 2  # no NONE here
                                        intent_encoding[i] = 1
                                elif action['act'] in SLOT_ACTS:
                                    i = self.slot_to_index[service][action['slot']]
                                    j = SLOT_ACTS[action['act']]
                                    slot_encoding[i, j] = 1
                                else:
                                    i = GENERAL_ACTS[action['act']]
                                    general_encoding[i] = 1
                                    if action['act'] == 'INFORM_COUNT':
                                        general_encoding[-1] = int(action['values'][0])
                            act_dict[service] = (general_encoding, intent_encoding, slot_encoding)
                        turn_id += 1
                        continue
                    if len(turn['frames']) > 1:  # sort
                        sorted_frames = []
                        for frame in turn['frames']:
                            if frame['service'] == last_service:
                                sorted_frames = [frame] + sorted_frames
                            else:
                                sorted_frames.append(frame)
                        turn['frames'] = sorted_frames
                    for frame in turn['frames']:
                        user_utterance = turn['utterance']
                        service = frame['service']
                        general_encoding, intent_encoding, slot_encoding = act_dict.get(
                            service, 
                            (np.zeros(len(GENERAL_ACTS) + 1),
                             np.zeros(len(self.intent_schema[service])),
                             np.zeros((len(self.slot_schema[service]), len(SLOT_ACTS))))
                        )
                        if 'active_intent' in frame['state']:
                            if last_intent == frame['state']['active_intent']:
                                active_intent = 'UNCHANGED'
                            else:
                                active_intent = frame['state']['active_intent']
                                last_intent = active_intent
                            requested_slots = frame['state']['requested_slots']
                        else:
                            active_intent = None
                            requested_slots = None
                        new_service = 1 if service != last_service else 0
                        general_encoding = np.concatenate([general_encoding, [new_service]])
                        yield self.text_to_instance(dialogue_id, turn_id, user_utterance, service,
                                                    general_encoding, intent_encoding, slot_encoding,
                                                    active_intent, requested_slots)
                        last_service = service

In [4]:
class DialogueIterator(object):
    """
    a custom iterator providing each dialogue as a batch
    """
    def __init__(self):
        self._epochs = defaultdict(int)
        self._num_batches_cache = {}
        self._cache: Dict[int, List[TensorDict]] = defaultdict(list)
    
    
    def __call__(self, instances,
                 num_epochs: int = None,
                 shuffle: bool = True):
        key = id(instances)
        starting_epoch = self._epochs[key]

        if num_epochs is None:
            epochs: Iterable[int] = itertools.count(starting_epoch)
        else:
            epochs = range(starting_epoch, starting_epoch + num_epochs)

        for epoch in epochs:
            if key not in self._cache:
                self._cache[key] = self._create_tensor_dicts(instances)
            tensor_dicts = self._cache[key]
            if shuffle:
                random.shuffle(tensor_dicts)
            for tensor_dict in tensor_dicts:
                yield tensor_dict
            self._epochs[key] = epoch + 1
            
        
    def _create_tensor_dicts(self, instances):
        tensor_dicts = []
        dialog_dict = defaultdict(list)
        for instance in instances:
            dialogue_id = instance.fields['meta'].metadata['dialogue_id']
            dialog_dict[dialogue_id].append(instance)
        for dialog in dialog_dict.values():
            batch = Batch(dialog)
            if self.vocab is not None:
                batch.index_instances(self.vocab)
            padding_lengths = batch.get_padding_lengths()
            tensor_dict = batch.as_tensor_dict(padding_lengths)
            tensor_dicts.append(tensor_dict)
        return tensor_dicts

    
    def get_num_batches(self, instances) -> int:
        key = id(instances)
        if key not in self._num_batches_cache:
            self._num_batches_cache[key] = len(
                set(instance.fields['meta'].metadata['dialogue_id'] for instance in instances))
        return self._num_batches_cache[key]
        
        
    def index_with(self, vocab: Vocabulary):
        self.vocab = vocab

In [5]:
class LossMetric(Metric):
    def __init__(self):
        self.sum = 0.
        self.count = 0.
    
    def __call__(self, loss: torch.Tensor):
        self.sum += loss.detach().cpu()
        self.count += 1
    
    def get_metric(self, reset: bool = False):
        if self.count > 1e-12:
            loss =  float(self.sum / self.count)
        else:
            loss = 0.
        if reset:
            self.reset()
        return loss
    
    def reset(self):
        self.sum = 0.
        self.count = 0.

In [6]:
class StateTracker(Model):
    def __init__(self,
                 sentence_embedder: TextFieldEmbedder,
                 hidden_size: int,
                 vocab: Vocabulary,
                 general_act_size: int = len(GENERAL_ACTS) + 2,
                 intent_act_size: int = 1,
                 slot_act_size: int = len(SLOT_ACTS),
                 dropout: float = 0.1) -> None:
        super().__init__(vocab)
        self.sentence_embedder = sentence_embedder
        self.dropout = torch.nn.Dropout(dropout)
        emb_size = sentence_embedder.get_output_dim()
        self.intent_projection = torch.nn.Linear(emb_size + intent_act_size, hidden_size)
        self.slot_projection = torch.nn.Linear(emb_size + slot_act_size, hidden_size)
        self.gru = torch.nn.GRU(input_size=hidden_size * 2 + emb_size + general_act_size,
                                hidden_size=hidden_size)
        self.h_0 = torch.nn.Parameter(torch.randn(1, 1, hidden_size))
        self.intent_output_layer = torch.nn.Linear(hidden_size + emb_size*2 + intent_act_size + general_act_size, 1)
        self.none_and_unchanged_layer = torch.nn.Linear(hidden_size, 2)
        self.slot_output_layer = torch.nn.Linear(hidden_size + emb_size*2 + slot_act_size + general_act_size, 1)
        self.intent_accuracy = CategoricalAccuracy()
        self.slot_accuracy = BooleanAccuracy()
        self.intent_loss = LossMetric()
        self.slot_loss = LossMetric()
        
    
    def get_sentence_embedding(self, tensor_dict: Dict[str, torch.Tensor]):
        token_embeddings = self.dropout(self.sentence_embedder(tensor_dict))
        mask = (tensor_dict['tokens'] > 0).byte()  # mask for original bert
        sentence_embedding = util.get_final_encoder_states(token_embeddings, mask, bidirectional=True)
        return sentence_embedding
    
    
    def get_valid_description_embeddings(self, descriptions: torch.Tensor):
        index = torch.arange((descriptions[:, 0] > 0).sum(), device=descriptions.get_device())
        valid_descriptions = descriptions.index_select(0, index)
        embeddings = self.get_sentence_embedding({'tokens': valid_descriptions})
        return embeddings
    
    
    def get_desc_emb_and_mask(self, tensor_dict: Dict[str, torch.Tensor]):
        emb_list = [self.get_valid_description_embeddings(descriptions)
                    for descriptions in tensor_dict['tokens']]
        emb = pad_sequence(emb_list, batch_first=True)
        mask = (tensor_dict['tokens'][:, :, 0] > 0).byte()
        return emb, mask


    def forward(self,
                meta: Dict[str, str],
                user_utterance: Dict[str, torch.Tensor],
                intent_descriptions: Dict[str, torch.Tensor],
                slot_descriptions: Dict[str, torch.Tensor],
                general_acts: torch.Tensor,
                intent_acts: torch.Tensor,
                slot_acts: torch.Tensor,
                intent_label: torch.Tensor = None,
                slot_label: torch.Tensor = None) -> Dict[str, torch.Tensor]:
        user_utterance_emb = self.get_sentence_embedding(user_utterance)
        # user_utterance_emb: (turn_size, emb_size)
        
        intent_emb, intent_mask = self.get_desc_emb_and_mask(intent_descriptions)
        intent_input = self.dropout(torch.cat([intent_emb, intent_acts.unsqueeze(-1)], dim=2))
        intent_encodings = F.relu(self.intent_projection(intent_input))
        intent_encoding = util.masked_mean(
            intent_encodings, intent_mask.unsqueeze(-1).expand(-1, -1, intent_encodings.size(2)), dim=1)

        slot_emb, slot_mask = self.get_desc_emb_and_mask(slot_descriptions)
        slot_input = self.dropout(torch.cat([slot_emb, slot_acts], dim=2))
        slot_encodings = F.relu(self.slot_projection(slot_input))
        slot_encoding = util.masked_mean(
            slot_encodings, slot_mask.unsqueeze(-1).expand(-1, -1, slot_encodings.size(2)), dim=1)

        general_encoding = torch.cat([user_utterance_emb, general_acts], dim=1)
        gru_input = torch.cat([general_encoding, intent_encoding, slot_encoding],
                              dim=1).unsqueeze(1)  # trick: batch dim 1 here
        states, _ = self.gru(gru_input, self.h_0)
        # states: (turn_size, 1, hidden_size)
        
        turn_input = torch.cat([states, general_encoding.unsqueeze(1)], dim=2)
        
        intent_num = intent_encodings.size(1)
        intent_state = torch.cat([intent_input, turn_input.expand(-1, intent_num, -1)], dim=2)
        intent_logits = self.intent_output_layer(intent_state).squeeze(2)
        intent_logits.masked_fill_((1 - intent_mask), float('-inf'))
        special_logits = self.none_and_unchanged_layer(states.squeeze(1))  # NONE and UNCHANGED intents
        intent_logits = torch.cat([special_logits, intent_logits], axis=1)
        
        slot_num = slot_encodings.size(1)
        slot_state = torch.cat([slot_input, turn_input.expand(-1, slot_num, -1)], dim=2)
        slot_sigmoids = F.sigmoid(self.slot_output_layer(slot_state).squeeze())
        
        output = dict(meta=meta, intent_logits=intent_logits, slot_sigmoids=slot_sigmoids)

        if intent_label is not None:
            output['intent_label'] = intent_label
            output['slot_label'] = slot_label
            self.intent_accuracy(intent_logits, intent_label)
            self.slot_accuracy((slot_sigmoids > 0.5).float(), slot_label, slot_mask.float())
            intent_loss = F.cross_entropy(intent_logits, intent_label.long())
            slot_loss = util.masked_mean(F.binary_cross_entropy(slot_sigmoids, slot_label, reduce=False),
                                         slot_mask, dim=1).mean()
            self.intent_loss(intent_loss)
            self.slot_loss(slot_loss)
            output['loss'] = intent_loss + slot_loss
        return output
    
    
    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"intent_acc": self.intent_accuracy.get_metric(reset),
                "slot_acc": self.slot_accuracy.get_metric(reset),
                "intent_loss": self.intent_loss.get_metric(reset),
                "slot_loss": self.slot_loss.get_metric(reset)}

In [7]:
reader = DialogueDatasetReader()
train_dataset = reader.read('data/train')
validation_dataset = reader.read('data/dev')
vocab = Vocabulary.from_instances(train_dataset + validation_dataset)
HIDDEN_DIM = 512
token_embedding = PretrainedBertEmbedder(bert_model, top_layer_only=True)
sentence_embedder = BasicTextFieldEmbedder(
    {"tokens": token_embedding}, 
    embedder_to_indexer_map={"tokens": ["tokens"]},
    allow_unmatched_keys=True
)

175780it [06:54, 423.91it/s]
26077it [01:01, 423.47it/s]
100%|██████████| 201857/201857 [00:06<00:00, 31343.31it/s]


In [8]:
model = StateTracker(sentence_embedder, HIDDEN_DIM, vocab)
if torch.cuda.is_available():
    cuda_device = CUDA
    model = model.cuda(cuda_device)
else:
    cuda_device = -1
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
iterator = DialogueIterator()
iterator.index_with(vocab)
serialization_dir = f'models/tensorboard_{EXP}'
trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  patience=3,
                  num_epochs=1,
                  serialization_dir=serialization_dir,
                  summary_interval=200,
                  grad_norm=5,
                  cuda_device=cuda_device)
trainer.train()

intent_acc: 0.4781, slot_acc: 0.8531, intent_loss: 3.5848, slot_loss: 0.1377, loss: 3.7225 ||:   1%|          | 200/16142 [08:48<2:34:26,  1.72it/s]  



intent_acc: 0.5203, slot_acc: 0.8559, intent_loss: 3.0827, slot_loss: 0.1480, loss: 3.2307 ||:   2%|▏         | 400/16142 [11:04<2:54:28,  1.50it/s]



intent_acc: 0.5505, slot_acc: 0.8573, intent_loss: 2.7788, slot_loss: 0.1376, loss: 2.9164 ||:   4%|▎         | 600/16142 [13:25<3:51:44,  1.12it/s]



intent_acc: 0.5727, slot_acc: 0.8575, intent_loss: 2.6307, slot_loss: 0.1298, loss: 2.7605 ||:   5%|▍         | 800/16142 [15:28<3:11:01,  1.34it/s]



intent_acc: 0.5933, slot_acc: 0.8582, intent_loss: 2.4831, slot_loss: 0.1283, loss: 2.6114 ||:   6%|▌         | 1000/16142 [17:33<2:58:22,  1.41it/s]



intent_acc: 0.6125, slot_acc: 0.8588, intent_loss: 2.3657, slot_loss: 0.1262, loss: 2.4919 ||:   7%|▋         | 1200/16142 [19:36<2:57:10,  1.41it/s]



intent_acc: 0.6274, slot_acc: 0.8577, intent_loss: 2.2842, slot_loss: 0.1240, loss: 2.4082 ||:   9%|▊         | 1400/16142 [21:38<2:15:18,  1.82it/s]



intent_acc: 0.6384, slot_acc: 0.8572, intent_loss: 2.1934, slot_loss: 0.1268, loss: 2.3202 ||:  10%|▉         | 1600/16142 [23:42<2:29:08,  1.63it/s]



intent_acc: 0.6484, slot_acc: 0.8562, intent_loss: 2.1321, slot_loss: 0.1293, loss: 2.2614 ||:  11%|█         | 1800/16142 [25:52<2:50:41,  1.40it/s]



intent_acc: 0.6563, slot_acc: 0.8557, intent_loss: 2.0785, slot_loss: 0.1274, loss: 2.2059 ||:  12%|█▏        | 2000/16142 [27:58<2:54:10,  1.35it/s]



intent_acc: 0.6645, slot_acc: 0.8557, intent_loss: 2.0134, slot_loss: 0.1281, loss: 2.1414 ||:  14%|█▎        | 2200/16142 [30:17<3:19:28,  1.16it/s]



intent_acc: 0.6703, slot_acc: 0.8570, intent_loss: 1.9781, slot_loss: 0.1255, loss: 2.1037 ||:  15%|█▍        | 2400/16142 [32:40<2:48:05,  1.36it/s]



intent_acc: 0.6770, slot_acc: 0.8575, intent_loss: 1.9344, slot_loss: 0.1231, loss: 2.0575 ||:  16%|█▌        | 2600/16142 [34:44<2:36:05,  1.45it/s]



intent_acc: 0.6828, slot_acc: 0.8570, intent_loss: 1.9008, slot_loss: 0.1230, loss: 2.0237 ||:  17%|█▋        | 2800/16142 [37:05<2:58:41,  1.24it/s]



intent_acc: 0.6886, slot_acc: 0.8571, intent_loss: 1.8575, slot_loss: 0.1229, loss: 1.9804 ||:  19%|█▊        | 3000/16142 [39:20<3:04:38,  1.19it/s]



intent_acc: 0.6931, slot_acc: 0.8572, intent_loss: 1.8259, slot_loss: 0.1221, loss: 1.9480 ||:  20%|█▉        | 3200/16142 [41:32<2:34:55,  1.39it/s]



intent_acc: 0.6968, slot_acc: 0.8576, intent_loss: 1.7939, slot_loss: 0.1204, loss: 1.9143 ||:  21%|██        | 3400/16142 [43:41<2:28:12,  1.43it/s]



intent_acc: 0.7007, slot_acc: 0.8576, intent_loss: 1.7678, slot_loss: 0.1213, loss: 1.8891 ||:  22%|██▏       | 3600/16142 [45:45<2:50:47,  1.22it/s]



intent_acc: 0.7048, slot_acc: 0.8575, intent_loss: 1.7363, slot_loss: 0.1216, loss: 1.8579 ||:  24%|██▎       | 3800/16142 [47:50<2:29:35,  1.38it/s]



intent_acc: 0.7084, slot_acc: 0.8570, intent_loss: 1.7123, slot_loss: 0.1229, loss: 1.8352 ||:  25%|██▍       | 4000/16142 [49:54<2:04:34,  1.62it/s]



intent_acc: 0.7102, slot_acc: 0.8572, intent_loss: 1.7016, slot_loss: 0.1214, loss: 1.8230 ||:  26%|██▌       | 4200/16142 [51:55<2:52:59,  1.15it/s]



intent_acc: 0.7138, slot_acc: 0.8570, intent_loss: 1.6763, slot_loss: 0.1214, loss: 1.7977 ||:  27%|██▋       | 4400/16142 [54:02<2:16:38,  1.43it/s]



intent_acc: 0.7160, slot_acc: 0.8575, intent_loss: 1.6689, slot_loss: 0.1207, loss: 1.7896 ||:  28%|██▊       | 4600/16142 [56:05<1:54:31,  1.68it/s]



intent_acc: 0.7187, slot_acc: 0.8578, intent_loss: 1.6479, slot_loss: 0.1193, loss: 1.7671 ||:  30%|██▉       | 4800/16142 [58:11<2:21:40,  1.33it/s]



intent_acc: 0.7210, slot_acc: 0.8578, intent_loss: 1.6338, slot_loss: 0.1187, loss: 1.7525 ||:  31%|███       | 5000/16142 [1:00:14<2:53:52,  1.07it/s]



intent_acc: 0.7231, slot_acc: 0.8581, intent_loss: 1.6143, slot_loss: 0.1178, loss: 1.7321 ||:  32%|███▏      | 5200/16142 [1:02:16<2:00:25,  1.51it/s]



intent_acc: 0.7243, slot_acc: 0.8581, intent_loss: 1.6044, slot_loss: 0.1179, loss: 1.7223 ||:  33%|███▎      | 5400/16142 [1:04:24<2:13:17,  1.34it/s]



intent_acc: 0.7268, slot_acc: 0.8583, intent_loss: 1.5907, slot_loss: 0.1169, loss: 1.7075 ||:  35%|███▍      | 5600/16142 [1:06:33<2:06:24,  1.39it/s]



intent_acc: 0.7291, slot_acc: 0.8577, intent_loss: 1.5728, slot_loss: 0.1181, loss: 1.6909 ||:  36%|███▌      | 5800/16142 [1:08:41<2:35:31,  1.11it/s]



intent_acc: 0.7307, slot_acc: 0.8582, intent_loss: 1.5660, slot_loss: 0.1169, loss: 1.6828 ||:  37%|███▋      | 6000/16142 [1:10:49<1:53:45,  1.49it/s]



intent_acc: 0.7321, slot_acc: 0.8580, intent_loss: 1.5569, slot_loss: 0.1165, loss: 1.6734 ||:  38%|███▊      | 6200/16142 [1:13:03<2:34:44,  1.07it/s]



intent_acc: 0.7336, slot_acc: 0.8582, intent_loss: 1.5458, slot_loss: 0.1159, loss: 1.6617 ||:  40%|███▉      | 6400/16142 [1:15:18<2:10:57,  1.24it/s]



intent_acc: 0.7352, slot_acc: 0.8584, intent_loss: 1.5343, slot_loss: 0.1158, loss: 1.6501 ||:  41%|████      | 6600/16142 [1:17:36<1:40:26,  1.58it/s]



intent_acc: 0.7367, slot_acc: 0.8583, intent_loss: 1.5221, slot_loss: 0.1155, loss: 1.6376 ||:  42%|████▏     | 6800/16142 [1:20:04<3:04:36,  1.19s/it]



intent_acc: 0.7375, slot_acc: 0.8582, intent_loss: 1.5129, slot_loss: 0.1153, loss: 1.6282 ||:  43%|████▎     | 7000/16142 [1:22:22<1:55:11,  1.32it/s]



intent_acc: 0.7388, slot_acc: 0.8582, intent_loss: 1.5048, slot_loss: 0.1148, loss: 1.6196 ||:  45%|████▍     | 7200/16142 [1:24:41<2:06:50,  1.17it/s]



intent_acc: 0.7403, slot_acc: 0.8580, intent_loss: 1.4929, slot_loss: 0.1158, loss: 1.6087 ||:  46%|████▌     | 7400/16142 [1:27:32<2:32:14,  1.04s/it]



intent_acc: 0.7416, slot_acc: 0.8580, intent_loss: 1.4846, slot_loss: 0.1158, loss: 1.6003 ||:  47%|████▋     | 7600/16142 [1:30:59<3:15:40,  1.37s/it]



intent_acc: 0.7425, slot_acc: 0.8584, intent_loss: 1.4806, slot_loss: 0.1148, loss: 1.5954 ||:  48%|████▊     | 7800/16142 [1:33:48<2:37:51,  1.14s/it]



intent_acc: 0.7436, slot_acc: 0.8585, intent_loss: 1.4717, slot_loss: 0.1146, loss: 1.5863 ||:  50%|████▉     | 8000/16142 [1:41:21<1:58:27,  1.15it/s] 



intent_acc: 0.7445, slot_acc: 0.8587, intent_loss: 1.4639, slot_loss: 0.1143, loss: 1.5782 ||:  51%|█████     | 8200/16142 [1:46:15<2:20:23,  1.06s/it] 



intent_acc: 0.7457, slot_acc: 0.8588, intent_loss: 1.4568, slot_loss: 0.1135, loss: 1.5703 ||:  52%|█████▏    | 8400/16142 [1:49:00<2:05:11,  1.03it/s]



intent_acc: 0.7467, slot_acc: 0.8587, intent_loss: 1.4486, slot_loss: 0.1132, loss: 1.5618 ||:  53%|█████▎    | 8600/16142 [1:51:46<1:46:42,  1.18it/s]



intent_acc: 0.7480, slot_acc: 0.8593, intent_loss: 1.4402, slot_loss: 0.1124, loss: 1.5527 ||:  55%|█████▍    | 8800/16142 [1:54:45<2:00:27,  1.02it/s]



intent_acc: 0.7493, slot_acc: 0.8594, intent_loss: 1.4311, slot_loss: 0.1123, loss: 1.5433 ||:  56%|█████▌    | 9000/16142 [1:57:21<1:56:37,  1.02it/s]



intent_acc: 0.7508, slot_acc: 0.8596, intent_loss: 1.4239, slot_loss: 0.1121, loss: 1.5360 ||:  57%|█████▋    | 9200/16142 [1:59:52<1:49:04,  1.06it/s]



intent_acc: 0.7522, slot_acc: 0.8598, intent_loss: 1.4135, slot_loss: 0.1121, loss: 1.5256 ||:  58%|█████▊    | 9400/16142 [2:02:24<1:49:24,  1.03it/s]



intent_acc: 0.7531, slot_acc: 0.8596, intent_loss: 1.4071, slot_loss: 0.1122, loss: 1.5193 ||:  59%|█████▉    | 9600/16142 [2:05:02<1:50:56,  1.02s/it]



intent_acc: 0.7536, slot_acc: 0.8598, intent_loss: 1.4046, slot_loss: 0.1120, loss: 1.5166 ||:  61%|██████    | 9800/16142 [2:07:40<2:22:23,  1.35s/it]



intent_acc: 0.7543, slot_acc: 0.8600, intent_loss: 1.3997, slot_loss: 0.1115, loss: 1.5112 ||:  62%|██████▏   | 10000/16142 [2:10:04<1:16:25,  1.34it/s]



intent_acc: 0.7556, slot_acc: 0.8601, intent_loss: 1.3912, slot_loss: 0.1113, loss: 1.5025 ||:  63%|██████▎   | 10200/16142 [2:13:07<2:07:15,  1.28s/it]



intent_acc: 0.7564, slot_acc: 0.8602, intent_loss: 1.3843, slot_loss: 0.1111, loss: 1.4954 ||:  64%|██████▍   | 10400/16142 [2:16:13<1:32:34,  1.03it/s]



intent_acc: 0.7573, slot_acc: 0.8599, intent_loss: 1.3799, slot_loss: 0.1110, loss: 1.4909 ||:  66%|██████▌   | 10600/16142 [2:19:02<1:27:40,  1.05it/s]



intent_acc: 0.7579, slot_acc: 0.8602, intent_loss: 1.3768, slot_loss: 0.1105, loss: 1.4874 ||:  67%|██████▋   | 10800/16142 [2:22:00<1:31:06,  1.02s/it]



intent_acc: 0.7589, slot_acc: 0.8603, intent_loss: 1.3692, slot_loss: 0.1104, loss: 1.4796 ||:  68%|██████▊   | 11000/16142 [2:25:05<1:22:29,  1.04it/s]



intent_acc: 0.7596, slot_acc: 0.8605, intent_loss: 1.3661, slot_loss: 0.1103, loss: 1.4765 ||:  69%|██████▉   | 11200/16142 [2:28:06<1:27:30,  1.06s/it]



intent_acc: 0.7605, slot_acc: 0.8607, intent_loss: 1.3598, slot_loss: 0.1100, loss: 1.4698 ||:  71%|███████   | 11400/16142 [2:31:07<1:26:27,  1.09s/it]



intent_acc: 0.7611, slot_acc: 0.8607, intent_loss: 1.3549, slot_loss: 0.1099, loss: 1.4648 ||:  72%|███████▏  | 11600/16142 [2:33:45<55:22,  1.37it/s]  



intent_acc: 0.7617, slot_acc: 0.8607, intent_loss: 1.3538, slot_loss: 0.1098, loss: 1.4636 ||:  73%|███████▎  | 11800/16142 [2:36:13<1:15:15,  1.04s/it]



intent_acc: 0.7624, slot_acc: 0.8608, intent_loss: 1.3480, slot_loss: 0.1096, loss: 1.4576 ||:  74%|███████▍  | 12000/16142 [2:39:25<57:09,  1.21it/s]   



intent_acc: 0.7631, slot_acc: 0.8608, intent_loss: 1.3452, slot_loss: 0.1092, loss: 1.4544 ||:  76%|███████▌  | 12200/16142 [2:41:48<59:02,  1.11it/s]  



intent_acc: 0.7636, slot_acc: 0.8608, intent_loss: 1.3423, slot_loss: 0.1093, loss: 1.4516 ||:  77%|███████▋  | 12400/16142 [2:44:13<51:57,  1.20it/s]  



intent_acc: 0.7641, slot_acc: 0.8608, intent_loss: 1.3390, slot_loss: 0.1093, loss: 1.4482 ||:  78%|███████▊  | 12600/16142 [2:46:39<52:57,  1.11it/s]  



intent_acc: 0.7651, slot_acc: 0.8608, intent_loss: 1.3325, slot_loss: 0.1091, loss: 1.4416 ||:  79%|███████▉  | 12800/16142 [2:49:04<47:58,  1.16it/s]  



intent_acc: 0.7657, slot_acc: 0.8609, intent_loss: 1.3268, slot_loss: 0.1090, loss: 1.4358 ||:  81%|████████  | 13000/16142 [2:51:24<43:49,  1.19it/s]



intent_acc: 0.7663, slot_acc: 0.8610, intent_loss: 1.3222, slot_loss: 0.1089, loss: 1.4311 ||:  82%|████████▏ | 13200/16142 [2:53:49<46:59,  1.04it/s]



intent_acc: 0.7666, slot_acc: 0.8610, intent_loss: 1.3206, slot_loss: 0.1089, loss: 1.4295 ||:  83%|████████▎ | 13400/16142 [2:56:11<35:46,  1.28it/s]



intent_acc: 0.7672, slot_acc: 0.8608, intent_loss: 1.3165, slot_loss: 0.1091, loss: 1.4255 ||:  84%|████████▍ | 13600/16142 [2:58:37<41:02,  1.03it/s]



intent_acc: 0.7675, slot_acc: 0.8608, intent_loss: 1.3132, slot_loss: 0.1093, loss: 1.4225 ||:  85%|████████▌ | 13800/16142 [3:01:06<38:51,  1.00it/s]



intent_acc: 0.7681, slot_acc: 0.8608, intent_loss: 1.3078, slot_loss: 0.1094, loss: 1.4172 ||:  87%|████████▋ | 14000/16142 [3:03:32<26:20,  1.36it/s]



intent_acc: 0.7686, slot_acc: 0.8605, intent_loss: 1.3051, slot_loss: 0.1097, loss: 1.4147 ||:  88%|████████▊ | 14200/16142 [3:06:03<33:08,  1.02s/it]



intent_acc: 0.7690, slot_acc: 0.8606, intent_loss: 1.3034, slot_loss: 0.1093, loss: 1.4127 ||:  89%|████████▉ | 14400/16142 [3:08:24<27:56,  1.04it/s]



intent_acc: 0.7694, slot_acc: 0.8607, intent_loss: 1.3013, slot_loss: 0.1091, loss: 1.4104 ||:  90%|█████████ | 14600/16142 [3:10:53<20:57,  1.23it/s]



intent_acc: 0.7701, slot_acc: 0.8610, intent_loss: 1.2958, slot_loss: 0.1087, loss: 1.4045 ||:  92%|█████████▏| 14800/16142 [3:13:16<23:32,  1.05s/it]



intent_acc: 0.7704, slot_acc: 0.8611, intent_loss: 1.2939, slot_loss: 0.1085, loss: 1.4024 ||:  93%|█████████▎| 15000/16142 [3:15:42<19:47,  1.04s/it]



intent_acc: 0.7711, slot_acc: 0.8611, intent_loss: 1.2896, slot_loss: 0.1086, loss: 1.3981 ||:  94%|█████████▍| 15200/16142 [3:18:11<14:33,  1.08it/s]



intent_acc: 0.7715, slot_acc: 0.8610, intent_loss: 1.2866, slot_loss: 0.1087, loss: 1.3953 ||:  95%|█████████▌| 15400/16142 [3:20:40<12:55,  1.04s/it]



intent_acc: 0.7721, slot_acc: 0.8608, intent_loss: 1.2823, slot_loss: 0.1088, loss: 1.3911 ||:  97%|█████████▋| 15600/16142 [3:23:37<10:18,  1.14s/it]



intent_acc: 0.7726, slot_acc: 0.8609, intent_loss: 1.2790, slot_loss: 0.1086, loss: 1.3876 ||:  98%|█████████▊| 15800/16142 [3:26:18<05:44,  1.01s/it]



intent_acc: 0.7729, slot_acc: 0.8609, intent_loss: 1.2771, slot_loss: 0.1086, loss: 1.3856 ||:  99%|█████████▉| 16000/16142 [3:28:43<02:18,  1.03it/s]



intent_acc: 0.7732, slot_acc: 0.8608, intent_loss: 1.2756, slot_loss: 0.1088, loss: 1.3845 ||: 100%|██████████| 16142/16142 [3:30:26<00:00,  1.34it/s]
intent_acc: 0.8809, slot_acc: 0.8820, intent_loss: 0.4040, slot_loss: 0.0607, loss: 0.4647 ||: 100%|██████████| 2482/2482 [28:33<00:00,  1.36it/s]


{'best_epoch': 0,
 'peak_cpu_memory_MB': 8097.164,
 'peak_gpu_0_memory_MB': 27349,
 'peak_gpu_1_memory_MB': 7416,
 'peak_gpu_2_memory_MB': 5558,
 'peak_gpu_3_memory_MB': 19977,
 'training_duration': '3:59:01.512742',
 'training_start_epoch': 0,
 'training_epochs': 0,
 'epoch': 0,
 'training_intent_acc': 0.773165320286722,
 'training_slot_acc': 0.8608487882580498,
 'training_intent_loss': 1.2756425142288208,
 'training_slot_loss': 0.1088167279958725,
 'training_loss': 1.3844616618133105,
 'training_cpu_memory_MB': 8097.164,
 'training_gpu_0_memory_MB': 27349,
 'training_gpu_1_memory_MB': 7416,
 'training_gpu_2_memory_MB': 5558,
 'training_gpu_3_memory_MB': 19977,
 'validation_intent_acc': 0.8809295547800744,
 'validation_slot_acc': 0.8819649499558998,
 'validation_intent_loss': 0.4039567708969116,
 'validation_slot_loss': 0.06074870750308037,
 'validation_loss': 0.46470570491879043,
 'best_validation_intent_acc': 0.8809295547800744,
 'best_validation_slot_acc': 0.8819649499558998,
 'bes

In [9]:
# save model
with open(f"/tmp/model_{EXP}.th", 'wb') as f:
    torch.save(model.state_dict(), f)
vocab.save_to_files(f"/tmp/vocabulary_{EXP}")

In [10]:
def predict(file_path, output_path=f'data/predict/{EXP}', slot_threshold=0.5):
    model.eval()
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    if file_path == 'data/dev':
        eval_dataset = validation_dataset
    else:
        eval_dataset = reader.read(file_path)
    result = {}
    all_outputs = []
    for inputs in tqdm(iterator(eval_dataset, num_epochs=1)):
        last_service = None
        last_active_intent = None
        inputs = util.move_to_device(inputs, cuda_device)
        outputs = model(**inputs)
        all_outputs.append(outputs)
        for meta, intent_logits, slot_sigmoids in zip(outputs['meta'], 
                                                      outputs['intent_logits'], 
                                                      outputs['slot_sigmoids']):
            dialogue_id = meta['dialogue_id']
            turn_id = meta['turn_id']
            service = meta['service']
            key = (dialogue_id, turn_id, service)
            if last_service != service:  # Force to predict some intent
                intent_logits[0] = float('-inf')
                intent_logits[1] = float('-inf')
            intent_num = len(reader.index_to_intent[service])
            intent_index = intent_logits[:intent_num].argmax()
            active_intent = reader.index_to_intent[service][intent_index]
            if active_intent == 'UNCHANGED':
                active_intent = last_active_intent
            else:
                last_active_intent = active_intent
            slot_num = len(reader.index_to_slot[service])
            requested_slots = [reader.index_to_slot[service][i] 
                               for i, sigmoid in enumerate(slot_sigmoids[:slot_num]) if sigmoid > slot_threshold]
            state = dict(active_intent=active_intent, requested_slots=requested_slots)
            result[key] = state
            last_service = service
    for file in sorted(glob.glob(os.path.join(file_path, 'dialogues*.json'))):
        with open(file, 'r') as f:
            ds = json.load(f)
        for dialog in ds:
            dialogue_id = dialog['dialogue_id']
            turn_id = 0
            last_service = None
            for turn in dialog['turns']:
                if turn['speaker'] == 'SYSTEM':  # system action encoding
                    turn_id += 1
                    continue
                if len(turn['frames']) > 1:  # sort
                    sorted_frames = []
                    for frame in turn['frames']:
                        if frame['service'] == last_service:
                            sorted_frames = [frame] + sorted_frames
                        else:
                            sorted_frames.append(frame)
                    turn['frames'] = sorted_frames
                for frame in turn['frames']:
                    service = frame['service']
                    key = (dialogue_id, turn_id, service)
                    state = result[key]
                    state['slot_values'] = frame['state']['slot_values']
                    frame['state'] = state
                last_service = turn['frames'][-1]['service']
        with open(os.path.join(output_path, os.path.split(file)[-1]), 'w') as f:
            json.dump(ds, f, indent=2)
    return all_outputs

In [11]:
all_outputs = predict('data/dev')

2482it [30:36,  1.46it/s]
