In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
os.chdir('/content/drive/MyDrive/Stage3/code')

In [3]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/b0/9e/5b80becd952d5f7250eaf8fc64b957077b12ccfe73e9c03d37146ab29712/transformers-4.6.0-py3-none-any.whl (2.3MB)
[K     |████████████████████████████████| 2.3MB 3.9MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 31.9MB/s 
[?25hCollecting huggingface-hub==0.0.8
  Downloading https://files.pythonhosted.org/packages/a1/88/7b1e45720ecf59c6c6737ff332f41c955963090a18e72acbcbeac6b25e86/huggingface_hub-0.0.8-py3-none-any.whl
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████████████| 901kB 38.0MB/s 
Installing c

In [4]:
import json
from tqdm import tqdm

from transformers import BertTokenizer
from data_utils import get_examples_from_dialogues, convert_state_dict, load_dataset
from data_utils import OntologyDSTFeature, DSTPreprocessor, _truncate_seq_pair,set_seed

## Data Loading 

In [5]:
set_seed(42)

In [6]:
train_data_file = "/content/drive/MyDrive/Stage3/input/data/train_dataset/train_dials.json"
slot_meta = json.load(open("/content/drive/MyDrive/Stage3/input/data/train_dataset/slot_meta.json"))
ontology = json.load(open("/content/drive/MyDrive/Stage3/input/data/train_dataset/somdst_ontology.json"))
train_data, dev_data, dev_labels = load_dataset(train_data_file)

In [7]:
train_examples = get_examples_from_dialogues(data=train_data,
                                             user_first=True,
                                             dialogue_level=True)

dev_examples = get_examples_from_dialogues(data=dev_data,
                                           user_first=True,
                                           dialogue_level=True)

100%|██████████| 6301/6301 [00:00<00:00, 8664.73it/s]
100%|██████████| 699/699 [00:00<00:00, 12471.58it/s]


In [8]:
len(train_data)

6301

In [9]:
max_turn = max([len(e['dialogue']) for e in train_data])

In [10]:
tokenizer = BertTokenizer.from_pretrained('dsksd/bert-ko-small-minimal')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=263327.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=124.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=288.0, style=ProgressStyle(description_…




## TODO-1: SUMBT Preprocessor 정의 

Ontology-based DST model인 SUMBT의 InputFeature를 만들기 위한 Preprocessor를 정의해야 합니다. <br>

1. `_convert_examples_to_features` 함수의 빈칸을 매워 완성하세요.
2. `recover_state` 함수의 빈칸을 매워 완성하세요.

In [11]:
class SUMBTPreprocessor(DSTPreprocessor):
    def __init__(
        self,
        slot_meta,
        src_tokenizer,
        trg_tokenizer=None,
        ontology=None,
        max_seq_length=64,
        max_turn_length=14,
    ):
        self.slot_meta = slot_meta
        self.src_tokenizer = src_tokenizer
        self.trg_tokenizer = trg_tokenizer if trg_tokenizer else src_tokenizer
        self.ontology = ontology
        self.max_seq_length = max_seq_length
        self.max_turn_length = max_turn_length

    def _convert_example_to_feature(self, example):
        guid = example[0].guid.rsplit("-", 1)[0]  # dialogue_idx
        turns = []
        token_types = []
        labels = []
        num_turn = None
        for turn in example[: self.max_turn_length]:
            assert len(turn.current_turn) == 2
            uttrs = []
            for segment_idx, uttr in enumerate(turn.current_turn):
                token = self.src_tokenizer.encode(uttr, add_special_tokens=False)
                uttrs.append(token)

            _truncate_seq_pair(uttrs[0], uttrs[1], self.max_seq_length - 3)
            tokens = (
                [self.src_tokenizer.cls_token_id]
                + uttrs[0]
                + [self.src_tokenizer.sep_token_id]
                + uttrs[1]
                + [self.src_tokenizer.sep_token_id]
            )
            token_type = [0] * (len(uttrs[0]) + 2) + [1] * (len(uttrs[1]) + 1)
            if len(tokens) < self.max_seq_length:
                gap = self.max_seq_length - len(tokens)
                tokens.extend([self.src_tokenizer.pad_token_id] * gap)
                token_type.extend([0] * gap)
            turns.append(tokens)
            token_types.append(token_type)
            label = []
            if turn.label:
                slot_dict = convert_state_dict(turn.label)
            else:
                slot_dict = {}
            for slot_type in self.slot_meta:
                value = slot_dict.get(slot_type, "none")
                # TODO
                # raise Exception('label_idx를 ontology에서 꺼내오는 코드를 작성하세요!')
                if value in self.ontology[slot_type]:
                    label_idx = self.ontology[slot_type].index(value)
                else:
                    label_idx = self.ontology[slot_type].index("none")
                label.append(label_idx)
            labels.append(label)
        num_turn = len(turns)
        if len(turns) < self.max_turn_length:
            gap = self.max_turn_length - len(turns)
            for _ in range(gap):
                dummy_turn = [self.src_tokenizer.pad_token_id] * self.max_seq_length
                turns.append(dummy_turn)
                token_types.append(dummy_turn)
                dummy_label = [-1] * len(self.slot_meta)
                labels.append(dummy_label)
        return OntologyDSTFeature(
            guid=guid,
            input_ids=turns,
            segment_ids=token_types,
            num_turn=num_turn,
            target_ids=labels,
        )

    def convert_examples_to_features(self, examples):
        return list(map(self._convert_example_to_feature, examples))

    def recover_state(self, pred_slots, num_turn):
        states = []
        for pred_slot in pred_slots[:num_turn]:
            state = []
            for s, p in zip(self.slot_meta, pred_slot):
                v = self.ontology[s][p]
                if v != "none":
                    state.append(f"{s}-{v}")
            states.append(state)
        return states

    def collate_fn(self, batch):
        guids = [b.guid for b in batch]
        input_ids = torch.LongTensor([b.input_ids for b in batch])
        segment_ids = torch.LongTensor([b.segment_ids for b in batch])
        input_masks = input_ids.ne(self.src_tokenizer.pad_token_id)
        target_ids = torch.LongTensor([b.target_ids for b in batch])
        num_turns = [b.num_turn for b in batch]
        return input_ids, segment_ids, input_masks, target_ids, num_turns, guids

## Convert_Examples_to_Features 

In [12]:
processor = SUMBTPreprocessor(slot_meta,
                              tokenizer,
                              ontology=ontology,  # predefined ontology
                              max_seq_length=64,  # 각 turn마다 최대 길이
                              max_turn_length=max_turn)  # 각 dialogue의 최대 turn 길이
train_features = processor.convert_examples_to_features(train_examples)
dev_features = processor.convert_examples_to_features(dev_examples)

In [13]:
print(len(train_features))  # 대화 level의 features
print(len(dev_features))

6301
699


## SUMBT 모델 선언 

In [14]:
"""
Most of code is from https://github.com/SKTBrain/SUMBT
"""

import math
import os.path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CosineEmbeddingLoss, CrossEntropyLoss
from transformers import BertModel, BertPreTrainedModel


class BertForUtteranceEncoding(BertPreTrainedModel):
    def __init__(self, config):
        super(BertForUtteranceEncoding, self).__init__(config)

        self.config = config
        self.bert = BertModel(config)

    def forward(self, input_ids, token_type_ids, attention_mask):
        return self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=False,
        )


class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1):
        super().__init__()

        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)

        self.scores = None

    def attention(self, q, k, v, d_k, mask=None, dropout=None):

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)

        if dropout is not None:
            scores = dropout(scores)

        self.scores = scores
        output = torch.matmul(scores, v)
        return output

    def forward(self, q, k, v, mask=None):
        bs = q.size(0)

        # perform linear operation and split into h heads
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)

        # transpose to get dimensions bs * h * sl * d_model
        k = k.transpose(1, 2)
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)

        scores = self.attention(q, k, v, self.d_k, mask, self.dropout)

        # concatenate heads and put through final linear layer
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output

    def get_scores(self):
        return self.scores


