In [1]:
import sys
sys.path.append('..')

In [2]:
import json
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

from transformers import (BertTokenizer, 
                          BertModel,  
                          AdamW, 
                          get_linear_schedule_with_warmup, 
                          BertConfig)
from transformers.modeling_bert import BertOnlyMLMHead
from data_utils import (load_dataset, 
                        get_examples_from_dialogues, 
                        convert_state_dict, 
                        DSTInputExample, 
                        OpenVocabDSTFeature, 
                        DSTPreprocessor, 
                        WOSDataset)

from inference import inference
from evaluation import _evaluation

In [3]:
train_data_file = "/opt/ml/repo/taepd/input/data/train_dataset/train_dials.json"
slot_meta = json.load(open("/opt/ml/repo/taepd/input/data/train_dataset/slot_meta.json"))
ontology = json.load(open("/opt/ml/repo/taepd/input/data/train_dataset/ontology.json"))
train_data, dev_data, dev_labels = load_dataset(train_data_file)

# dialogue_level=False : SUMBT와 다르게 dialogue context level로 input하므로
train_examples = get_examples_from_dialogues(train_data,
                                             user_first=False,
                                             dialogue_level=False)
dev_examples = get_examples_from_dialogues(dev_data,
                                           user_first=False,
                                           dialogue_level=False)

100%|██████████| 6301/6301 [00:00<00:00, 8709.63it/s] 
100%|██████████| 699/699 [00:00<00:00, 15282.07it/s]


In [4]:
print(len(train_examples))
print(len(dev_examples))

46315
4930


## TRADE Preprocessor

BERT Encoder가 적용된 TRADE의 preprocessor입니다.

In [5]:
class TRADEPreprocessor(DSTPreprocessor):
    def __init__(
        self,
        slot_meta,
        src_tokenizer,
        trg_tokenizer=None,
        ontology=None,
        max_seq_length=512,
    ):
        self.slot_meta = slot_meta
        self.src_tokenizer = src_tokenizer
        self.trg_tokenizer = trg_tokenizer if trg_tokenizer else src_tokenizer
        self.ontology = ontology
        self.gating2id = {"none": 0, "dontcare": 1, "yes": 2, "no": 3, "ptr": 4}
        self.id2gating = {v: k for k, v in self.gating2id.items()}
        self.max_seq_length = max_seq_length

    def _convert_example_to_feature(self, example):
        dialogue_context = " [SEP] ".join(example.context_turns + example.current_turn)

        input_id = self.src_tokenizer.encode(dialogue_context, add_special_tokens=False)
        max_length = self.max_seq_length - 2
        if len(input_id) > max_length:
            gap = len(input_id) - max_length
            input_id = input_id[gap:]

        input_id = (
            [self.src_tokenizer.cls_token_id]
            + input_id
            + [self.src_tokenizer.sep_token_id]
        )
        segment_id = [0] * len(input_id)

        target_ids = []
        gating_id = []
        if not example.label:
            example.label = []

        state = convert_state_dict(example.label)
        for slot in self.slot_meta:
            value = state.get(slot, "none")
            target_id = self.trg_tokenizer.encode(value, add_special_tokens=False) + [
                self.trg_tokenizer.sep_token_id
            ]
            target_ids.append(target_id)
            gating_id.append(self.gating2id.get(value, self.gating2id["ptr"]))
        target_ids = self.pad_ids(target_ids, self.trg_tokenizer.pad_token_id)
        return OpenVocabDSTFeature(
            example.guid, input_id, segment_id, gating_id, target_ids
        )

    def convert_examples_to_features(self, examples):
        return list(map(self._convert_example_to_feature, examples))

    def recover_state(self, gate_list, gen_list):
        assert len(gate_list) == len(self.slot_meta)
        assert len(gen_list) == len(self.slot_meta)

        recovered = []
        for slot, gate, value in zip(self.slot_meta, gate_list, gen_list):
            if self.id2gating[gate] == "none":
                continue

            if self.id2gating[gate] in ["dontcare", "yes", "no"]:
                recovered.append("%s-%s" % (slot, self.id2gating[gate]))
                continue

            token_id_list = []
            for id_ in value:
                if id_ in self.trg_tokenizer.all_special_ids:
                    break

                token_id_list.append(id_)
            value = self.trg_tokenizer.decode(token_id_list, skip_special_tokens=True)

            if value == "none":
                continue

            recovered.append("%s-%s" % (slot, value))
        return recovered

    def collate_fn(self, batch):
        guids = [b.guid for b in batch]
        input_ids = torch.LongTensor(
            self.pad_ids([b.input_id for b in batch], self.src_tokenizer.pad_token_id)
        )
        segment_ids = torch.LongTensor(
            self.pad_ids([b.segment_id for b in batch], self.src_tokenizer.pad_token_id)
        )
        input_masks = input_ids.ne(self.src_tokenizer.pad_token_id)

        gating_ids = torch.LongTensor([b.gating_id for b in batch])
        target_ids = self.pad_id_of_matrix(
            [torch.LongTensor(b.target_ids) for b in batch],
            self.trg_tokenizer.pad_token_id,
        )
        return input_ids, segment_ids, input_masks, gating_ids, target_ids, guids

