In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
os.chdir('/content/drive/MyDrive/Stage3/code')

In [3]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/b0/9e/5b80becd952d5f7250eaf8fc64b957077b12ccfe73e9c03d37146ab29712/transformers-4.6.0-py3-none-any.whl (2.3MB)
[K     |████████████████████████████████| 2.3MB 8.6MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████████████| 901kB 46.3MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 43.6MB/s 
Collecting huggingface-hub==0.0.8
  Downloading https://files.pythonhosted.org/packages/a1/88/7b1e45720ecf59c6c6737ff332f41c955963090a18e72acbcbeac6b25e86/huggingface_hub-0.0.8-py3-none-any.whl
Instal

In [4]:
import json
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

from transformers import BertModel, BertTokenizer, BertConfig, AdamW, get_linear_schedule_with_warmup
from data_utils import (
    load_dataset, 
    get_examples_from_dialogues, 
    convert_state_dict, 
    DSTInputExample, 
    OpenVocabDSTFeature, 
    DSTPreprocessor, 
    WOSDataset)
    
from inference import inference
from evaluation import _evaluation
from data_utils import set_seed

## Data loading

In [5]:
set_seed(42)

In [6]:
train_data_file = "/content/drive/MyDrive/Stage3/input/data/train_dataset/train_dials.json"
slot_meta = json.load(open("/content/drive/MyDrive/Stage3/input/data/train_dataset/slot_meta.json"))
ontology = json.load(open("/content/drive/MyDrive/Stage3/input/data/train_dataset/ontology.json"))
train_data, dev_data, dev_labels = load_dataset(train_data_file)

train_examples = get_examples_from_dialogues(train_data,
                                             user_first=False,
                                             dialogue_level=False)
dev_examples = get_examples_from_dialogues(dev_data,
                                           user_first=False,
                                           dialogue_level=False)

100%|██████████| 6301/6301 [00:00<00:00, 9419.52it/s]
100%|██████████| 699/699 [00:00<00:00, 13502.10it/s]


In [7]:
print(len(train_examples))
print(len(dev_examples))

46170
5075


## TRADE Preprocessor 

기존의 GRU 기반의 인코더를 BERT-based Encoder로 바꿀 준비를 합시다.

1. 현재 `_convert_example_to_feature`에서는 `max_seq_length`를 핸들하고 있지 않습니다. `input_id`와 `segment_id`가 `max_seq_length`를 넘어가면 좌측부터 truncate시키는 코드를 삽입하세요.

2. hybrid approach에서 얻은 교훈을 바탕으로 gate class를 3개에서 5개로 늘려봅시다.
    - `gating2id`를 수정하세요
    - 이에 따른 `recover_state`를 수정하세요.
    
3. word dropout을 구현하세요.

In [8]:
class TRADEPreprocessor(DSTPreprocessor):
    def __init__(
        self,
        slot_meta,
        src_tokenizer,
        trg_tokenizer=None,
        ontology=None,
        max_seq_length=512,
    ):
        self.slot_meta = slot_meta
        self.src_tokenizer = src_tokenizer
        self.trg_tokenizer = trg_tokenizer if trg_tokenizer else src_tokenizer
        self.ontology = ontology
        self.gating2id = {"none": 0, "dontcare": 1, "yes": 2, "no": 3, "ptr": 4}
        self.id2gating = {v: k for k, v in self.gating2id.items()}
        self.max_seq_length = max_seq_length

    def _convert_example_to_feature(self, example):
        dialogue_context = " [SEP] ".join(example.context_turns + example.current_turn)

        input_id = self.src_tokenizer.encode(dialogue_context, add_special_tokens=False)
        max_length = self.max_seq_length - 2
        if len(input_id) > max_length:
            gap = len(input_id) - max_length
            input_id = input_id[gap:]

        input_id = (
            [self.src_tokenizer.cls_token_id]
            + input_id
            + [self.src_tokenizer.sep_token_id]
        )
        segment_id = [0] * (len(example.context_turns) +2) + [1] * (len(example.current_turn) +1)


        target_ids = []
        gating_id = []
        if not example.label:
            example.label = []

        state = convert_state_dict(example.label)
        for slot in self.slot_meta:
            value = state.get(slot, "none")
            target_id = self.trg_tokenizer.encode(value, add_special_tokens=False) + [
                self.trg_tokenizer.sep_token_id
            ]
            target_ids.append(target_id)
            gating_id.append(self.gating2id.get(value, self.gating2id["ptr"]))
        target_ids = self.pad_ids(target_ids, self.trg_tokenizer.pad_token_id)
        return OpenVocabDSTFeature(
            example.guid, input_id, segment_id, gating_id, target_ids
        )

    def convert_examples_to_features(self, examples):
        return list(map(self._convert_example_to_feature, examples))

    def recover_state(self, gate_list, gen_list):
        assert len(gate_list) == len(self.slot_meta)
        assert len(gen_list) == len(self.slot_meta)

        recovered = []
        for slot, gate, value in zip(self.slot_meta, gate_list, gen_list):
            if self.id2gating[gate] == "none":
                continue

            if self.id2gating[gate] in ["dontcare", "yes", "no"]:
                recovered.append("%s-%s" % (slot, self.id2gating[gate]))
                continue

            token_id_list = []
            for id_ in value:
                if id_ in self.trg_tokenizer.all_special_ids:
                    break

                token_id_list.append(id_)
            value = self.trg_tokenizer.decode(token_id_list, skip_special_tokens=True)

            if value == "none":
                continue

            recovered.append("%s-%s" % (slot, value))
        return recovered

    def collate_fn(self, batch):
        guids = [b.guid for b in batch]
        input_ids = torch.LongTensor(
            self.pad_ids([b.input_id for b in batch], self.src_tokenizer.pad_token_id)
        )
        segment_ids = torch.LongTensor(
            self.pad_ids([b.segment_id for b in batch], self.src_tokenizer.pad_token_id)
        )
        input_masks = input_ids.ne(self.src_tokenizer.pad_token_id)

        gating_ids = torch.LongTensor([b.gating_id for b in batch])
        target_ids = self.pad_id_of_matrix(
            [torch.LongTensor(b.target_ids) for b in batch],
            self.trg_tokenizer.pad_token_id,
        )
        return input_ids, segment_ids, input_masks, gating_ids, target_ids, guids

## Convert_Examples_to_Features 

In [9]:
tokenizer = BertTokenizer.from_pretrained('dsksd/bert-ko-small-minimal')
processor = TRADEPreprocessor(slot_meta, tokenizer, max_seq_length=512)

train_features = processor.convert_examples_to_features(train_examples)
dev_features = processor.convert_examples_to_features(dev_examples)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=263327.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=124.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=288.0, style=ProgressStyle(description_…




Token indices sequence length is longer than the specified maximum sequence length for this model (537 > 512). Running this sequence through the model will result in indexing errors


In [10]:
print(len(train_features))
print(len(dev_features))

46170
5075


# Model 

1. `GRUEncoder`를 `BertModel`로 교체하세요. 이에 따라 `tie_weight` 함수가 수정되어야 합니다.

In [11]:
class TRADE(nn.Module):
    def __init__(self, config, slot_vocab, slot_meta, pad_idx=0):
        super(TRADE, self).__init__()
        self.slot_meta = slot_meta
        if config.model_name_or_path:
            self.encoder = BertModel.from_pretrained(config.model_name_or_path)
        else:
            self.encoder = BertModel(config)
            
        self.decoder = SlotGenerator(
            config.vocab_size,
            config.hidden_size,
            config.hidden_dropout_prob,
            config.n_gate,
            None,
            pad_idx,
        )
        
        # init for only subword embedding
        self.decoder.set_slot_idx(slot_vocab)
        self.tie_weight()

    def tie_weight(self):
        self.decoder.embed.weight = self.encoder.embeddings.word_embeddings.weight

    def forward(self, input_ids, token_type_ids, attention_mask=None, max_len=10, teacher=None):

        encoder_outputs, pooled_output = self.encoder(input_ids=input_ids,return_dict = False)
        all_point_outputs, all_gate_outputs = self.decoder(
            input_ids,
            encoder_outputs,
            pooled_output.unsqueeze(0), 
            attention_mask, 
            max_len, 
            teacher
        )

        return all_point_outputs, all_gate_outputs
    
class SlotGenerator(nn.Module):
    def __init__(
        self, vocab_size, hidden_size, dropout, n_gate, proj_dim=None, pad_idx=0
    ):
        super(SlotGenerator, self).__init__()
        self.pad_idx = pad_idx
        
        # 전체 보캡에 대해
        self.vocab_size = vocab_size
        
        
        self.embed = nn.Embedding(
            vocab_size, hidden_size, padding_idx=pad_idx
        )  # shared with encoder

        if proj_dim:
            self.proj_layer = nn.Linear(hidden_size, proj_dim, bias=False)
        else:
            self.proj_layer = None
        self.hidden_size = proj_dim if proj_dim else hidden_size

        
        
        
        #self.gru = nn.GRU(
        #    self.hidden_size, self.hidden_size, 1, dropout=dropout, batch_first=True
        #)
        self.decoder_layer = nn.TransformerDecoderLayer(d_model=self.hidden_size, nhead=4)
        self.transformer = nn.TransformerDecoder(self.decoder_layer, num_layers=6)

        self.n_gate = n_gate
        self.dropout = nn.Dropout(dropout)
        self.w_gen = nn.Linear(self.hidden_size * 3, 1)
        self.sigmoid = nn.Sigmoid()
        self.w_gate = nn.Linear(self.hidden_size, n_gate)

    def set_slot_idx(self, slot_vocab_idx):
        whole = []
        
        #slot_vocab_idx 중 가장 긴거
        max_length = max(map(len, slot_vocab_idx))
        for idx in slot_vocab_idx:
            if len(idx) < max_length:
                gap = max_length - len(idx)
                idx.extend([self.pad_idx] * gap)
            whole.append(idx)
            
        #결국 whole == sumbt에서 slotlookup
        self.slot_embed_idx = whole  # torch.LongTensor(whole)

    def embedding(self, x):
        x = self.embed(x)
        if self.proj_layer:
            x = self.proj_layer(x)
        return x

    def forward(
        self, input_ids, encoder_output, hidden, input_masks, max_len, teacher=None
    ):
        input_masks = input_masks.ne(1)
        # J, slot_meta : key : [domain, slot] ex> LongTensor([1,2])
        # J,2
        batch_size = encoder_output.size(0)
        
        slot = torch.LongTensor(self.slot_embed_idx).to(input_ids.device)  ## J, N, d
        slot_e = torch.sum(self.embedding(slot), 1)  # J,d
        
        J = slot_e.size(0)

        all_point_outputs = torch.zeros(batch_size, J, max_len, self.vocab_size).to(
            input_ids.device
        )
        
        # Parallel Decoding
        w = slot_e.repeat(batch_size, 1).unsqueeze(1)
        hidden = hidden.repeat_interleave(J, dim=1)
        encoder_output = encoder_output.repeat_interleave(J, dim=0)
        input_ids = input_ids.repeat_interleave(J, dim=0)
        input_masks = input_masks.repeat_interleave(J, dim=0)
        
        
        for k in range(max_len):
            w = self.dropout(w)
            #_, hidden = self.gru(w, hidden)  # 1,B,D
            hidden = self.transformer(hidden,w)

            # B,T,D * B,D,1 => B,T
            attn_e = torch.bmm(encoder_output, hidden.permute(1, 2, 0))  # B,T,1
            attn_e = attn_e.squeeze(-1).masked_fill(input_masks, -1e9)
            attn_history = F.softmax(attn_e, -1)  # B,T

            if self.proj_layer:
                hidden_proj = torch.matmul(hidden, self.proj_layer.weight)
            else:
                hidden_proj = hidden

            # B,D * D,V => B,V
            attn_v = torch.matmul(
                hidden_proj.squeeze(0), self.embed.weight.transpose(0, 1)
            )  # B,V
            attn_vocab = F.softmax(attn_v, -1)

            # B,1,T * B,T,D => B,1,D
            context = torch.bmm(attn_history.unsqueeze(1), encoder_output)  # B,1,D
            p_gen = self.sigmoid(
                self.w_gen(torch.cat([w, hidden.transpose(0, 1), context], -1))
            )  # B,1
            p_gen = p_gen.squeeze(-1)

            p_context_ptr = torch.zeros_like(attn_vocab).to(input_ids.device)
            p_context_ptr.scatter_add_(1, input_ids, attn_history)  # copy B,V
            p_final = p_gen * attn_vocab + (1 - p_gen) * p_context_ptr  # B,V
            _, w_idx = p_final.max(-1)

            if teacher is not None:
                w = self.embedding(teacher[:, :, k]).transpose(0, 1).reshape(batch_size * J, 1, -1)
            else:
                w = self.embedding(w_idx).unsqueeze(1)  # B,1,D
            if k == 0:
                gated_logit = self.w_gate(context.squeeze(1))  # B,3
                all_gate_outputs = gated_logit.view(batch_size, J, self.n_gate)
            all_point_outputs[:, :, k, :] = p_final.view(batch_size, J, self.vocab_size)

        return all_point_outputs, all_gate_outputs

# 모델 및 데이터 로더 정의

In [12]:
slot_vocab = []
for slot in slot_meta:
    slot_vocab.append(
        tokenizer.encode(slot.replace('-', ' '),
                         add_special_tokens=False)
    )
    
config = BertConfig.from_pretrained('dsksd/bert-ko-small-minimal')
config.model_name_or_path = 'dsksd/bert-ko-small-minimal'
config.n_gate = len(processor.gating2id)
config.proj_dim = None
model = TRADE(config, slot_vocab, slot_meta)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=434.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=284118515.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at dsksd/bert-ko-small-minimal were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_data = WOSDataset(train_features)
train_sampler = RandomSampler(train_data)
train_loader = DataLoader(train_data, batch_size=8, sampler=train_sampler, collate_fn=processor.collate_fn)

dev_data = WOSDataset(dev_features)
dev_sampler = SequentialSampler(dev_data)
dev_loader = DataLoader(dev_data, batch_size=8, sampler=dev_sampler, collate_fn=processor.collate_fn)

# Optimizer & Scheduler 선언

In [14]:
n_epochs = 50
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

t_total = len(train_loader) * n_epochs
optimizer = AdamW(optimizer_grouped_parameters, lr=3e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0.1, num_training_steps=t_total
)
teacher_forcing = 0.5
model.to(device)

def masked_cross_entropy_for_value(logits, target, pad_idx=0):
    mask = target.ne(pad_idx)
    logits_flat = logits.view(-1, logits.size(-1))
    log_probs_flat = torch.log(logits_flat)
    target_flat = target.view(-1, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    losses = losses_flat.view(*target.size())
    losses = losses * mask.float()
    loss = losses.sum() / (mask.sum().float())
    return loss

loss_fnc_1 = masked_cross_entropy_for_value  # generation
loss_fnc_2 = nn.CrossEntropyLoss()  # gating

## Train

In [None]:
fffff
best_score = 0
cnt = 0
for epoch in range(50):
    batch_loss = []
    model.train()
    for step, batch in enumerate(train_loader):
        input_ids, segment_ids, input_masks, gating_ids, target_ids, guids = [b.to(device) if not isinstance(b, list) else b for b in batch]
        if teacher_forcing > 0.0 and random.random() < teacher_forcing:
            tf = target_ids
        else:
            tf = None

        all_point_outputs, all_gate_outputs = model(input_ids, segment_ids, input_masks, target_ids.size(-1))  # gt - length (generation)
        loss_1 = loss_fnc_1(all_point_outputs.contiguous(), target_ids.contiguous().view(-1))
        loss_2 = loss_fnc_2(all_gate_outputs.contiguous().view(-1, 5), gating_ids.contiguous().view(-1))
        loss = loss_1 + loss_2
        batch_loss.append(loss.item())

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        if step % 100 == 0:
            print('[%d/%d] [%d/%d] %f' % (epoch, n_epochs, step, len(train_loader), loss.item()))
       
    predictions = inference(model, dev_loader, processor, device)
    eval_result = _evaluation(predictions, dev_labels, slot_meta)
    score = eval_result['joint_goal_accuracy']
    if score > best_score:
        cnt = 0
        best_score = score
        torch.save(model.state_dict(), "/content/drive/MyDrive/Stage3/model/trade_transformer_best.pt")
    
    #else:
    #    cnt += 1
    #    if cnt == 5:
    #        print('Early stop! Epoch:',str(epoch))
    #        break
    
    for k, v in eval_result.items():
        print(f"{k}: {v}")
torch.save(model.state_dict(), "/content/drive/MyDrive/Stage3/model/trade_transformer_50.pt")

[0/50] [0/5772] 14.580219
[0/50] [100/5772] 1.788384
[0/50] [200/5772] 1.320970
[0/50] [300/5772] 1.062520
[0/50] [400/5772] 1.172823
[0/50] [500/5772] 1.033173
[0/50] [600/5772] 0.942167
[0/50] [700/5772] 0.555394
[0/50] [800/5772] 0.859395
[0/50] [900/5772] 0.593315
[0/50] [1000/5772] 0.490033
[0/50] [1100/5772] 0.735464
[0/50] [1200/5772] 0.363320
[0/50] [1300/5772] 0.498603
[0/50] [1400/5772] 0.307671
[0/50] [1500/5772] 0.320645
[0/50] [1600/5772] 0.396256
[0/50] [1700/5772] 0.316367
[0/50] [1800/5772] 0.248857
[0/50] [1900/5772] 0.446112
[0/50] [2000/5772] 0.400615
[0/50] [2100/5772] 0.309392
[0/50] [2200/5772] 0.197964
[0/50] [2300/5772] 0.239243
[0/50] [2400/5772] 0.372462
[0/50] [2500/5772] 0.225404
[0/50] [2600/5772] 0.330557
[0/50] [2700/5772] 0.317234
[0/50] [2800/5772] 0.299415
[0/50] [2900/5772] 0.412460
[0/50] [3000/5772] 0.271218
[0/50] [3100/5772] 0.143259
[0/50] [3200/5772] 0.230600
[0/50] [3300/5772] 0.459866
[0/50] [3400/5772] 0.209698
[0/50] [3500/5772] 0.338970
[0/

100%|██████████| 635/635 [03:31<00:00,  3.00it/s]


{'joint_goal_accuracy': 0.21517241379310345, 'turn_slot_accuracy': 0.95732019704435, 'turn_slot_f1': 0.8160748145418041}
joint_goal_accuracy: 0.21517241379310345
turn_slot_accuracy: 0.95732019704435
turn_slot_f1: 0.8160748145418041
[1/50] [0/5772] 0.107206
[1/50] [100/5772] 0.113113
[1/50] [200/5772] 0.238200
[1/50] [300/5772] 0.167443
[1/50] [400/5772] 0.179293
[1/50] [500/5772] 0.159053
[1/50] [600/5772] 0.206386
[1/50] [700/5772] 0.181444
[1/50] [800/5772] 0.184207
[1/50] [900/5772] 0.178955
[1/50] [1000/5772] 0.069459
[1/50] [1100/5772] 0.305747
[1/50] [1200/5772] 0.239655
[1/50] [1300/5772] 0.113650
[1/50] [1400/5772] 0.146167
[1/50] [1500/5772] 0.060224
[1/50] [1600/5772] 0.101877
[1/50] [1700/5772] 0.279358
[1/50] [1800/5772] 0.102700
[1/50] [1900/5772] 0.045201
[1/50] [2000/5772] 0.095045
[1/50] [2100/5772] 0.067352
[1/50] [2200/5772] 0.078987
[1/50] [2300/5772] 0.270099
[1/50] [2400/5772] 0.193313
[1/50] [2500/5772] 0.052020
[1/50] [2600/5772] 0.187921
[1/50] [2700/5772] 0.139

100%|██████████| 635/635 [03:32<00:00,  2.98it/s]


{'joint_goal_accuracy': 0.4250246305418719, 'turn_slot_accuracy': 0.9760525451560025, 'turn_slot_f1': 0.8960537621226483}
joint_goal_accuracy: 0.4250246305418719
turn_slot_accuracy: 0.9760525451560025
turn_slot_f1: 0.8960537621226483
[2/50] [0/5772] 0.131891
[2/50] [100/5772] 0.119986
[2/50] [200/5772] 0.066515
[2/50] [300/5772] 0.087678
[2/50] [400/5772] 0.080316
[2/50] [500/5772] 0.121662
[2/50] [600/5772] 0.028528
[2/50] [700/5772] 0.080763
[2/50] [800/5772] 0.105926
[2/50] [900/5772] 0.036895
[2/50] [1000/5772] 0.034231
[2/50] [1100/5772] 0.078428
[2/50] [1200/5772] 0.029447
[2/50] [1300/5772] 0.027037
[2/50] [1400/5772] 0.016031
[2/50] [1500/5772] 0.064721
[2/50] [1600/5772] 0.081879
[2/50] [1700/5772] 0.081467
[2/50] [1800/5772] 0.135337
[2/50] [1900/5772] 0.034207
[2/50] [2000/5772] 0.044568
[2/50] [2100/5772] 0.068475
[2/50] [2200/5772] 0.026458
[2/50] [2300/5772] 0.076097
[2/50] [2400/5772] 0.046471
[2/50] [2500/5772] 0.072169
[2/50] [2600/5772] 0.057945
[2/50] [2700/5772] 0.0

100%|██████████| 635/635 [03:32<00:00,  2.98it/s]


{'joint_goal_accuracy': 0.467192118226601, 'turn_slot_accuracy': 0.9801072796934951, 'turn_slot_f1': 0.9100871635878389}
joint_goal_accuracy: 0.467192118226601
turn_slot_accuracy: 0.9801072796934951
turn_slot_f1: 0.9100871635878389
[3/50] [0/5772] 0.159405
[3/50] [100/5772] 0.030887
[3/50] [200/5772] 0.025884
[3/50] [300/5772] 0.019922
[3/50] [400/5772] 0.144278
[3/50] [500/5772] 0.072006
[3/50] [600/5772] 0.102847
[3/50] [700/5772] 0.049376
[3/50] [800/5772] 0.032671
[3/50] [900/5772] 0.087325
[3/50] [1000/5772] 0.021264
[3/50] [1100/5772] 0.045128
[3/50] [1200/5772] 0.022619
[3/50] [1300/5772] 0.059963
[3/50] [1400/5772] 0.059034
[3/50] [1500/5772] 0.116153
[3/50] [1600/5772] 0.045907
[3/50] [1700/5772] 0.034479
[3/50] [1800/5772] 0.028078
[3/50] [1900/5772] 0.051269
[3/50] [2000/5772] 0.039380
[3/50] [2100/5772] 0.132835
[3/50] [2200/5772] 0.017018
[3/50] [2300/5772] 0.046540
[3/50] [2400/5772] 0.029602
[3/50] [2500/5772] 0.032577
[3/50] [2600/5772] 0.038317
[3/50] [2700/5772] 0.039

100%|██████████| 635/635 [03:32<00:00,  2.98it/s]


{'joint_goal_accuracy': 0.4845320197044335, 'turn_slot_accuracy': 0.9816179529283082, 'turn_slot_f1': 0.9171370625201847}
joint_goal_accuracy: 0.4845320197044335
turn_slot_accuracy: 0.9816179529283082
turn_slot_f1: 0.9171370625201847
[4/50] [0/5772] 0.012152
[4/50] [100/5772] 0.014108
[4/50] [200/5772] 0.015004
[4/50] [300/5772] 0.018516
[4/50] [400/5772] 0.049016
[4/50] [500/5772] 0.042942
[4/50] [600/5772] 0.008636
[4/50] [700/5772] 0.024726
[4/50] [800/5772] 0.034927
[4/50] [900/5772] 0.034680
[4/50] [1000/5772] 0.076988
[4/50] [1100/5772] 0.036638
[4/50] [1200/5772] 0.035157
[4/50] [1300/5772] 0.048419
[4/50] [1400/5772] 0.057011
[4/50] [1500/5772] 0.024379
[4/50] [1600/5772] 0.013077
[4/50] [1700/5772] 0.035915
[4/50] [1800/5772] 0.023194
[4/50] [1900/5772] 0.016011
[4/50] [2000/5772] 0.025617
[4/50] [2100/5772] 0.065243
[4/50] [2200/5772] 0.045565
[4/50] [2300/5772] 0.012438


KeyboardInterrupt: ignored

## Inference 

In [16]:
eval_data = json.load(open(f"/content/drive/MyDrive/Stage3/input/data/eval_dataset/eval_dials.json", "r"))

eval_examples = get_examples_from_dialogues(
    eval_data, user_first=False, dialogue_level=False
)

# Extracting Featrues
eval_features = processor.convert_examples_to_features(eval_examples)
eval_data = WOSDataset(eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_loader = DataLoader(
    eval_data,
    batch_size=8,
    sampler=eval_sampler,
    collate_fn=processor.collate_fn,
)

100%|██████████| 2000/2000 [00:00<00:00, 12765.23it/s]


In [17]:
model.load_state_dict(torch.load("/content/drive/MyDrive/Stage3/model/Trade_segment_transformer_best.pt"))
model= model.eval()
predictions = inference(model, eval_loader, processor, device)

100%|██████████| 1847/1847 [09:50<00:00,  3.13it/s]


In [21]:
json.dump(predictions, open('/content/drive/MyDrive/Stage3/Trade_segment_transformer.csv', 'w'), indent=2, ensure_ascii=False) 

## prediction alalysis

In [None]:
model.load_state_dict(torch.load("/opt/ml/model/Trade50_best.pt"))
model = model.eval()
predictions = inference(model, dev_loader, processor, device)
eval_result = _evaluation(predictions, dev_labels, slot_meta)

In [None]:
cat_ontology = []
noncat_ontology = []
for i in range(45):
    if len(list(ontology.items())[i][1]) <=9:
        cat_ontology.append(list(ontology.items())[i][0])
    else:
        noncat_ontology.append(list(ontology.items())[i][0])

In [None]:
category_pred = []
noncategory_pred = []
category_label = []
noncategory_label = []

In [None]:
pred_keys = list(predictions.keys())
pred_values = list(predictions.values())
labels_keys = list(dev_labels.keys())
labels_values = list(dev_labels.values())

In [None]:
correct = []
wrong = []

In [None]:
# 전체 예측에 대해
last_key = ''
for num_labels in range(len(list(predictions))):
    
    # 마지막 예측이랑 비교해 새로운 예측만 생각
    if labels_keys[num_labels].split(':')[0] != last_key:
        last_pred = set([])
        last_label = set([])
        last_key = labels_keys[num_labels].split(':')[0]
    
    else:
        last_pred = set(pred_values[num_labels-1]) 
        last_label = set(labels_values[num_labels-1])
    
    pred_slots = list(set(pred_values[num_labels]) - last_pred)
    label_slots = list(set(labels_values[num_labels]) - last_label)
                       
    #예측 안에 각각의 예측 slot에 대해
    for label_slot_num in range(len(label_slots)):
        
        # 정답에 예측이 있으면
        if label_slots[label_slot_num] in pred_slots:
            correct.append(label_slots[label_slot_num])
        else:
            wrong.append(label_slots[label_slot_num])

In [None]:
print('len_correct:',len(correct))
print('len_wrong:',len(wrong))

In [None]:
cat_correct = []
noncat_correct = []
cat_wrong = []
noncat_wrong = []
for label in correct:
    if label.split('-')[0] + '-' +label.split('-')[1] in cat_ontology:
        cat_correct.append(label)
    else:
        noncat_correct.append(label)
for label in wrong:
    if label.split('-')[0] + '-' +label.split('-')[1] in cat_ontology:
        cat_wrong.append(label)
    else:
        noncat_wrong.append(label)


In [None]:
print('cat_correct:',len(cat_correct))
print('noncat_correct:',len(noncat_correct))
print('cat_wrong:',len(cat_wrong))
print('noncat_wrong:',len(noncat_wrong))

In [None]:
from collections import defaultdict

In [None]:
cat_correct_dict = defaultdict(int)
noncat_correct_dict = defaultdict(int)
cat_wrong_dict = defaultdict(int)
noncat_wrong_dict = defaultdict(int)

In [None]:
for cat in cat_correct:
    cat_correct_dict[cat.split('-')[0] + '-' +cat.split('-')[1]] +=1
for cat in noncat_correct:
    noncat_correct_dict[cat.split('-')[0] + '-' +cat.split('-')[1]] +=1    
for cat in cat_wrong:
    cat_wrong_dict[cat.split('-')[0] + '-' +cat.split('-')[1]] +=1
for cat in noncat_wrong:
    noncat_wrong_dict[cat.split('-')[0] + '-' +cat.split('-')[1]] +=1

In [None]:
print('cat 정답률')
for cat in cat_ontology:
    percent = cat_correct_dict[cat]/(cat_correct_dict[cat] + cat_wrong_dict[cat])*100
    print( str(cat), format(percent,".2f"),'%', '('+str(cat_correct_dict[cat])+'/'+str(cat_correct_dict[cat] + cat_wrong_dict[cat])+')')


In [None]:
print('cat 오답률')
for cat in cat_ontology:
    percent = cat_wrong_dict[cat]/(cat_correct_dict[cat] + cat_wrong_dict[cat])*100
    print( str(cat), format(percent,".2f"),'%', '('+str(cat_wrong_dict[cat])+'/'+str(cat_correct_dict[cat] + cat_wrong_dict[cat])+')')


In [None]:
print('noncat 정답률')
for cat in noncat_ontology:
    if noncat_correct_dict[cat] ==0:
        print(str(cat),' 0.00 %')
    else:
        percent = noncat_correct_dict[cat]/(noncat_correct_dict[cat] + noncat_wrong_dict[cat])*100
        print( str(cat), format(percent,".2f"),'%', '('+str(noncat_correct_dict[cat])+'/'+str(noncat_correct_dict[cat] + noncat_wrong_dict[cat])+')')


In [None]:
print('noncat 오답률')

for cat in noncat_ontology:
    if noncat_correct_dict[cat] ==0:
        print(str(cat),' 0.00 %')
    else:
        percent = noncat_wrong_dict[cat]/(noncat_correct_dict[cat] + noncat_wrong_dict[cat])*100
        print( str(cat), format(percent,".2f"),'%', '('+str(noncat_wrong_dict[cat])+'/'+str(noncat_correct_dict[cat] + noncat_wrong_dict[cat])+')')