class SUMBT(nn.Module):
    def __init__(self, args, num_labels, device):
        super(SUMBT, self).__init__()

        self.hidden_dim = args.hidden_dim
        self.rnn_num_layers = args.num_rnn_layers
        self.zero_init_rnn = args.zero_init_rnn
        self.max_seq_length = args.max_seq_length
        self.max_label_length = args.max_label_length
        self.num_labels = num_labels
        self.num_slots = len(num_labels)
        self.attn_head = args.attn_head
        self.device = device

        ### Utterance Encoder
        self.utterance_encoder = BertForUtteranceEncoding.from_pretrained(
            args.model_name_or_path
        )
        self.bert_output_dim = self.utterance_encoder.config.hidden_size
        self.hidden_dropout_prob = self.utterance_encoder.config.hidden_dropout_prob
        if args.fix_utterance_encoder:
            for p in self.utterance_encoder.bert.pooler.parameters():
                p.requires_grad = False

        ### slot, slot-value Encoder (not trainable)
        self.sv_encoder = BertForUtteranceEncoding.from_pretrained(
            args.model_name_or_path
        )
        # os.path.join(args.bert_dir, 'bert-base-uncased.model'))
        for p in self.sv_encoder.bert.parameters():
            p.requires_grad = False

        self.slot_lookup = nn.Embedding(self.num_slots, self.bert_output_dim)
        self.value_lookup = nn.ModuleList(
            [nn.Embedding(num_label, self.bert_output_dim) for num_label in num_labels]
        )

        ### Attention layer
        self.attn = MultiHeadAttention(self.attn_head, self.bert_output_dim, dropout=0)

        ### RNN Belief Tracker
        self.nbt = nn.GRU(
            input_size=self.bert_output_dim,
            hidden_size=self.hidden_dim,
            num_layers=self.rnn_num_layers,
            dropout=self.hidden_dropout_prob,
            batch_first=True,
            bidirectional = True
        )
        self.init_parameter(self.nbt)

        if not self.zero_init_rnn:
            self.rnn_init_linear = nn.Sequential(
                nn.Linear(self.bert_output_dim, self.hidden_dim),
                nn.ReLU(),
                nn.Dropout(self.hidden_dropout_prob),
            )

        self.linear = nn.Linear(self.hidden_dim, self.bert_output_dim)
        self.layer_norm = nn.LayerNorm(self.bert_output_dim)

        ### Measure
        self.metric = torch.nn.PairwiseDistance(p=2.0, eps=1e-06, keepdim=False)

        ### Classifier
        self.nll = CrossEntropyLoss(ignore_index=-1)
        #self.nll = torch.nn.HingeEmbeddingLoss(margin = 0.5,reduction='mean')
        ### Etc.
        self.dropout = nn.Dropout(self.hidden_dropout_prob)

    def initialize_slot_value_lookup(self, label_ids, slot_ids):

        self.sv_encoder.eval()

        # Slot encoding
        slot_type_ids = torch.zeros(slot_ids.size(), dtype=torch.long).to(
            slot_ids.device
        )
        slot_mask = slot_ids > 0
        hid_slot, _ = self.sv_encoder(
            slot_ids.view(-1, self.max_label_length),
            slot_type_ids.view(-1, self.max_label_length),
            slot_mask.view(-1, self.max_label_length),
        )
        hid_slot = hid_slot[:, 0, :]
        hid_slot = hid_slot.detach()
        self.slot_lookup = nn.Embedding.from_pretrained(hid_slot, freeze=True)

        for s, label_id in enumerate(label_ids):
            label_type_ids = torch.zeros(label_id.size(), dtype=torch.long).to(
                label_id.device
            )
            label_mask = label_id > 0
            hid_label, _ = self.sv_encoder(
                label_id.view(-1, self.max_label_length),
                label_type_ids.view(-1, self.max_label_length),
                label_mask.view(-1, self.max_label_length),
            )
            hid_label = hid_label[:, 0, :]
            hid_label = hid_label.detach()
            self.value_lookup[s] = nn.Embedding.from_pretrained(hid_label, freeze=True)
            self.value_lookup[s].padding_idx = -1

        print("Complete initialization of slot and value lookup")
        self.sv_encoder = None

    def forward(
        self,
        input_ids,
        token_type_ids,
        attention_mask,
        labels=None,
        n_gpu=1,
        target_slot=None,
    ):
        # input_ids: [B, M, N]
        # token_type_ids: [B, M, N]
        # attention_mask: [B, M, N]
        # labels: [B, M, J]

        # if target_slot is not specified, output values corresponding all slot-types
        if target_slot is None:
            target_slot = list(range(0, self.num_slots))
        
                                # Max sequence (N)
        ds = input_ids.size(0)  # Batch size (B)
        ts = input_ids.size(1)  # Max turn size (M)
        bs = ds * ts            # B*M
        slot_dim = len(target_slot)  # J


        # Utterance encoding
        hidden, _ = self.utterance_encoder(
            input_ids.view(-1, self.max_seq_length),
            token_type_ids.view(-1, self.max_seq_length),
            attention_mask.view(-1, self.max_seq_length),
        )
        hidden = torch.mul(
            hidden,
            attention_mask.view(-1, self.max_seq_length, 1)
            .expand(hidden.size())
            .float(),
        )
        hidden = hidden.repeat(slot_dim, 1, 1)  # [J*M*B, N, H] [6120,64,768]

        hid_slot = self.slot_lookup.weight[
            target_slot, :
        ]  # Select target slot embedding #[45,768]
        hid_slot = hid_slot.repeat(1, bs).view(bs * slot_dim, -1)  # [J*M*B, N, H] [6120,768]

        # Attended utterance vector
        hidden = self.attn(
            hid_slot,  # q^s  [J*M*B, N, H]
            hidden,  # U [J*M*B, N, H]
            hidden,  # U [J*M*B, N, H]
            mask=attention_mask.view(-1, 1, self.max_seq_length).repeat(slot_dim, 1, 1),
        )
        hidden = hidden.squeeze()  # h [J*M*B, H] Aggregated Slot Context
        hidden = hidden.view(slot_dim, ds, ts, -1).view(
            -1, ts, self.bert_output_dim
        )  # [J*B, M, H]

        # NBT
        if self.zero_init_rnn:
            h = torch.zeros(
                self.rnn_num_layers, input_ids.shape[0] * slot_dim, self.hidden_dim
            ).to(
                self.device
            )  # [1, slot_dim*ds, hidden]
        else:
            h = hidden[:, 0, :].unsqueeze(0).repeat(self.rnn_num_layers, 1, 1)
            h = self.rnn_init_linear(h)

        if isinstance(self.nbt, nn.GRU):
            h = h.repeat(2, 1, 1)
            rnn_out, _ = self.nbt(hidden, h)  # [J*B, M, H_GRU]
            
            
        elif isinstance(self.nbt, nn.LSTM):
            c = torch.zeros(
                self.rnn_num_layers, input_ids.shape[0] * slot_dim, self.hidden_dim
            ).to(
                self.device
            )  # [1, slot_dim*ds, hidden]
            rnn_out, _ = self.nbt(hidden, (h, c))  # [slot_dim*ds, turn, hidden]
        
        
        rnn_out = rnn_out[:,:,:self.hidden_dim] +rnn_out[:,:,self.hidden_dim:] 
        rnn_out = self.layer_norm(self.linear(self.dropout(rnn_out)))

        hidden = rnn_out.view(slot_dim, ds, ts, -1)  # [J, B, M, H_GRU]


        # Label (slot-value) encoding
        loss = 0
        loss_slot = []
        pred_slot = []
        output = []
        for s, slot_id in enumerate(target_slot):  ## note: target_slots are successive
            # loss calculation
            hid_label = self.value_lookup[slot_id].weight
            num_slot_labels = hid_label.size(0)

            _hid_label = (
                hid_label.unsqueeze(0)
                .unsqueeze(0)
                .repeat(ds, ts, 1, 1)
                .view(ds * ts * num_slot_labels, -1)
            )
            _hidden = (
                hidden[s, :, :, :]
                .unsqueeze(2)
                .repeat(1, 1, num_slot_labels, 1)
                .view(ds * ts * num_slot_labels, -1)
            )
            _dist = self.metric(_hid_label, _hidden).view(ds, ts, num_slot_labels)
            _dist = -_dist
            _, pred = torch.max(_dist, -1)
            pred_slot.append(pred.view(ds, ts, 1))
            output.append(_dist)

            if labels is not None:
                _loss = self.nll(_dist.view(ds * ts, -1), labels[:, :, s].view(-1))
                loss_slot.append(_loss.item())
                loss += _loss

        pred_slot = torch.cat(pred_slot, 2)
        if labels is None:
            return output, pred_slot

        # calculate joint accuracy
        accuracy = (pred_slot == labels).view(-1, slot_dim)
        acc_slot = (
            torch.sum(accuracy, 0).float()
            / torch.sum(labels.view(-1, slot_dim) > -1, 0).float()
        )
        acc = (
            sum(torch.sum(accuracy, 1) / slot_dim).float()
            / torch.sum(labels[:, :, 0].view(-1) > -1, 0).float()
        )  # joint accuracy

        if n_gpu == 1:
            return loss, loss_slot, acc, acc_slot, pred_slot
        else:
            return (
                loss.unsqueeze(0),
                None,
                acc.unsqueeze(0),
                acc_slot.unsqueeze(0),
                pred_slot.unsqueeze(0),
            )

    @staticmethod
    def init_parameter(module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_normal_(module.weight)
            torch.nn.init.constant_(module.bias, 0.0)
        elif isinstance(module, nn.GRU) or isinstance(module, nn.LSTM):
            torch.nn.init.xavier_normal_(module.weight_ih_l0)
            torch.nn.init.xavier_normal_(module.weight_hh_l0)
            torch.nn.init.constant_(module.bias_ih_l0, 0.0)
            torch.nn.init.constant_(module.bias_hh_l0, 0.0)

## TODO-2: Ontology Pre-Encoding 

Ontology의 slot type들과 이에 속하는 slot_value들을 tokenizing하는 `tokenize_ontology`를 작성하세요. <br>
[CLS] Pooling하여 `slot_lookup` 과 `value_lookup` embedding matrix들을 초기화하는 <br>
`initialize_slot_value_lookup`에 인자로 넘겨주세요. <br>

In [15]:
def tokenize_ontology(ontology, tokenizer, max_seq_length=12):
    slot_types = []
    slot_values = []
    for k, v in ontology.items():
        tokens = tokenizer.encode(k)
        if len(tokens) < max_seq_length:
            gap = max_seq_length - len(tokens)
            tokens.extend([tokenizer.pad_token_id] *  gap)
        slot_types.append(tokens)
        slot_value = []
        for vv in v:
            tokens = tokenizer.encode(vv)
            if len(tokens) < max_seq_length:
                gap = max_seq_length - len(tokens)
                tokens.extend([tokenizer.pad_token_id] *  gap)
            slot_value.append(tokens)
        slot_values.append(torch.LongTensor(slot_value))
    return torch.LongTensor(slot_types), slot_values

In [16]:
slot_type_ids, slot_values_ids = tokenize_ontology(ontology, tokenizer, 12)
num_labels = [len(s) for s in slot_values_ids]  # 각 Slot 별 후보 Values의 갯수

print("Tokenized Slot: ", slot_type_ids.size())
for slot, slot_value_id in zip(slot_meta, slot_values_ids):
    print(f"Tokenized Value of {slot}", slot_value_id.size())

Tokenized Slot:  torch.Size([45, 12])
Tokenized Value of 관광-경치 좋은 torch.Size([4, 12])
Tokenized Value of 관광-교육적 torch.Size([4, 12])
Tokenized Value of 관광-도보 가능 torch.Size([4, 12])
Tokenized Value of 관광-문화 예술 torch.Size([4, 12])
Tokenized Value of 관광-역사적 torch.Size([4, 12])
Tokenized Value of 관광-이름 torch.Size([112, 12])
Tokenized Value of 관광-종류 torch.Size([13, 12])
Tokenized Value of 관광-주차 가능 torch.Size([4, 12])
Tokenized Value of 관광-지역 torch.Size([7, 12])
Tokenized Value of 숙소-가격대 torch.Size([5, 12])
Tokenized Value of 숙소-도보 가능 torch.Size([4, 12])
Tokenized Value of 숙소-수영장 유무 torch.Size([4, 12])
Tokenized Value of 숙소-스파 유무 torch.Size([4, 12])
Tokenized Value of 숙소-예약 기간 torch.Size([12, 12])
Tokenized Value of 숙소-예약 명수 torch.Size([12, 12])
Tokenized Value of 숙소-예약 요일 torch.Size([9, 12])
Tokenized Value of 숙소-이름 torch.Size([186, 12])
Tokenized Value of 숙소-인터넷 가능 torch.Size([4, 12])
Tokenized Value of 숙소-조식 가능 torch.Size([4, 12])
Tokenized Value of 숙소-종류 torch.Size([9, 12])
Tokenized Valu

## Model 선언 

In [17]:
from argparse import Namespace

args = {
    'hidden_dim': 300,
    'num_rnn_layers': 1,
    'zero_init_rnn': False,
    'max_seq_length': 64,
    'max_label_length': 12,
    'attn_head': 4,
    'fix_utterance_encoder': False,
    'task_name': 'sumbtgru',
    'distance_metric': 'euclidean',
    'model_name_or_path': 'dsksd/bert-ko-small-minimal',
    'warmup_ratio': 0.1,
    'learning_rate': 5e-5,
    'weight_decay': 0.01,
    'num_train_epochs': 50
}

args = Namespace(**args)

num_labels = [len(s) for s in slot_values_ids]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_gpu = 1 if torch.cuda.device_count() < 2 else torch.cuda.device_count()
n_epochs = args.num_train_epochs

In [18]:
model = SUMBT(args, num_labels, device)
model.initialize_slot_value_lookup(slot_values_ids, slot_type_ids)  # Tokenized Ontology의 Pre-encoding using BERT_SV
model.to(device)
print()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=434.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=284118515.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at dsksd/bert-ko-small-minimal were not used when initializing BertForUtteranceEncoding: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForUtteranceEncoding from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForUtteranceEncoding from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at dsksd/bert-ko-small-minimal 

Complete initialization of slot and value lookup



## 데이터 로더 정의

In [19]:
from data_utils import WOSDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import AdamW, get_linear_schedule_with_warmup
import random


train_data = WOSDataset(train_features)
train_sampler = RandomSampler(train_data)
train_loader = DataLoader(train_data, batch_size=4, sampler=train_sampler, collate_fn=processor.collate_fn)

dev_data = WOSDataset(dev_features)
dev_sampler = SequentialSampler(dev_data)
dev_loader = DataLoader(dev_data, batch_size=4, sampler=dev_sampler, collate_fn=processor.collate_fn)

## Optimizer & Scheduler 선언 

In [20]:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

t_total = len(train_loader) * n_epochs
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=int(t_total * args.warmup_ratio), num_training_steps=t_total
)