## Convert_Examples_to_Features

In [6]:
tokenizer = BertTokenizer.from_pretrained('dsksd/bert-ko-small-minimal')
processor = TRADEPreprocessor(slot_meta, tokenizer, max_seq_length=512)

train_features = processor.convert_examples_to_features(train_examples)
dev_features = processor.convert_examples_to_features(dev_examples)

Token indices sequence length is longer than the specified maximum sequence length for this model (537 > 512). Running this sequence through the model will result in indexing errors


In [7]:
print(len(train_features))
print(len(dev_features))

46315
4930


# TRADE

## Model

In [None]:
class TRADE(nn.Module):
    def __init__(self, config, slot_vocab, slot_meta, pad_idx=0):
        super(TRADE, self).__init__()
        self.config = config
        self.slot_meta = slot_meta
        if config.model_name_or_path:
            self.encoder = BertModel.from_pretrained(config.model_name_or_path)
        else:
            self.encoder = BertModel(config)
            
        self.decoder = SlotGenerator(
            config.vocab_size,
            config.hidden_size,
            config.hidden_dropout_prob,
            config.n_gate,
            None,
            pad_idx,
        )
        
        self.decoder.set_slot_idx(slot_vocab)
        
        self.mlm_head = BertOnlyMLMHead(config)
        self.tie_weight()

    def tie_weight(self):
        self.decoder.embed.weight = self.encoder.embeddings.word_embeddings.weight

    def forward(self, input_ids, token_type_ids, attention_mask=None, max_len=10, teacher=None):

        encoder_outputs, pooled_output = self.encoder(input_ids=input_ids)
        all_point_outputs, all_gate_outputs = self.decoder(
            input_ids, encoder_outputs, pooled_output.unsqueeze(0), attention_mask, max_len, teacher
        )

        return all_point_outputs, all_gate_outputs
    
    @staticmethod
    def mask_tokens(inputs, tokenizer, config, mlm_probability=0.15):
        """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
        probability_matrix = torch.full(labels.shape, mlm_probability).to(device)
        #special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]

        probability_matrix.masked_fill_(torch.eq(labels, 0), value=0.0)

        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(device=device, dtype=torch.bool) & masked_indices
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(["[MASK]"])[0]

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(device=device, dtype=torch.bool) & masked_indices & ~indices_replaced
        random_words = torch.randint(config.vocab_size, labels.shape, device=device, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random].to(device)

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels
    
    def forward_pretrain(self, input_ids, tokenizer):
        input_ids, labels = self.mask_tokens(input_ids, tokenizer, self.config)
        encoder_outputs, _ = self.encoder(input_ids=input_ids)
        mlm_logits = self.mlm_head(encoder_outputs)
        
        return mlm_logits, labels
    
class SlotGenerator(nn.Module):
    def __init__(
        self, vocab_size, hidden_size, dropout, n_gate, proj_dim=None, pad_idx=0
    ):
        super(SlotGenerator, self).__init__()
        self.pad_idx = pad_idx
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(
            vocab_size, hidden_size, padding_idx=pad_idx
        )  # shared with encoder

        if proj_dim:
            self.proj_layer = nn.Linear(hidden_size, proj_dim, bias=False)
        else:
            self.proj_layer = None
        self.hidden_size = proj_dim if proj_dim else hidden_size

        self.gru = nn.GRU(
            self.hidden_size, self.hidden_size, 1, dropout=dropout, batch_first=True
        )
        self.n_gate = n_gate
        self.dropout = nn.Dropout(dropout)
        self.w_gen = nn.Linear(self.hidden_size * 3, 1)
        self.sigmoid = nn.Sigmoid()
        self.w_gate = nn.Linear(self.hidden_size, n_gate)

    def set_slot_idx(self, slot_vocab_idx):
        whole = []
        max_length = max(map(len, slot_vocab_idx))
        for idx in slot_vocab_idx:
            if len(idx) < max_length:
                gap = max_length - len(idx)
                idx.extend([self.pad_idx] * gap)
            whole.append(idx)
        self.slot_embed_idx = whole  # torch.LongTensor(whole)

    def embedding(self, x):
        x = self.embed(x)
        if self.proj_layer:
            x = self.proj_layer(x)
        return x

    def forward(
        self, input_ids, encoder_output, hidden, input_masks, max_len, teacher=None
    ):
        input_masks = input_masks.ne(1)
        # J, slot_meta : key : [domain, slot] ex> LongTensor([1,2])
        # J,2
        batch_size = encoder_output.size(0)
        slot = torch.LongTensor(self.slot_embed_idx).to(input_ids.device)  ##
        slot_e = torch.sum(self.embedding(slot), 1)  # J,d
        J = slot_e.size(0)

        all_point_outputs = torch.zeros(batch_size, J, max_len, self.vocab_size).to(
            input_ids.device
        )
        
        # Parallel Decoding
        w = slot_e.repeat(batch_size, 1).unsqueeze(1)
        hidden = hidden.repeat_interleave(J, dim=1)
        encoder_output = encoder_output.repeat_interleave(J, dim=0)
        input_ids = input_ids.repeat_interleave(J, dim=0)
        input_masks = input_masks.repeat_interleave(J, dim=0)
        for k in range(max_len):
            w = self.dropout(w)
            _, hidden = self.gru(w, hidden)  # 1,B,D

            # B,T,D * B,D,1 => B,T
            attn_e = torch.bmm(encoder_output, hidden.permute(1, 2, 0))  # B,T,1
            attn_e = attn_e.squeeze(-1).masked_fill(input_masks, -1e9)
            attn_history = F.softmax(attn_e, -1)  # B,T

            if self.proj_layer:
                hidden_proj = torch.matmul(hidden, self.proj_layer.weight)
            else:
                hidden_proj = hidden

            # B,D * D,V => B,V
            attn_v = torch.matmul(
                hidden_proj.squeeze(0), self.embed.weight.transpose(0, 1)
            )  # B,V
            attn_vocab = F.softmax(attn_v, -1)

            # B,1,T * B,T,D => B,1,D
            context = torch.bmm(attn_history.unsqueeze(1), encoder_output)  # B,1,D
            p_gen = self.sigmoid(
                self.w_gen(torch.cat([w, hidden.transpose(0, 1), context], -1))
            )  # B,1
            p_gen = p_gen.squeeze(-1)

            p_context_ptr = torch.zeros_like(attn_vocab).to(input_ids.device)
            p_context_ptr.scatter_add_(1, input_ids, attn_history)  # copy B,V
            p_final = p_gen * attn_vocab + (1 - p_gen) * p_context_ptr  # B,V
            _, w_idx = p_final.max(-1)

            if teacher is not None:
                w = self.embedding(teacher[:, :, k]).transpose(0, 1).reshape(batch_size * J, 1, -1)
            else:
                w = self.embedding(w_idx).unsqueeze(1)  # B,1,D
            if k == 0:
                gated_logit = self.w_gate(context.squeeze(1))  # B,3
                all_gate_outputs = gated_logit.view(batch_size, J, self.n_gate)
            all_point_outputs[:, :, k, :] = p_final.view(batch_size, J, self.vocab_size)

        return all_point_outputs, all_gate_outputs

## 모델 및 데이터 로더 정의

In [None]:
slot_vocab = []
for slot in slot_meta:
    slot_vocab.append(
        tokenizer.encode(slot.replace('-', ' '),
                         add_special_tokens=False)
    )
    
config = BertConfig.from_pretrained('dsksd/bert-ko-small-minimal')
config.model_name_or_path = 'dsksd/bert-ko-small-minimal'
config.n_gate = len(processor.gating2id)
config.proj_dim = None
model = TRADE(config, slot_vocab, slot_meta)



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_data = WOSDataset(train_features)
train_sampler = RandomSampler(train_data)
train_loader = DataLoader(train_data, batch_size=16, sampler=train_sampler, collate_fn=processor.collate_fn)

dev_data = WOSDataset(dev_features)
dev_sampler = SequentialSampler(dev_data)
dev_loader = DataLoader(dev_data, batch_size=8, sampler=dev_sampler, collate_fn=processor.collate_fn)

## Optimizer & Scheduler 선언

In [None]:
n_epochs = 10
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

t_total = len(train_loader) * n_epochs
optimizer = AdamW(optimizer_grouped_parameters, lr=3e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0.1, num_training_steps=t_total
)
teacher_forcing = 0.5
model.to(device)

def masked_cross_entropy_for_value(logits, target, pad_idx=0):
    mask = target.ne(pad_idx)
    logits_flat = logits.view(-1, logits.size(-1))
    log_probs_flat = torch.log(logits_flat)
    target_flat = target.view(-1, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    losses = losses_flat.view(*target.size())
    losses = losses * mask.float()
    loss = losses.sum() / (mask.sum().float())
    return loss

loss_fnc_1 = masked_cross_entropy_for_value  # generation
loss_fnc_2 = nn.CrossEntropyLoss()  # gating
loss_fnc_pretrain = nn.CrossEntropyLoss()  # MLM pretrain

## Pretraining

In [None]:
MLM_PRE = True

n_pretrain_epochs = 3

def mlm_pretrain(loader, n_epochs):
    model.train()
    for step, batch in enumerate(loader):
        input_ids, segment_ids, input_masks, gating_ids, target_ids, guids = [b.to(device) if not isinstance(b, list) else b for b in batch]

        logits, labels = model.forward_pretrain(input_ids, tokenizer)
        loss = loss_fnc_pretrain(logits.view(-1, config.vocab_size), labels.view(-1))

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        optimizer.zero_grad()

        if step % 100 == 0:
            print('[%d/%d] [%d/%d] %f' % (epoch, n_epochs, step, len(loader), loss.item()))

if MLM_PRE:
    for epoch in range(n_pretrain_epochs):
        mlm_pretrain(train_loader, n_pretrain_epochs)

[0/3] [0/2894] 10.695285
[0/3] [100/2894] 10.653587
[0/3] [200/2894] 10.636821
[0/3] [300/2894] 10.710348
[0/3] [400/2894] 10.655028
[0/3] [500/2894] 10.701962
[0/3] [600/2894] 10.700084
[0/3] [700/2894] 10.652074
[0/3] [800/2894] 10.669184
[0/3] [900/2894] 10.635167
[0/3] [1000/2894] 10.661515
[0/3] [1100/2894] 10.724567
[0/3] [1200/2894] 10.658038
[0/3] [1300/2894] 10.662727
[0/3] [1400/2894] 10.691187
[0/3] [1500/2894] 10.628413
[0/3] [1600/2894] 10.704060
[0/3] [1700/2894] 10.612057
[0/3] [1800/2894] 10.658079
[0/3] [1900/2894] 10.662621
[0/3] [2000/2894] 10.598007
[0/3] [2100/2894] 10.616374
[0/3] [2200/2894] 10.652863
[0/3] [2300/2894] 10.659072
[0/3] [2400/2894] 10.626670
[0/3] [2500/2894] 10.615333
[0/3] [2600/2894] 10.674631
[0/3] [2700/2894] 10.666609
[0/3] [2800/2894] 10.630229
[1/3] [0/2894] 10.631713
[1/3] [100/2894] 10.700149
[1/3] [200/2894] 10.703929
[1/3] [300/2894] 10.622448
[1/3] [400/2894] 10.626153
[1/3] [500/2894] 10.639113
[1/3] [600/2894] 10.647092
[1/3] [700/28

## 모델 학습

In [None]:
MLM_DURING = True

for epoch in range(n_epochs):
    batch_loss = []
    model.train()
    for step, batch in enumerate(train_loader):
        input_ids, segment_ids, input_masks, gating_ids, target_ids, guids = [b.to(device) if not isinstance(b, list) else b for b in batch]
        if teacher_forcing > 0.0 and random.random() < teacher_forcing:
            tf = target_ids
        else:
            tf = None

        all_point_outputs, all_gate_outputs = model(input_ids, segment_ids, input_masks, target_ids.size(-1))  # gt - length (generation)
        loss_1 = loss_fnc_1(all_point_outputs.contiguous(), target_ids.contiguous().view(-1))
        loss_2 = loss_fnc_2(all_gate_outputs.contiguous().view(-1, 5), gating_ids.contiguous().view(-1))
        loss = loss_1 + loss_2
        batch_loss.append(loss.item())

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        if step % 100 == 0:
            print('[%d/%d] [%d/%d] %f' % (epoch, n_epochs, step, len(train_loader), loss.item()))
            
    if MLM_DURING:
        mlm_pretrain(train_loader, n_epochs)
                
    predictions = inference(model, dev_loader, processor, device)
    eval_result = _evaluation(predictions, dev_labels, slot_meta)
    for k, v in eval_result.items():
        print(f"{k}: {v}")

[0/10] [0/2894] 10.479241
[0/10] [100/2894] 1.787377
[0/10] [200/2894] 1.833614
[0/10] [300/2894] 0.978634
[0/10] [400/2894] 0.868383
[0/10] [500/2894] 0.769875
[0/10] [600/2894] 0.694434
[0/10] [700/2894] 0.709479
[0/10] [800/2894] 0.786163
[0/10] [900/2894] 0.388131
[0/10] [1000/2894] 0.507657
[0/10] [1100/2894] 0.517074
[0/10] [1200/2894] 0.387976
[0/10] [1300/2894] 0.369454
[0/10] [1400/2894] 0.304115
[0/10] [1500/2894] 0.390003
[0/10] [1600/2894] 0.355167
[0/10] [1700/2894] 0.350109
[0/10] [1800/2894] 0.284057
[0/10] [1900/2894] 0.342065
[0/10] [2000/2894] 0.346315
[0/10] [2100/2894] 0.297524
[0/10] [2200/2894] 0.196061
[0/10] [2300/2894] 0.283494
[0/10] [2400/2894] 0.308459
[0/10] [2500/2894] 0.191082
[0/10] [2600/2894] 0.247932
[0/10] [2700/2894] 0.379054
[0/10] [2800/2894] 0.303518
[0/10] [0/2894] 10.290485
[0/10] [100/2894] 3.807715
[0/10] [200/2894] 2.556649
[0/10] [300/2894] 2.512274
[0/10] [400/2894] 1.989473
[0/10] [500/2894] 1.912938
[0/10] [600/2894] 1.531750
[0/10] [700

100%|██████████| 619/619 [01:09<00:00,  8.94it/s]


{'joint_goal_accuracy': 0.001211876388608362, 'turn_slot_accuracy': 0.6366839471262775, 'turn_slot_f1': 0.26452292652714}
joint_goal_accuracy: 0.001211876388608362
turn_slot_accuracy: 0.6366839471262775
turn_slot_f1: 0.26452292652714
[1/10] [0/2894] 2.698599
[1/10] [100/2894] 0.213983
[1/10] [200/2894] 0.266393
[1/10] [300/2894] 0.285038
[1/10] [400/2894] 0.236070
[1/10] [500/2894] 0.199817
[1/10] [600/2894] 0.136715
[1/10] [700/2894] 0.231997
[1/10] [800/2894] 0.192494
[1/10] [900/2894] 0.221587
[1/10] [1000/2894] 0.227034
[1/10] [1100/2894] 0.148866
[1/10] [1200/2894] 0.134970
[1/10] [1300/2894] 0.207710
[1/10] [1400/2894] 0.199708
[1/10] [1500/2894] 0.238266
[1/10] [1600/2894] 0.125722
[1/10] [1700/2894] 0.207651
[1/10] [1800/2894] 0.111876
[1/10] [1900/2894] 0.130565
[1/10] [2000/2894] 0.204489
[1/10] [2100/2894] 0.094385
[1/10] [2200/2894] 0.191717
[1/10] [2300/2894] 0.126738
[1/10] [2400/2894] 0.113447
[1/10] [2500/2894] 0.134092
[1/10] [2600/2894] 0.172208
[1/10] [2700/2894] 0.1

100%|██████████| 619/619 [01:08<00:00,  9.01it/s]


{'joint_goal_accuracy': 0.01979398101393658, 'turn_slot_accuracy': 0.8120469489889746, 'turn_slot_f1': 0.4722378969176701}
joint_goal_accuracy: 0.01979398101393658
turn_slot_accuracy: 0.8120469489889746
turn_slot_f1: 0.4722378969176701
[2/10] [0/2894] 1.636626
[2/10] [100/2894] 0.105757
[2/10] [200/2894] 0.065617
[2/10] [300/2894] 0.117363
[2/10] [400/2894] 0.163464
[2/10] [500/2894] 0.058478
[2/10] [600/2894] 0.081168
[2/10] [700/2894] 0.073699
[2/10] [800/2894] 0.100143
[2/10] [900/2894] 0.097578
[2/10] [1000/2894] 0.055120
[2/10] [1100/2894] 0.124482
[2/10] [1200/2894] 0.107333
[2/10] [1300/2894] 0.128425
[2/10] [1400/2894] 0.084379
[2/10] [1500/2894] 0.067071
[2/10] [1600/2894] 0.058993
[2/10] [1700/2894] 0.046455
[2/10] [1800/2894] 0.092474
[2/10] [1900/2894] 0.083181
[2/10] [2000/2894] 0.055736
[2/10] [2100/2894] 0.069237
[2/10] [2200/2894] 0.135437
[2/10] [2300/2894] 0.118794
[2/10] [2400/2894] 0.098667
[2/10] [2500/2894] 0.086952
[2/10] [2600/2894] 0.107385
[2/10] [2700/2894] 0

100%|██████████| 619/619 [01:08<00:00,  9.00it/s]


{'joint_goal_accuracy': 0.06018986063421531, 'turn_slot_accuracy': 0.9033685675172273, 'turn_slot_f1': 0.6589605144883941}
joint_goal_accuracy: 0.06018986063421531
turn_slot_accuracy: 0.9033685675172273
turn_slot_f1: 0.6589605144883941
[3/10] [0/2894] 0.896906
[3/10] [100/2894] 0.072288
[3/10] [200/2894] 0.095614
[3/10] [300/2894] 0.072446
[3/10] [400/2894] 0.104356
[3/10] [500/2894] 0.091745
[3/10] [600/2894] 0.044922
[3/10] [700/2894] 0.068284
[3/10] [800/2894] 0.064532
[3/10] [900/2894] 0.052464
[3/10] [1000/2894] 0.062923
[3/10] [1100/2894] 0.158064
[3/10] [1200/2894] 0.121072
[3/10] [1300/2894] 0.091380
[3/10] [1400/2894] 0.051162
[3/10] [1500/2894] 0.060938
[3/10] [1600/2894] 0.090415
[3/10] [1700/2894] 0.069550
[3/10] [1800/2894] 0.056941
[3/10] [1900/2894] 0.111795
[3/10] [2000/2894] 0.053701
[3/10] [2100/2894] 0.060008
[3/10] [2200/2894] 0.081920
[3/10] [2300/2894] 0.094612
[3/10] [2400/2894] 0.037036
[3/10] [2500/2894] 0.060814
[3/10] [2600/2894] 0.081578
[3/10] [2700/2894] 0

100%|██████████| 619/619 [01:08<00:00,  9.06it/s]


{'joint_goal_accuracy': 0.12643910321147242, 'turn_slot_accuracy': 0.9348324693103632, 'turn_slot_f1': 0.743428973393871}
joint_goal_accuracy: 0.12643910321147242
turn_slot_accuracy: 0.9348324693103632
turn_slot_f1: 0.743428973393871
[4/10] [0/2894] 0.375657
[4/10] [100/2894] 0.048460
[4/10] [200/2894] 0.100936
[4/10] [300/2894] 0.112913
[4/10] [400/2894] 0.058470
[4/10] [500/2894] 0.076045
[4/10] [600/2894] 0.056451
[4/10] [700/2894] 0.024079
[4/10] [800/2894] 0.032442
[4/10] [900/2894] 0.054822
[4/10] [1000/2894] 0.062688
[4/10] [1100/2894] 0.098679
[4/10] [1200/2894] 0.068616
[4/10] [1300/2894] 0.016514
[4/10] [1400/2894] 0.024989
[4/10] [1500/2894] 0.076372
[4/10] [1600/2894] 0.062778
[4/10] [1700/2894] 0.060029
[4/10] [1800/2894] 0.035235
[4/10] [1900/2894] 0.072962
[4/10] [2000/2894] 0.066529
[4/10] [2100/2894] 0.076374
[4/10] [2200/2894] 0.062498
[4/10] [2300/2894] 0.073810
[4/10] [2400/2894] 0.038563
[4/10] [2500/2894] 0.048721
[4/10] [2600/2894] 0.078993
[4/10] [2700/2894] 0.0

100%|██████████| 619/619 [01:08<00:00,  9.07it/s]


{'joint_goal_accuracy': 0.30478691173500305, 'turn_slot_accuracy': 0.9656500370295695, 'turn_slot_f1': 0.8535739796875405}
joint_goal_accuracy: 0.30478691173500305
turn_slot_accuracy: 0.9656500370295695
turn_slot_f1: 0.8535739796875405
[5/10] [0/2894] 0.289729
[5/10] [100/2894] 0.065920
[5/10] [200/2894] 0.080398
[5/10] [300/2894] 0.035892
[5/10] [400/2894] 0.033895
[5/10] [500/2894] 0.046071
[5/10] [600/2894] 0.021943
[5/10] [700/2894] 0.055971
[5/10] [800/2894] 0.026725
[5/10] [900/2894] 0.025048
[5/10] [1000/2894] 0.045479
[5/10] [1100/2894] 0.047929
[5/10] [1200/2894] 0.022669
[5/10] [1300/2894] 0.082602
[5/10] [1400/2894] 0.042442
[5/10] [1500/2894] 0.076513
[5/10] [1600/2894] 0.047413
[5/10] [1700/2894] 0.036124
[5/10] [1800/2894] 0.041714
[5/10] [1900/2894] 0.024631
[5/10] [2000/2894] 0.039686
[5/10] [2100/2894] 0.039790
[5/10] [2200/2894] 0.037568
[5/10] [2300/2894] 0.035502
[5/10] [2400/2894] 0.072325
[5/10] [2500/2894] 0.025523
[5/10] [2600/2894] 0.034588
[5/10] [2700/2894] 0

100%|██████████| 619/619 [01:08<00:00,  9.05it/s]


{'joint_goal_accuracy': 0.3983033730559483, 'turn_slot_accuracy': 0.9731816243632128, 'turn_slot_f1': 0.8865022788453295}
joint_goal_accuracy: 0.3983033730559483
turn_slot_accuracy: 0.9731816243632128
turn_slot_f1: 0.8865022788453295
[6/10] [0/2894] 0.216665
[6/10] [100/2894] 0.043218
[6/10] [200/2894] 0.057242
[6/10] [300/2894] 0.056206
[6/10] [400/2894] 0.044075
[6/10] [500/2894] 0.039929
[6/10] [600/2894] 0.055174
[6/10] [700/2894] 0.039498
[6/10] [800/2894] 0.030613
[6/10] [900/2894] 0.032366
[6/10] [1000/2894] 0.119888
[6/10] [1100/2894] 0.024544
[6/10] [1200/2894] 0.037841
[6/10] [1300/2894] 0.032522
[6/10] [1400/2894] 0.016029
[6/10] [1500/2894] 0.019364
[6/10] [1600/2894] 0.034022
[6/10] [1700/2894] 0.047358
[6/10] [1800/2894] 0.039442
[6/10] [1900/2894] 0.037178
[6/10] [2000/2894] 0.030836
[6/10] [2100/2894] 0.055623
[6/10] [2200/2894] 0.030512
[6/10] [2300/2894] 0.019074
[6/10] [2400/2894] 0.032463
[6/10] [2500/2894] 0.025576
[6/10] [2600/2894] 0.056333
[6/10] [2700/2894] 0.0

100%|██████████| 619/619 [01:08<00:00,  9.07it/s]


{'joint_goal_accuracy': 0.5406988487174308, 'turn_slot_accuracy': 0.9837967638412062, 'turn_slot_f1': 0.9209380437252567}
joint_goal_accuracy: 0.5406988487174308
turn_slot_accuracy: 0.9837967638412062
turn_slot_f1: 0.9209380437252567
[7/10] [0/2894] 0.058712
[7/10] [100/2894] 0.034744
[7/10] [200/2894] 0.064716
[7/10] [300/2894] 0.046700
[7/10] [400/2894] 0.023303
[7/10] [500/2894] 0.025220
[7/10] [600/2894] 0.029212
[7/10] [700/2894] 0.017783
[7/10] [800/2894] 0.052295
[7/10] [900/2894] 0.016869
[7/10] [1000/2894] 0.010020
[7/10] [1100/2894] 0.082434
[7/10] [1200/2894] 0.029836
[7/10] [1300/2894] 0.024994
[7/10] [1400/2894] 0.027071
[7/10] [1500/2894] 0.033265
[7/10] [1600/2894] 0.027167
[7/10] [1700/2894] 0.018439
[7/10] [1800/2894] 0.050906
[7/10] [1900/2894] 0.085401
[7/10] [2000/2894] 0.061284
[7/10] [2100/2894] 0.055016
[7/10] [2200/2894] 0.043745
[7/10] [2300/2894] 0.046031
[7/10] [2400/2894] 0.058996
[7/10] [2500/2894] 0.014377
[7/10] [2600/2894] 0.037652
[7/10] [2700/2894] 0.0

100%|██████████| 619/619 [01:08<00:00,  9.07it/s]


{'joint_goal_accuracy': 0.5901838012522723, 'turn_slot_accuracy': 0.9861217711349072, 'turn_slot_f1': 0.932642909758438}
joint_goal_accuracy: 0.5901838012522723
turn_slot_accuracy: 0.9861217711349072
turn_slot_f1: 0.932642909758438
[8/10] [0/2894] 0.039473
[8/10] [100/2894] 0.116277
[8/10] [200/2894] 0.030894
[8/10] [300/2894] 0.017684
[8/10] [400/2894] 0.020432
[8/10] [500/2894] 0.063957
[8/10] [600/2894] 0.021164
[8/10] [700/2894] 0.036350
[8/10] [800/2894] 0.054530
[8/10] [900/2894] 0.041291
[8/10] [1000/2894] 0.025891
[8/10] [1100/2894] 0.048146
[8/10] [1200/2894] 0.026129
[8/10] [1300/2894] 0.028268
[8/10] [1400/2894] 0.025651
[8/10] [1500/2894] 0.031688
[8/10] [1600/2894] 0.016329
[8/10] [1700/2894] 0.023373
[8/10] [1800/2894] 0.023661
[8/10] [1900/2894] 0.029619
[8/10] [2000/2894] 0.034652
[8/10] [2100/2894] 0.043870
[8/10] [2200/2894] 0.018598
[8/10] [2300/2894] 0.036970
[8/10] [2400/2894] 0.022005
[8/10] [2500/2894] 0.034227
[8/10] [2600/2894] 0.021552
[8/10] [2700/2894] 0.012

100%|██████████| 619/619 [01:08<00:00,  9.07it/s]


{'joint_goal_accuracy': 0.6154312260149465, 'turn_slot_accuracy': 0.9874054624206189, 'turn_slot_f1': 0.9388318799464411}
joint_goal_accuracy: 0.6154312260149465
turn_slot_accuracy: 0.9874054624206189
turn_slot_f1: 0.9388318799464411
[9/10] [0/2894] 0.042903
[9/10] [100/2894] 0.025247
[9/10] [200/2894] 0.039650
[9/10] [300/2894] 0.034933
[9/10] [400/2894] 0.043661
[9/10] [500/2894] 0.012590
[9/10] [600/2894] 0.023037
[9/10] [700/2894] 0.019565
[9/10] [800/2894] 0.039427
[9/10] [900/2894] 0.014163
[9/10] [1000/2894] 0.018687
[9/10] [1100/2894] 0.026516
[9/10] [1200/2894] 0.043638
[9/10] [1300/2894] 0.034804
[9/10] [1400/2894] 0.026990
[9/10] [1500/2894] 0.029972
[9/10] [1600/2894] 0.020106
[9/10] [1700/2894] 0.019763
[9/10] [1800/2894] 0.031618
[9/10] [1900/2894] 0.020294
[9/10] [2000/2894] 0.009980
[9/10] [2100/2894] 0.073129
[9/10] [2200/2894] 0.023514
[9/10] [2300/2894] 0.013358
[9/10] [2400/2894] 0.013117
[9/10] [2500/2894] 0.026427
[9/10] [2600/2894] 0.026022
[9/10] [2700/2894] 0.0

100%|██████████| 619/619 [01:08<00:00,  9.01it/s]

{'joint_goal_accuracy': 0.6241163401333064, 'turn_slot_accuracy': 0.9877420947507864, 'turn_slot_f1': 0.940597239821952}
joint_goal_accuracy: 0.6241163401333064
turn_slot_accuracy: 0.9877420947507864
turn_slot_f1: 0.940597239821952





## Inference

In [None]:
eval_data = json.load(open(f"/opt/ml/input/data/eval_dataset/eval_dials.json", "r"))

eval_examples = get_examples_from_dialogues(
    eval_data, user_first=False, dialogue_level=False
)

# Extracting Featrues
eval_features = processor.convert_examples_to_features(eval_examples)
eval_data = WOSDataset(eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_loader = DataLoader(
    eval_data,
    batch_size=8,
    sampler=eval_sampler,
    collate_fn=processor.collate_fn,
)

100%|██████████| 1000/1000 [00:00<00:00, 36946.73it/s]


In [None]:
predictions = inference(model, eval_loader, processor, device)

100%|██████████| 944/944 [01:45<00:00,  8.92it/s]


In [None]:
json.dump(predictions, open('predictions.csv', 'w'), indent=2, ensure_ascii=False)