## TODO-3: Inference code 작성 

In [21]:
from evaluation import _evaluation

In [22]:
def inference(model, eval_loader, processor, device):
    model.eval()
    predictions = {}
    for batch in tqdm(eval_loader):
        input_ids, segment_ids, input_masks, target_ids, num_turns, guids = \
        [b.to(device) if not isinstance(b, list) else b for b in batch]

        with torch.no_grad():
            _, pred_slot = model(
                input_ids, segment_ids, input_masks, labels=None, n_gpu=1
            )
        
        batch_size = input_ids.size(0)
        for i in range(batch_size):
            guid = guids[i]
            states = processor.recover_state(pred_slot.tolist()[i], num_turns[i])
            for tid, state in enumerate(states):
                predictions[f"{guid}-{tid}"] = state
    return predictions

## Training 

In [23]:
best_score, best_checkpoint = 0, 0
for epoch in range(n_epochs):
    batch_loss = []
    model.train()
    for step, batch in enumerate(train_loader):
        input_ids, segment_ids, input_masks, target_ids, num_turns, guids  = \
        [b.to(device) if not isinstance(b, list) else b for b in batch]

        # Forward
        if n_gpu == 1:
            loss, loss_slot, acc, acc_slot, _ = model(input_ids, segment_ids, input_masks, target_ids, n_gpu)
        else:
            loss, _, acc, acc_slot, _ = model(input_ids, segment_ids, input_masks, target_ids, n_gpu)
        
        batch_loss.append(loss.item())

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        if step % 100 == 0:
            print('[%d/%d] [%d/%d] %f' % (epoch, n_epochs, step, len(train_loader), loss.item()))
    
    predictions = inference(model, dev_loader, processor, device)
    eval_result = _evaluation(predictions, dev_labels, slot_meta)
    score = eval_result['joint_goal_accuracy']
    if score > best_score:
        cnt = 0
        best_score = score
        torch.save(model.state_dict(), "/content/drive/MyDrive/Stage3/model/sumbt_somtology_50_best.pt")
    
    for k, v in eval_result.items():
        print(f"{k}: {v}")

[0/50] [0/1576] 117.789993
[0/50] [100/1576] 82.087982
[0/50] [200/1576] 45.805374
[0/50] [300/1576] 48.390671
[0/50] [400/1576] 39.113377
[0/50] [500/1576] 39.562935
[0/50] [600/1576] 39.747345
[0/50] [700/1576] 30.635595
[0/50] [800/1576] 36.739212
[0/50] [900/1576] 31.622696
[0/50] [1000/1576] 30.433405
[0/50] [1100/1576] 37.254543
[0/50] [1200/1576] 27.618567
[0/50] [1300/1576] 29.186260
[0/50] [1400/1576] 39.073238
[0/50] [1500/1576] 31.720762


100%|██████████| 175/175 [00:56<00:00,  3.12it/s]


{'joint_goal_accuracy': 0.019901477832512317, 'turn_slot_accuracy': 0.837495347564315, 'turn_slot_f1': 0.18304545652491044}
joint_goal_accuracy: 0.019901477832512317
turn_slot_accuracy: 0.837495347564315
turn_slot_f1: 0.18304545652491044
[1/50] [0/1576] 27.121223
[1/50] [100/1576] 21.453230
[1/50] [200/1576] 33.677727
[1/50] [300/1576] 25.554983
[1/50] [400/1576] 28.526690
[1/50] [500/1576] 23.745035
[1/50] [600/1576] 18.889921
[1/50] [700/1576] 19.856777
[1/50] [800/1576] 20.421585
[1/50] [900/1576] 24.433273
[1/50] [1000/1576] 27.749201
[1/50] [1100/1576] 23.299051
[1/50] [1200/1576] 24.318411
[1/50] [1300/1576] 14.785513
[1/50] [1400/1576] 24.751047
[1/50] [1500/1576] 21.542629


100%|██████████| 175/175 [00:56<00:00,  3.12it/s]


{'joint_goal_accuracy': 0.06305418719211822, 'turn_slot_accuracy': 0.8862309797482204, 'turn_slot_f1': 0.4380856143623322}
joint_goal_accuracy: 0.06305418719211822
turn_slot_accuracy: 0.8862309797482204
turn_slot_f1: 0.4380856143623322
[2/50] [0/1576] 26.249607
[2/50] [100/1576] 19.409462
[2/50] [200/1576] 21.516968
[2/50] [300/1576] 16.518993
[2/50] [400/1576] 18.638720
[2/50] [500/1576] 18.522757
[2/50] [600/1576] 22.761511
[2/50] [700/1576] 14.373045
[2/50] [800/1576] 19.088478
[2/50] [900/1576] 12.653646
[2/50] [1000/1576] 11.396968
[2/50] [1100/1576] 16.804308
[2/50] [1200/1576] 15.240593
[2/50] [1300/1576] 11.158917
[2/50] [1400/1576] 13.974407
[2/50] [1500/1576] 9.056827


100%|██████████| 175/175 [00:56<00:00,  3.12it/s]


{'joint_goal_accuracy': 0.2051231527093596, 'turn_slot_accuracy': 0.9518598795840283, 'turn_slot_f1': 0.7825952212110691}
joint_goal_accuracy: 0.2051231527093596
turn_slot_accuracy: 0.9518598795840283
turn_slot_f1: 0.7825952212110691
[3/50] [0/1576] 8.983657
[3/50] [100/1576] 11.249883
[3/50] [200/1576] 11.808194
[3/50] [300/1576] 12.028056
[3/50] [400/1576] 12.009197
[3/50] [500/1576] 9.327368
[3/50] [600/1576] 10.548391
[3/50] [700/1576] 9.259199
[3/50] [800/1576] 7.510169
[3/50] [900/1576] 7.629751
[3/50] [1000/1576] 5.095890
[3/50] [1100/1576] 3.741469
[3/50] [1200/1576] 4.179442
[3/50] [1300/1576] 8.558231
[3/50] [1400/1576] 7.246936
[3/50] [1500/1576] 2.580445


100%|██████████| 175/175 [00:56<00:00,  3.12it/s]


{'joint_goal_accuracy': 0.44807881773399016, 'turn_slot_accuracy': 0.9763503010399672, 'turn_slot_f1': 0.9006247881355435}
joint_goal_accuracy: 0.44807881773399016
turn_slot_accuracy: 0.9763503010399672
turn_slot_f1: 0.9006247881355435
[4/50] [0/1576] 9.661570
[4/50] [100/1576] 5.455002
[4/50] [200/1576] 4.248937
[4/50] [300/1576] 5.155593
[4/50] [400/1576] 3.921451
[4/50] [500/1576] 7.187084
[4/50] [600/1576] 1.965773
[4/50] [700/1576] 3.984461
[4/50] [800/1576] 4.824460
[4/50] [900/1576] 5.492684
[4/50] [1000/1576] 3.371562
[4/50] [1100/1576] 3.610557
[4/50] [1200/1576] 2.359735
[4/50] [1300/1576] 3.561301
[4/50] [1400/1576] 2.266639
[4/50] [1500/1576] 5.081060


100%|██████████| 175/175 [00:56<00:00,  3.12it/s]


{'joint_goal_accuracy': 0.6666009852216749, 'turn_slot_accuracy': 0.988422550629455, 'turn_slot_f1': 0.9532866146030753}
joint_goal_accuracy: 0.6666009852216749
turn_slot_accuracy: 0.988422550629455
turn_slot_f1: 0.9532866146030753
[5/50] [0/1576] 3.755262
[5/50] [100/1576] 3.031674
[5/50] [200/1576] 2.242636
[5/50] [300/1576] 2.704596
[5/50] [400/1576] 2.449701
[5/50] [500/1576] 5.841527
[5/50] [600/1576] 4.563891
[5/50] [700/1576] 4.436372
[5/50] [800/1576] 4.216320
[5/50] [900/1576] 1.843679
[5/50] [1000/1576] 6.136195
[5/50] [1100/1576] 2.365789
[5/50] [1200/1576] 1.789667
[5/50] [1300/1576] 2.231672
[5/50] [1400/1576] 1.771273
[5/50] [1500/1576] 3.763737


100%|██████████| 175/175 [00:56<00:00,  3.12it/s]


{'joint_goal_accuracy': 0.7101477832512315, 'turn_slot_accuracy': 0.9900470717022546, 'turn_slot_f1': 0.9636403993547725}
joint_goal_accuracy: 0.7101477832512315
turn_slot_accuracy: 0.9900470717022546
turn_slot_f1: 0.9636403993547725
[6/50] [0/1576] 4.482941
[6/50] [100/1576] 1.559965
[6/50] [200/1576] 0.787294
[6/50] [300/1576] 2.190218
[6/50] [400/1576] 2.178063
[6/50] [500/1576] 1.953700
[6/50] [600/1576] 2.341014
[6/50] [700/1576] 1.874679
[6/50] [800/1576] 1.218296
[6/50] [900/1576] 1.218451
[6/50] [1000/1576] 3.506203
[6/50] [1100/1576] 2.311724
[6/50] [1200/1576] 1.033752
[6/50] [1300/1576] 1.883006
[6/50] [1400/1576] 2.628989
[6/50] [1500/1576] 1.151574


100%|██████████| 175/175 [00:56<00:00,  3.12it/s]


{'joint_goal_accuracy': 0.7400985221674877, 'turn_slot_accuracy': 0.9913037766830951, 'turn_slot_f1': 0.9670756524002179}
joint_goal_accuracy: 0.7400985221674877
turn_slot_accuracy: 0.9913037766830951
turn_slot_f1: 0.9670756524002179
[7/50] [0/1576] 2.245846
[7/50] [100/1576] 1.899761
[7/50] [200/1576] 1.782034
[7/50] [300/1576] 1.637957
[7/50] [400/1576] 1.730893
[7/50] [500/1576] 1.982770
[7/50] [600/1576] 0.963551
[7/50] [700/1576] 1.103820
[7/50] [800/1576] 1.373117
[7/50] [900/1576] 2.905252
[7/50] [1000/1576] 4.098300
[7/50] [1100/1576] 5.281775
[7/50] [1200/1576] 1.594005
[7/50] [1300/1576] 1.619616
[7/50] [1400/1576] 0.842084
[7/50] [1500/1576] 1.078067


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.7724137931034483, 'turn_slot_accuracy': 0.9928363437329044, 'turn_slot_f1': 0.9706345812659849}
joint_goal_accuracy: 0.7724137931034483
turn_slot_accuracy: 0.9928363437329044
turn_slot_f1: 0.9706345812659849
[8/50] [0/1576] 1.373828
[8/50] [100/1576] 1.064789
[8/50] [200/1576] 1.750282
[8/50] [300/1576] 1.207763
[8/50] [400/1576] 1.008656
[8/50] [500/1576] 1.383081
[8/50] [600/1576] 1.469760
[8/50] [700/1576] 0.863820
[8/50] [800/1576] 0.999965
[8/50] [900/1576] 0.445856
[8/50] [1000/1576] 1.225067
[8/50] [1100/1576] 1.537211
[8/50] [1200/1576] 1.567779
[8/50] [1300/1576] 0.757010
[8/50] [1400/1576] 0.679780
[8/50] [1500/1576] 0.725845


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.7850246305418719, 'turn_slot_accuracy': 0.9927662835249069, 'turn_slot_f1': 0.9729804016945475}
joint_goal_accuracy: 0.7850246305418719
turn_slot_accuracy: 0.9927662835249069
turn_slot_f1: 0.9729804016945475
[9/50] [0/1576] 1.135993
[9/50] [100/1576] 1.509543
[9/50] [200/1576] 1.305833
[9/50] [300/1576] 0.867429
[9/50] [400/1576] 2.162632
[9/50] [500/1576] 0.463066
[9/50] [600/1576] 1.585859
[9/50] [700/1576] 1.087725
[9/50] [800/1576] 0.660950
[9/50] [900/1576] 0.746819
[9/50] [1000/1576] 1.704839
[9/50] [1100/1576] 0.886125
[9/50] [1200/1576] 0.669976
[9/50] [1300/1576] 2.137031
[9/50] [1400/1576] 0.682333
[9/50] [1500/1576] 1.493556


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.7994088669950739, 'turn_slot_accuracy': 0.9935106732348166, 'turn_slot_f1': 0.9753451327938828}
joint_goal_accuracy: 0.7994088669950739
turn_slot_accuracy: 0.9935106732348166
turn_slot_f1: 0.9753451327938828
[10/50] [0/1576] 0.342926
[10/50] [100/1576] 0.454161
[10/50] [200/1576] 1.392246
[10/50] [300/1576] 0.991418
[10/50] [400/1576] 0.462528
[10/50] [500/1576] 0.927725
[10/50] [600/1576] 0.802626
[10/50] [700/1576] 0.322857
[10/50] [800/1576] 0.836736
[10/50] [900/1576] 0.378224
[10/50] [1000/1576] 0.793440
[10/50] [1100/1576] 0.658747
[10/50] [1200/1576] 3.393567
[10/50] [1300/1576] 4.125965
[10/50] [1400/1576] 0.647573
[10/50] [1500/1576] 1.004234


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.8076847290640394, 'turn_slot_accuracy': 0.9938828680897714, 'turn_slot_f1': 0.9765820192528556}
joint_goal_accuracy: 0.8076847290640394
turn_slot_accuracy: 0.9938828680897714
turn_slot_f1: 0.9765820192528556
[11/50] [0/1576] 0.644554
[11/50] [100/1576] 0.703817
[11/50] [200/1576] 0.541686
[11/50] [300/1576] 0.172873
[11/50] [400/1576] 1.451236
[11/50] [500/1576] 0.526461
[11/50] [600/1576] 0.819581
[11/50] [700/1576] 6.009538
[11/50] [800/1576] 0.409001
[11/50] [900/1576] 0.298077
[11/50] [1000/1576] 0.961233
[11/50] [1100/1576] 0.945875
[11/50] [1200/1576] 0.859288
[11/50] [1300/1576] 0.793181
[11/50] [1400/1576] 0.269092
[11/50] [1500/1576] 1.116082


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.8027586206896552, 'turn_slot_accuracy': 0.9939135194307662, 'turn_slot_f1': 0.9767012281478247}
joint_goal_accuracy: 0.8027586206896552
turn_slot_accuracy: 0.9939135194307662
turn_slot_f1: 0.9767012281478247
[12/50] [0/1576] 0.337408
[12/50] [100/1576] 1.204474
[12/50] [200/1576] 0.611621
[12/50] [300/1576] 0.628863
[12/50] [400/1576] 0.238226
[12/50] [500/1576] 0.368693
[12/50] [600/1576] 0.638217
[12/50] [700/1576] 0.820922
[12/50] [800/1576] 1.145684
[12/50] [900/1576] 0.352845
[12/50] [1000/1576] 1.507901
[12/50] [1100/1576] 1.165614
[12/50] [1200/1576] 0.590197
[12/50] [1300/1576] 0.466239
[12/50] [1400/1576] 0.341001
[12/50] [1500/1576] 1.683064


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.8082758620689655, 'turn_slot_accuracy': 0.9938960043787689, 'turn_slot_f1': 0.9778196943609373}
joint_goal_accuracy: 0.8082758620689655
turn_slot_accuracy: 0.9938960043787689
turn_slot_f1: 0.9778196943609373
[13/50] [0/1576] 0.415534
[13/50] [100/1576] 0.872363
[13/50] [200/1576] 0.788723
[13/50] [300/1576] 0.518390
[13/50] [400/1576] 0.370151
[13/50] [500/1576] 0.644505
[13/50] [600/1576] 1.095670
[13/50] [700/1576] 0.822790
[13/50] [800/1576] 2.854345
[13/50] [900/1576] 0.206863
[13/50] [1000/1576] 0.476161
[13/50] [1100/1576] 0.884161
[13/50] [1200/1576] 0.229285
[13/50] [1300/1576] 0.292152
[13/50] [1400/1576] 0.051549
[13/50] [1500/1576] 1.129906


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.8100492610837439, 'turn_slot_accuracy': 0.9940580186097481, 'turn_slot_f1': 0.978241392596467}
joint_goal_accuracy: 0.8100492610837439
turn_slot_accuracy: 0.9940580186097481
turn_slot_f1: 0.978241392596467
[14/50] [0/1576] 0.278062
[14/50] [100/1576] 0.713576
[14/50] [200/1576] 0.143571
[14/50] [300/1576] 0.754636
[14/50] [400/1576] 0.638327
[14/50] [500/1576] 0.525596
[14/50] [600/1576] 0.709667
[14/50] [700/1576] 0.192401
[14/50] [800/1576] 0.684779
[14/50] [900/1576] 0.379286
[14/50] [1000/1576] 0.480153
[14/50] [1100/1576] 0.693224
[14/50] [1200/1576] 0.173617
[14/50] [1300/1576] 5.395926
[14/50] [1400/1576] 0.723558
[14/50] [1500/1576] 0.579838


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.7998029556650247, 'turn_slot_accuracy': 0.9939178981937672, 'turn_slot_f1': 0.9769745716780132}
joint_goal_accuracy: 0.7998029556650247
turn_slot_accuracy: 0.9939178981937672
turn_slot_f1: 0.9769745716780132
[15/50] [0/1576] 0.943625
[15/50] [100/1576] 0.592107
[15/50] [200/1576] 0.263217
[15/50] [300/1576] 0.234775
[15/50] [400/1576] 0.496026
[15/50] [500/1576] 0.331095
[15/50] [600/1576] 0.509381
[15/50] [700/1576] 0.268953
[15/50] [800/1576] 0.433296
[15/50] [900/1576] 1.745117
[15/50] [1000/1576] 0.305025
[15/50] [1100/1576] 1.717768
[15/50] [1200/1576] 1.612925
[15/50] [1300/1576] 0.273259
[15/50] [1400/1576] 0.734739
[15/50] [1500/1576] 0.418329


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.8126108374384237, 'turn_slot_accuracy': 0.9938653530377735, 'turn_slot_f1': 0.9778613923960868}
joint_goal_accuracy: 0.8126108374384237
turn_slot_accuracy: 0.9938653530377735
turn_slot_f1: 0.9778613923960868
[16/50] [0/1576] 0.428835
[16/50] [100/1576] 0.207663
[16/50] [200/1576] 0.297957
[16/50] [300/1576] 0.443406
[16/50] [400/1576] 0.575846
[16/50] [500/1576] 0.362132
[16/50] [600/1576] 0.461482
[16/50] [700/1576] 1.327822
[16/50] [800/1576] 0.184611
[16/50] [900/1576] 0.172083
[16/50] [1000/1576] 0.541016
[16/50] [1100/1576] 0.360219
[16/50] [1200/1576] 0.424316
[16/50] [1300/1576] 0.329608
[16/50] [1400/1576] 0.739154
[16/50] [1500/1576] 0.977229


100%|██████████| 175/175 [00:56<00:00,  3.12it/s]


{'joint_goal_accuracy': 0.8165517241379311, 'turn_slot_accuracy': 0.9939967159277557, 'turn_slot_f1': 0.9776250712031602}
joint_goal_accuracy: 0.8165517241379311
turn_slot_accuracy: 0.9939967159277557
turn_slot_f1: 0.9776250712031602
[17/50] [0/1576] 0.546311
[17/50] [100/1576] 0.766368
[17/50] [200/1576] 0.447835
[17/50] [300/1576] 0.238039
[17/50] [400/1576] 0.311556
[17/50] [500/1576] 0.395889
[17/50] [600/1576] 0.154439
[17/50] [700/1576] 0.407025
[17/50] [800/1576] 0.339419
[17/50] [900/1576] 0.801396
[17/50] [1000/1576] 0.296527
[17/50] [1100/1576] 0.251109
[17/50] [1200/1576] 0.533438
[17/50] [1300/1576] 0.209360
[17/50] [1400/1576] 0.315541
[17/50] [1500/1576] 0.552668


100%|██████████| 175/175 [00:56<00:00,  3.12it/s]


{'joint_goal_accuracy': 0.818128078817734, 'turn_slot_accuracy': 0.994311986863717, 'turn_slot_f1': 0.9789200514227752}
joint_goal_accuracy: 0.818128078817734
turn_slot_accuracy: 0.994311986863717
turn_slot_f1: 0.9789200514227752
[18/50] [0/1576] 0.189134
[18/50] [100/1576] 0.442700
[18/50] [200/1576] 0.108100
[18/50] [300/1576] 0.335394
[18/50] [400/1576] 0.090166
[18/50] [500/1576] 0.367024
[18/50] [600/1576] 0.059155
[18/50] [700/1576] 0.605674
[18/50] [800/1576] 1.118605
[18/50] [900/1576] 0.303246
[18/50] [1000/1576] 0.123137
[18/50] [1100/1576] 0.473841
[18/50] [1200/1576] 0.396211
[18/50] [1300/1576] 0.426206
[18/50] [1400/1576] 2.602523
[18/50] [1500/1576] 0.098908


100%|██████████| 175/175 [00:56<00:00,  3.12it/s]


{'joint_goal_accuracy': 0.8149753694581281, 'turn_slot_accuracy': 0.9941193212917425, 'turn_slot_f1': 0.9778028536003687}
joint_goal_accuracy: 0.8149753694581281
turn_slot_accuracy: 0.9941193212917425
turn_slot_f1: 0.9778028536003687
[19/50] [0/1576] 0.158045
[19/50] [100/1576] 0.276602
[19/50] [200/1576] 0.289807
[19/50] [300/1576] 0.246236
[19/50] [400/1576] 0.186123
[19/50] [500/1576] 0.121569
[19/50] [600/1576] 0.447219
[19/50] [700/1576] 0.203161
[19/50] [800/1576] 0.741232
[19/50] [900/1576] 0.650604
[19/50] [1000/1576] 0.556895
[19/50] [1100/1576] 0.454578
[19/50] [1200/1576] 0.362467
[19/50] [1300/1576] 0.214383
[19/50] [1400/1576] 0.524350
[19/50] [1500/1576] 0.143372


100%|██████████| 175/175 [00:56<00:00,  3.12it/s]


{'joint_goal_accuracy': 0.8149753694581281, 'turn_slot_accuracy': 0.9942244116037277, 'turn_slot_f1': 0.9779822566922731}
joint_goal_accuracy: 0.8149753694581281
turn_slot_accuracy: 0.9942244116037277
turn_slot_f1: 0.9779822566922731
[20/50] [0/1576] 0.503178
[20/50] [100/1576] 0.191806
[20/50] [200/1576] 0.299021
[20/50] [300/1576] 0.358112
[20/50] [400/1576] 0.272985
[20/50] [500/1576] 0.347354
[20/50] [600/1576] 0.397895
[20/50] [700/1576] 0.440670
[20/50] [800/1576] 1.682811
[20/50] [900/1576] 0.331509
[20/50] [1000/1576] 0.590058
[20/50] [1100/1576] 1.068024
[20/50] [1200/1576] 0.279029
[20/50] [1300/1576] 2.656569
[20/50] [1400/1576] 0.109765
[20/50] [1500/1576] 0.182435


100%|██████████| 175/175 [00:56<00:00,  3.12it/s]


{'joint_goal_accuracy': 0.8171428571428572, 'turn_slot_accuracy': 0.9944389709907002, 'turn_slot_f1': 0.9791442348334063}
joint_goal_accuracy: 0.8171428571428572
turn_slot_accuracy: 0.9944389709907002
turn_slot_f1: 0.9791442348334063
[21/50] [0/1576] 0.363600
[21/50] [100/1576] 0.274074
[21/50] [200/1576] 0.483864
[21/50] [300/1576] 0.246933
[21/50] [400/1576] 0.322369
[21/50] [500/1576] 0.319024
[21/50] [600/1576] 0.256468
[21/50] [700/1576] 0.164872
[21/50] [800/1576] 0.202628
[21/50] [900/1576] 0.159055
[21/50] [1000/1576] 0.379479
[21/50] [1100/1576] 1.604359
[21/50] [1200/1576] 0.300050
[21/50] [1300/1576] 2.369132
[21/50] [1400/1576] 0.245721
[21/50] [1500/1576] 0.186755


100%|██████████| 175/175 [00:56<00:00,  3.12it/s]


{'joint_goal_accuracy': 0.8092610837438423, 'turn_slot_accuracy': 0.9939967159277555, 'turn_slot_f1': 0.9773072520207736}
joint_goal_accuracy: 0.8092610837438423
turn_slot_accuracy: 0.9939967159277555
turn_slot_f1: 0.9773072520207736
[22/50] [0/1576] 0.166636
[22/50] [100/1576] 0.228949
[22/50] [200/1576] 0.218505
[22/50] [300/1576] 0.398515
[22/50] [400/1576] 0.301019
[22/50] [500/1576] 0.497826
[22/50] [600/1576] 0.361228
[22/50] [700/1576] 0.517738
[22/50] [800/1576] 0.192286
[22/50] [900/1576] 0.370650
[22/50] [1000/1576] 0.215971
[22/50] [1100/1576] 0.179416
[22/50] [1200/1576] 0.143951
[22/50] [1300/1576] 0.121865
[22/50] [1400/1576] 0.381628
[22/50] [1500/1576] 0.388931


100%|██████████| 175/175 [00:56<00:00,  3.12it/s]


{'joint_goal_accuracy': 0.8199014778325123, 'turn_slot_accuracy': 0.9942463054187268, 'turn_slot_f1': 0.9789640373203313}
joint_goal_accuracy: 0.8199014778325123
turn_slot_accuracy: 0.9942463054187268
turn_slot_f1: 0.9789640373203313
[23/50] [0/1576] 0.233681
[23/50] [100/1576] 0.206253
[23/50] [200/1576] 0.389384
[23/50] [300/1576] 0.248407
[23/50] [400/1576] 0.124793
[23/50] [500/1576] 0.271780
[23/50] [600/1576] 0.292562
[23/50] [700/1576] 0.157553
[23/50] [800/1576] 0.303680
[23/50] [900/1576] 0.370631
[23/50] [1000/1576] 0.325653
[23/50] [1100/1576] 0.122101
[23/50] [1200/1576] 0.233591
[23/50] [1300/1576] 0.277914
[23/50] [1400/1576] 0.668100
[23/50] [1500/1576] 0.186687


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.803743842364532, 'turn_slot_accuracy': 0.9940448823207512, 'turn_slot_f1': 0.9772124939671539}
joint_goal_accuracy: 0.803743842364532
turn_slot_accuracy: 0.9940448823207512
turn_slot_f1: 0.9772124939671539
[24/50] [0/1576] 0.126673
[24/50] [100/1576] 0.302245
[24/50] [200/1576] 0.065702
[24/50] [300/1576] 0.277048
[24/50] [400/1576] 0.420815
[24/50] [500/1576] 1.998721
[24/50] [600/1576] 0.139109
[24/50] [700/1576] 0.236797
[24/50] [800/1576] 0.156023
[24/50] [900/1576] 0.330505
[24/50] [1000/1576] 0.316634
[24/50] [1100/1576] 0.683673
[24/50] [1200/1576] 0.382713
[24/50] [1300/1576] 0.306705
[24/50] [1400/1576] 0.248855
[24/50] [1500/1576] 0.293744


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.8258128078817734, 'turn_slot_accuracy': 0.994403940886706, 'turn_slot_f1': 0.9798134936818429}
joint_goal_accuracy: 0.8258128078817734
turn_slot_accuracy: 0.994403940886706
turn_slot_f1: 0.9798134936818429
[25/50] [0/1576] 0.225162
[25/50] [100/1576] 0.395206
[25/50] [200/1576] 0.264974
[25/50] [300/1576] 0.308623
[25/50] [400/1576] 0.273503
[25/50] [500/1576] 0.494310
[25/50] [600/1576] 0.129462
[25/50] [700/1576] 0.179956
[25/50] [800/1576] 0.193206
[25/50] [900/1576] 0.077461
[25/50] [1000/1576] 0.602033
[25/50] [1100/1576] 0.408238
[25/50] [1200/1576] 0.176780
[25/50] [1300/1576] 0.168758
[25/50] [1400/1576] 0.175119
[25/50] [1500/1576] 0.322269


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.8151724137931035, 'turn_slot_accuracy': 0.9942112753147296, 'turn_slot_f1': 0.9785944552957331}
joint_goal_accuracy: 0.8151724137931035
turn_slot_accuracy: 0.9942112753147296
turn_slot_f1: 0.9785944552957331
[26/50] [0/1576] 0.384895
[26/50] [100/1576] 0.525450
[26/50] [200/1576] 0.157110
[26/50] [300/1576] 0.238227
[26/50] [400/1576] 0.511335
[26/50] [500/1576] 0.179076
[26/50] [600/1576] 0.201636
[26/50] [700/1576] 0.128039
[26/50] [800/1576] 0.155655
[26/50] [900/1576] 0.200385
[26/50] [1000/1576] 0.144612
[26/50] [1100/1576] 0.248664
[26/50] [1200/1576] 0.127035
[26/50] [1300/1576] 0.136925
[26/50] [1400/1576] 0.587221
[26/50] [1500/1576] 0.387359


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.818128078817734, 'turn_slot_accuracy': 0.9943864258347065, 'turn_slot_f1': 0.9788654103359145}
joint_goal_accuracy: 0.818128078817734
turn_slot_accuracy: 0.9943864258347065
turn_slot_f1: 0.9788654103359145
[27/50] [0/1576] 0.179599
[27/50] [100/1576] 0.170776
[27/50] [200/1576] 0.352366
[27/50] [300/1576] 0.260779
[27/50] [400/1576] 0.498539
[27/50] [500/1576] 1.173380
[27/50] [600/1576] 0.145568
[27/50] [700/1576] 0.211061
[27/50] [800/1576] 0.176614
[27/50] [900/1576] 0.149952
[27/50] [1000/1576] 0.158814
[27/50] [1100/1576] 0.321529
[27/50] [1200/1576] 0.860005
[27/50] [1300/1576] 0.149601
[27/50] [1400/1576] 0.164438
[27/50] [1500/1576] 0.146771


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.8167487684729065, 'turn_slot_accuracy': 0.9941018062397426, 'turn_slot_f1': 0.9779473273361858}
joint_goal_accuracy: 0.8167487684729065
turn_slot_accuracy: 0.9941018062397426
turn_slot_f1: 0.9779473273361858
[28/50] [0/1576] 0.725197
[28/50] [100/1576] 0.227825
[28/50] [200/1576] 0.218799
[28/50] [300/1576] 0.146131
[28/50] [400/1576] 0.032624
[28/50] [500/1576] 0.153405
[28/50] [600/1576] 0.200820
[28/50] [700/1576] 0.143505
[28/50] [800/1576] 0.336215
[28/50] [900/1576] 0.076423
[28/50] [1000/1576] 1.117327
[28/50] [1100/1576] 0.300711
[28/50] [1200/1576] 0.119112
[28/50] [1300/1576] 0.136977
[28/50] [1400/1576] 0.201342
[28/50] [1500/1576] 0.166576


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.8118226600985222, 'turn_slot_accuracy': 0.9940974274767441, 'turn_slot_f1': 0.9785657239369282}
joint_goal_accuracy: 0.8118226600985222
turn_slot_accuracy: 0.9940974274767441
turn_slot_f1: 0.9785657239369282
[29/50] [0/1576] 0.353142
[29/50] [100/1576] 0.359128
[29/50] [200/1576] 0.171177
[29/50] [300/1576] 0.238962
[29/50] [400/1576] 0.139417
[29/50] [500/1576] 0.123090
[29/50] [600/1576] 0.452440
[29/50] [700/1576] 0.392641
[29/50] [800/1576] 0.201601
[29/50] [900/1576] 0.073917
[29/50] [1000/1576] 0.244641
[29/50] [1100/1576] 0.263498
[29/50] [1200/1576] 0.358255
[29/50] [1300/1576] 0.156767
[29/50] [1400/1576] 0.261220
[29/50] [1500/1576] 0.102554


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.8135960591133005, 'turn_slot_accuracy': 0.9941412151067389, 'turn_slot_f1': 0.9789246889197559}
joint_goal_accuracy: 0.8135960591133005
turn_slot_accuracy: 0.9941412151067389
turn_slot_f1: 0.9789246889197559
[30/50] [0/1576] 0.091096
[30/50] [100/1576] 0.107516
[30/50] [200/1576] 0.330826
[30/50] [300/1576] 0.224968
[30/50] [400/1576] 0.184235
[30/50] [500/1576] 0.272215
[30/50] [600/1576] 0.204735
[30/50] [700/1576] 0.067695
[30/50] [800/1576] 0.101103
[30/50] [900/1576] 0.165627
[30/50] [1000/1576] 0.182627
[30/50] [1100/1576] 0.169311
[30/50] [1200/1576] 0.250023
[30/50] [1300/1576] 0.413004
[30/50] [1400/1576] 0.189258
[30/50] [1500/1576] 0.682283


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.8175369458128079, 'turn_slot_accuracy': 0.9941018062397429, 'turn_slot_f1': 0.9783465470564906}
joint_goal_accuracy: 0.8175369458128079
turn_slot_accuracy: 0.9941018062397429
turn_slot_f1: 0.9783465470564906
[31/50] [0/1576] 0.249785
[31/50] [100/1576] 0.415033
[31/50] [200/1576] 0.154545
[31/50] [300/1576] 0.229554
[31/50] [400/1576] 0.571976
[31/50] [500/1576] 0.254108
[31/50] [600/1576] 0.278035
[31/50] [700/1576] 0.199953
[31/50] [800/1576] 0.332441
[31/50] [900/1576] 0.296605
[31/50] [1000/1576] 0.168163
[31/50] [1100/1576] 0.259439
[31/50] [1200/1576] 0.200769
[31/50] [1300/1576] 0.257144
[31/50] [1400/1576] 0.244028
[31/50] [1500/1576] 0.147309


100%|██████████| 175/175 [00:56<00:00,  3.11it/s]


{'joint_goal_accuracy': 0.8179310344827586, 'turn_slot_accuracy': 0.9942287903667273, 'turn_slot_f1': 0.978653597943107}
joint_goal_accuracy: 0.8179310344827586
turn_slot_accuracy: 0.9942287903667273
turn_slot_f1: 0.978653597943107
[32/50] [0/1576] 0.208481


KeyboardInterrupt: ignored

## Inference

In [24]:
model.load_state_dict(torch.load("/content/drive/MyDrive/Stage3/model/sumbt_somtology_50_best.pt"))
model = model.eval()

In [25]:
eval_data = json.load(open(f"/content/drive/MyDrive/Stage3/input/data/eval_dataset/eval_dials.json", "r"))

eval_examples = get_examples_from_dialogues(
    eval_data, user_first=True, dialogue_level=True
)

# Extracting Featrues
eval_features = processor.convert_examples_to_features(eval_examples)
eval_data = WOSDataset(eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_loader = DataLoader(
    eval_data,
    batch_size=8,
    sampler=eval_sampler,
    collate_fn=processor.collate_fn,
)


100%|██████████| 2000/2000 [00:00<00:00, 13026.58it/s]


In [26]:
predictions = inference(model, eval_loader, processor, device)

100%|██████████| 250/250 [02:38<00:00,  1.58it/s]


In [27]:
json.dump(predictions, open('/content/drive/MyDrive/Stage3/sumbt_somtology_50_best_pred.csv', 'w'), indent=2, ensure_ascii=False) 

In [28]:
#model.load_state_dict(torch.load("/content/drive/MyDrive/Stage3/model/sumbt50_best.pt"))
#model = model.eval()
predictions = inference(model, dev_loader, processor, device)
eval_result = _evaluation(predictions, dev_labels, slot_meta)

100%|██████████| 175/175 [00:56<00:00,  3.11it/s]

{'joint_goal_accuracy': 0.8258128078817734, 'turn_slot_accuracy': 0.994403940886706, 'turn_slot_f1': 0.9798134936818429}





In [29]:
cat_ontology = []
noncat_ontology = []
for i in range(45):
    if len(list(ontology.items())[i][1]) <=9:
        cat_ontology.append(list(ontology.items())[i][0])
    else:
        noncat_ontology.append(list(ontology.items())[i][0])

In [30]:
category_pred = []
noncategory_pred = []
category_label = []
noncategory_label = []

In [31]:
pred_keys = list(predictions.keys())
pred_values = list(predictions.values())
labels_keys = list(dev_labels.keys())
labels_values = list(dev_labels.values())

In [32]:
correct = []
wrong = []

In [33]:
# 전체 예측에 대해
last_key = ''
for num_labels in range(len(list(predictions))):
    
    # 마지막 예측이랑 비교해 새로운 예측만 생각
    if labels_keys[num_labels].split(':')[0] != last_key:
        last_pred = set([])
        last_label = set([])
        last_key = labels_keys[num_labels].split(':')[0]
    
    else:
        last_pred = set(pred_values[num_labels-1]) 
        last_label = set(labels_values[num_labels-1])
    
    pred_slots = list(set(pred_values[num_labels]) - last_pred)
    label_slots = list(set(labels_values[num_labels]) - last_label)
                       
    #예측 안에 각각의 예측 slot에 대해
    for label_slot_num in range(len(label_slots)):
        
        # 정답에 예측이 있으면
        if label_slots[label_slot_num] in pred_slots:
            correct.append(label_slots[label_slot_num])
        else:
            wrong.append(label_slots[label_slot_num])

In [34]:
print('len_correct:',len(correct))
print('len_wrong:',len(wrong))

len_correct: 7649
len_wrong: 288


In [35]:
cat_correct = []
noncat_correct = []
cat_wrong = []
noncat_wrong = []
for label in correct:
    if label.split('-')[0] + '-' +label.split('-')[1] in cat_ontology:
        cat_correct.append(label)
    else:
        noncat_correct.append(label)
for label in wrong:
    if label.split('-')[0] + '-' +label.split('-')[1] in cat_ontology:
        cat_wrong.append(label)
    else:
        noncat_wrong.append(label)


In [36]:
print('cat_correct:',len(cat_correct))
print('noncat_correct:',len(noncat_correct))
print('cat_wrong:',len(cat_wrong))
print('noncat_wrong:',len(noncat_wrong))

cat_correct: 3590
noncat_correct: 4059
cat_wrong: 120
noncat_wrong: 168


In [37]:
from collections import defaultdict

In [38]:
cat_correct_dict = defaultdict(int)
noncat_correct_dict = defaultdict(int)
cat_wrong_dict = defaultdict(int)
noncat_wrong_dict = defaultdict(int)

In [39]:
for cat in cat_correct:
    cat_correct_dict[cat.split('-')[0] + '-' +cat.split('-')[1]] +=1
for cat in noncat_correct:
    noncat_correct_dict[cat.split('-')[0] + '-' +cat.split('-')[1]] +=1    
for cat in cat_wrong:
    cat_wrong_dict[cat.split('-')[0] + '-' +cat.split('-')[1]] +=1
for cat in noncat_wrong:
    noncat_wrong_dict[cat.split('-')[0] + '-' +cat.split('-')[1]] +=1

In [40]:
print('cat 정답률')
for cat in cat_ontology:
    if (cat_correct_dict[cat] + cat_wrong_dict[cat]) == 0:
        print(str(cat),str(0))
        continue
    percent = cat_correct_dict[cat]/(cat_correct_dict[cat] + cat_wrong_dict[cat])*100
    print( str(cat), format(percent,".2f"),'%', '('+str(cat_correct_dict[cat])+'/'+str(cat_correct_dict[cat] + cat_wrong_dict[cat])+')')


cat 정답률
관광-경치 좋은 100.00 % (39/39)
관광-교육적 100.00 % (18/18)
관광-도보 가능 100.00 % (1/1)
관광-문화 예술 93.75 % (15/16)
관광-역사적 98.57 % (69/70)
관광-주차 가능 94.44 % (34/36)
관광-지역 96.10 % (296/308)
숙소-가격대 96.25 % (359/373)
숙소-도보 가능 85.71 % (12/14)
숙소-수영장 유무 83.33 % (5/6)
숙소-스파 유무 92.86 % (39/42)
숙소-예약 요일 94.67 % (355/375)
숙소-인터넷 가능 93.94 % (31/33)
숙소-조식 가능 100.00 % (38/38)
숙소-종류 97.79 % (354/362)
숙소-주차 가능 96.00 % (48/50)
숙소-지역 96.93 % (347/358)
숙소-헬스장 유무 91.89 % (34/37)
숙소-흡연 가능 91.43 % (32/35)
식당-가격대 98.55 % (339/344)
식당-도보 가능 100.00 % (4/4)
식당-야외석 유무 100.00 % (56/56)
식당-예약 요일 98.06 % (405/413)
식당-인터넷 가능 92.31 % (24/26)
식당-주류 판매 100.00 % (41/41)
식당-주차 가능 91.43 % (32/35)
식당-지역 97.06 % (330/340)
식당-흡연 가능 84.00 % (21/25)
택시-종류 98.60 % (212/215)


In [41]:
print('cat 오답률')
for cat in cat_ontology:
    if (cat_correct_dict[cat] + cat_wrong_dict[cat]) == 0:
        print(str(cat),str(0))
        continue
    percent = cat_wrong_dict[cat]/(cat_correct_dict[cat] + cat_wrong_dict[cat])*100
    print( str(cat), format(percent,".2f"),'%', '('+str(cat_wrong_dict[cat])+'/'+str(cat_correct_dict[cat] + cat_wrong_dict[cat])+')')


cat 오답률
관광-경치 좋은 0.00 % (0/39)
관광-교육적 0.00 % (0/18)
관광-도보 가능 0.00 % (0/1)
관광-문화 예술 6.25 % (1/16)
관광-역사적 1.43 % (1/70)
관광-주차 가능 5.56 % (2/36)
관광-지역 3.90 % (12/308)
숙소-가격대 3.75 % (14/373)
숙소-도보 가능 14.29 % (2/14)
숙소-수영장 유무 16.67 % (1/6)
숙소-스파 유무 7.14 % (3/42)
숙소-예약 요일 5.33 % (20/375)
숙소-인터넷 가능 6.06 % (2/33)
숙소-조식 가능 0.00 % (0/38)
숙소-종류 2.21 % (8/362)
숙소-주차 가능 4.00 % (2/50)
숙소-지역 3.07 % (11/358)
숙소-헬스장 유무 8.11 % (3/37)
숙소-흡연 가능 8.57 % (3/35)
식당-가격대 1.45 % (5/344)
식당-도보 가능 0.00 % (0/4)
식당-야외석 유무 0.00 % (0/56)
식당-예약 요일 1.94 % (8/413)
식당-인터넷 가능 7.69 % (2/26)
식당-주류 판매 0.00 % (0/41)
식당-주차 가능 8.57 % (3/35)
식당-지역 2.94 % (10/340)
식당-흡연 가능 16.00 % (4/25)
택시-종류 1.40 % (3/215)


In [42]:
print('noncat 정답률')
for cat in noncat_ontology:
    if noncat_correct_dict[cat] ==0:
        print(str(cat),' 0.00 %')
    else:
        percent = noncat_correct_dict[cat]/(noncat_correct_dict[cat] + noncat_wrong_dict[cat])*100
        print( str(cat), format(percent,".2f"),'%', '('+str(noncat_correct_dict[cat])+'/'+str(noncat_correct_dict[cat] + noncat_wrong_dict[cat])+')')


noncat 정답률
관광-이름 93.82 % (319/340)
관광-종류 98.67 % (297/301)
숙소-예약 기간 98.12 % (366/373)
숙소-예약 명수 98.08 % (358/365)
숙소-이름 96.25 % (359/373)
식당-예약 명수 98.18 % (377/384)
식당-예약 시간 96.60 % (398/412)
식당-이름 97.22 % (385/396)
식당-종류 99.41 % (337/339)
지하철-도착지 95.08 % (58/61)
지하철-출발 시간 81.82 % (9/11)
지하철-출발지 88.52 % (54/61)
택시-도착 시간 89.81 % (97/108)
택시-도착지 94.42 % (203/215)
택시-출발 시간 93.41 % (255/273)
택시-출발지 86.98 % (187/215)


In [43]:
print('noncat 오답률')

for cat in noncat_ontology:
    if noncat_correct_dict[cat] ==0:
        print(str(cat),' 0.00 %')
    else:
        percent = noncat_wrong_dict[cat]/(noncat_correct_dict[cat] + noncat_wrong_dict[cat])*100
        print( str(cat), format(percent,".2f"),'%', '('+str(noncat_wrong_dict[cat])+'/'+str(noncat_correct_dict[cat] + noncat_wrong_dict[cat])+')')

noncat 오답률
관광-이름 6.18 % (21/340)
관광-종류 1.33 % (4/301)
숙소-예약 기간 1.88 % (7/373)
숙소-예약 명수 1.92 % (7/365)
숙소-이름 3.75 % (14/373)
식당-예약 명수 1.82 % (7/384)
식당-예약 시간 3.40 % (14/412)
식당-이름 2.78 % (11/396)
식당-종류 0.59 % (2/339)
지하철-도착지 4.92 % (3/61)
지하철-출발 시간 18.18 % (2/11)
지하철-출발지 11.48 % (7/61)
택시-도착 시간 10.19 % (11/108)
택시-도착지 5.58 % (12/215)
택시-출발 시간 6.59 % (18/273)
택시-출발지 13.02 % (28/215)
