In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
os.chdir('/content/drive/MyDrive/Stage3/code')

In [3]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/b0/9e/5b80becd952d5f7250eaf8fc64b957077b12ccfe73e9c03d37146ab29712/transformers-4.6.0-py3-none-any.whl (2.3MB)
[K     |████████████████████████████████| 2.3MB 4.2MB/s 
Collecting huggingface-hub==0.0.8
  Downloading https://files.pythonhosted.org/packages/a1/88/7b1e45720ecf59c6c6737ff332f41c955963090a18e72acbcbeac6b25e86/huggingface_hub-0.0.8-py3-none-any.whl
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████████████| 901kB 48.0MB/s 
[?25hCollecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 49.2MB/s 
Installing c

In [4]:
import json
from tqdm import tqdm

from transformers import BertTokenizer
from data_utils import get_examples_from_dialogues, convert_state_dict, load_dataset
from data_utils import OntologyDSTFeature, DSTPreprocessor, _truncate_seq_pair,set_seed

## Data Loading 

In [5]:
set_seed(42)

In [6]:
train_data_file = "/content/drive/MyDrive/Stage3/input/data/train_dataset/train_dials.json"
slot_meta = json.load(open("/content/drive/MyDrive/Stage3/input/data/train_dataset/slot_meta.json"))
ontology = json.load(open("/content/drive/MyDrive/Stage3/input/data/train_dataset/new_ontology.json"))
train_data, dev_data, dev_labels = load_dataset(train_data_file)

In [7]:
for key in list(ontology.keys()):
    print(key,len(ontology[key]))

관광-경치 좋은 4
관광-교육적 4
관광-도보 가능 4
관광-문화 예술 4
관광-역사적 4
관광-이름 315
관광-종류 13
관광-주차 가능 4
관광-지역 7
숙소-가격대 5
숙소-도보 가능 4
숙소-수영장 유무 4
숙소-스파 유무 4
숙소-예약 기간 12
숙소-예약 명수 12
숙소-예약 요일 9
숙소-이름 315
숙소-인터넷 가능 4
숙소-조식 가능 4
숙소-종류 7
숙소-주차 가능 4
숙소-지역 7
숙소-헬스장 유무 4
숙소-흡연 가능 4
식당-가격대 5
식당-도보 가능 4
식당-야외석 유무 4
식당-예약 명수 12
식당-예약 시간 569
식당-예약 요일 9
식당-이름 315
식당-인터넷 가능 4
식당-종류 10
식당-주류 판매 4
식당-주차 가능 4
식당-지역 7
식당-흡연 가능 4
지하철-도착지 61
지하철-출발 시간 12
지하철-출발지 61
택시-도착 시간 190
택시-도착지 315
택시-종류 5
택시-출발 시간 431
택시-출발지 315


In [8]:
train_examples = get_examples_from_dialogues(data=train_data,
                                             user_first=True,
                                             dialogue_level=True)

dev_examples = get_examples_from_dialogues(data=dev_data,
                                           user_first=True,
                                           dialogue_level=True)

100%|██████████| 6301/6301 [00:00<00:00, 9295.94it/s]
100%|██████████| 699/699 [00:00<00:00, 13627.11it/s]


In [9]:
len(train_data)

6301

In [10]:
max_turn = max([len(e['dialogue']) for e in train_data])

In [11]:
tokenizer = BertTokenizer.from_pretrained('dsksd/bert-ko-small-minimal')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=263327.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=124.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=288.0, style=ProgressStyle(description_…




## TODO-1: SUMBT Preprocessor 정의 

Ontology-based DST model인 SUMBT의 InputFeature를 만들기 위한 Preprocessor를 정의해야 합니다. <br>

1. `_convert_examples_to_features` 함수의 빈칸을 매워 완성하세요.
2. `recover_state` 함수의 빈칸을 매워 완성하세요.

In [12]:
class SUMBTPreprocessor(DSTPreprocessor):
    def __init__(
        self,
        slot_meta,
        src_tokenizer,
        trg_tokenizer=None,
        ontology=None,
        max_seq_length=64,
        max_turn_length=14,
    ):
        self.slot_meta = slot_meta
        self.src_tokenizer = src_tokenizer
        self.trg_tokenizer = trg_tokenizer if trg_tokenizer else src_tokenizer
        self.ontology = ontology
        self.max_seq_length = max_seq_length
        self.max_turn_length = max_turn_length

    def _convert_example_to_feature(self, example):
        guid = example[0].guid.rsplit("-", 1)[0]  # dialogue_idx
        turns = []
        token_types = []
        labels = []
        num_turn = None
        for turn in example[: self.max_turn_length]:
            assert len(turn.current_turn) == 2
            uttrs = []
            for segment_idx, uttr in enumerate(turn.current_turn):
                token = self.src_tokenizer.encode(uttr, add_special_tokens=False)
                uttrs.append(token)

            _truncate_seq_pair(uttrs[0], uttrs[1], self.max_seq_length - 3)
            tokens = (
                [self.src_tokenizer.cls_token_id]
                + uttrs[0]
                + [self.src_tokenizer.sep_token_id]
                + uttrs[1]
                + [self.src_tokenizer.sep_token_id]
            )
            token_type = [0] * (len(uttrs[0]) + 2) + [1] * (len(uttrs[1]) + 1)
            if len(tokens) < self.max_seq_length:
                gap = self.max_seq_length - len(tokens)
                tokens.extend([self.src_tokenizer.pad_token_id] * gap)
                token_type.extend([0] * gap)
            turns.append(tokens)
            token_types.append(token_type)
            label = []
            if turn.label:
                slot_dict = convert_state_dict(turn.label)
            else:
                slot_dict = {}
            for slot_type in self.slot_meta:
                value = slot_dict.get(slot_type, "none")
                # TODO
                # raise Exception('label_idx를 ontology에서 꺼내오는 코드를 작성하세요!')
                if value in self.ontology[slot_type]:
                    label_idx = self.ontology[slot_type].index(value)
                else:
                    label_idx = self.ontology[slot_type].index("none")
                label.append(label_idx)
            labels.append(label)
        num_turn = len(turns)
        if len(turns) < self.max_turn_length:
            gap = self.max_turn_length - len(turns)
            for _ in range(gap):
                dummy_turn = [self.src_tokenizer.pad_token_id] * self.max_seq_length
                turns.append(dummy_turn)
                token_types.append(dummy_turn)
                dummy_label = [-1] * len(self.slot_meta)
                labels.append(dummy_label)
        return OntologyDSTFeature(
            guid=guid,
            input_ids=turns,
            segment_ids=token_types,
            num_turn=num_turn,
            target_ids=labels,
        )

    def convert_examples_to_features(self, examples):
        return list(map(self._convert_example_to_feature, examples))

    def recover_state(self, pred_slots, num_turn):
        states = []
        for pred_slot in pred_slots[:num_turn]:
            state = []
            for s, p in zip(self.slot_meta, pred_slot):
                v = self.ontology[s][p]
                if v != "none":
                    state.append(f"{s}-{v}")
            states.append(state)
        return states

    def collate_fn(self, batch):
        guids = [b.guid for b in batch]
        input_ids = torch.LongTensor([b.input_ids for b in batch])
        segment_ids = torch.LongTensor([b.segment_ids for b in batch])
        input_masks = input_ids.ne(self.src_tokenizer.pad_token_id)
        target_ids = torch.LongTensor([b.target_ids for b in batch])
        num_turns = [b.num_turn for b in batch]
        return input_ids, segment_ids, input_masks, target_ids, num_turns, guids

## Convert_Examples_to_Features 

In [13]:
processor = SUMBTPreprocessor(slot_meta,
                              tokenizer,
                              ontology=ontology,  # predefined ontology
                              max_seq_length=64,  # 각 turn마다 최대 길이
                              max_turn_length=max_turn)  # 각 dialogue의 최대 turn 길이
train_features = processor.convert_examples_to_features(train_examples)
dev_features = processor.convert_examples_to_features(dev_examples)

In [14]:
print(len(train_features))  # 대화 level의 features
print(len(dev_features))

6301
699


## SUMBT 모델 선언 

In [15]:
"""
Most of code is from https://github.com/SKTBrain/SUMBT
"""

import math
import os.path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CosineEmbeddingLoss, CrossEntropyLoss
from transformers import BertModel, BertPreTrainedModel


class BertForUtteranceEncoding(BertPreTrainedModel):
    def __init__(self, config):
        super(BertForUtteranceEncoding, self).__init__(config)

        self.config = config
        self.bert = BertModel(config)

    def forward(self, input_ids, token_type_ids, attention_mask):
        return self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            output_attentions=False,
            output_hidden_states=False,
            return_dict=False,
        )


class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1):
        super().__init__()

        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(d_model, d_model)

        self.scores = None

    def attention(self, q, k, v, d_k, mask=None, dropout=None):

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)
        scores = F.softmax(scores, dim=-1)

        if dropout is not None:
            scores = dropout(scores)

        self.scores = scores
        output = torch.matmul(scores, v)
        return output

    def forward(self, q, k, v, mask=None):
        bs = q.size(0)

        # perform linear operation and split into h heads
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)

        # transpose to get dimensions bs * h * sl * d_model
        k = k.transpose(1, 2)
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)

        scores = self.attention(q, k, v, self.d_k, mask, self.dropout)

        # concatenate heads and put through final linear layer
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output

    def get_scores(self):
        return self.scores


class SUMBT(nn.Module):
    def __init__(self, args, num_labels, device):
        super(SUMBT, self).__init__()

        self.hidden_dim = args.hidden_dim
        self.rnn_num_layers = args.num_rnn_layers
        self.zero_init_rnn = args.zero_init_rnn
        self.max_seq_length = args.max_seq_length
        self.max_label_length = args.max_label_length
        self.num_labels = num_labels
        self.num_slots = len(num_labels)
        self.attn_head = args.attn_head
        self.device = device

        ### Utterance Encoder
        self.utterance_encoder = BertForUtteranceEncoding.from_pretrained(
            args.model_name_or_path
        )
        self.bert_output_dim = self.utterance_encoder.config.hidden_size
        self.hidden_dropout_prob = self.utterance_encoder.config.hidden_dropout_prob
        if args.fix_utterance_encoder:
            for p in self.utterance_encoder.bert.pooler.parameters():
                p.requires_grad = False

        ### slot, slot-value Encoder (not trainable)
        self.sv_encoder = BertForUtteranceEncoding.from_pretrained(
            args.model_name_or_path
        )
        # os.path.join(args.bert_dir, 'bert-base-uncased.model'))
        for p in self.sv_encoder.bert.parameters():
            p.requires_grad = False

        self.slot_lookup = nn.Embedding(self.num_slots, self.bert_output_dim)
        self.value_lookup = nn.ModuleList(
            [nn.Embedding(num_label, self.bert_output_dim) for num_label in num_labels]
        )

        ### Attention layer
        self.attn = MultiHeadAttention(self.attn_head, self.bert_output_dim, dropout=0)

        ### RNN Belief Tracker
        self.nbt = nn.GRU(
            input_size=self.bert_output_dim,
            hidden_size=self.hidden_dim,
            num_layers=self.rnn_num_layers,
            dropout=self.hidden_dropout_prob,
            batch_first=True,
            bidirectional = True
        )
        self.init_parameter(self.nbt)

        if not self.zero_init_rnn:
            self.rnn_init_linear = nn.Sequential(
                nn.Linear(self.bert_output_dim, self.hidden_dim),
                nn.ReLU(),
                nn.Dropout(self.hidden_dropout_prob),
            )

        self.linear = nn.Linear(self.hidden_dim, self.bert_output_dim)
        self.layer_norm = nn.LayerNorm(self.bert_output_dim)

        ### Measure
        self.metric = torch.nn.PairwiseDistance(p=2.0, eps=1e-06, keepdim=False)

        ### Classifier
        self.nll = CrossEntropyLoss(ignore_index=-1)
        #self.nll = torch.nn.HingeEmbeddingLoss(margin = 0.5,reduction='mean')
        ### Etc.
        self.dropout = nn.Dropout(self.hidden_dropout_prob)

    def initialize_slot_value_lookup(self, label_ids, slot_ids):

        self.sv_encoder.eval()

        # Slot encoding
        slot_type_ids = torch.zeros(slot_ids.size(), dtype=torch.long).to(
            slot_ids.device
        )
        slot_mask = slot_ids > 0
        hid_slot, _ = self.sv_encoder(
            slot_ids.view(-1, self.max_label_length),
            slot_type_ids.view(-1, self.max_label_length),
            slot_mask.view(-1, self.max_label_length),
        )
        hid_slot = hid_slot[:, 0, :]
        hid_slot = hid_slot.detach()
        self.slot_lookup = nn.Embedding.from_pretrained(hid_slot, freeze=True)

        for s, label_id in enumerate(label_ids):
            label_type_ids = torch.zeros(label_id.size(), dtype=torch.long).to(
                label_id.device
            )
            label_mask = label_id > 0
            hid_label, _ = self.sv_encoder(
                label_id.view(-1, self.max_label_length),
                label_type_ids.view(-1, self.max_label_length),
                label_mask.view(-1, self.max_label_length),
            )
            hid_label = hid_label[:, 0, :]
            hid_label = hid_label.detach()
            self.value_lookup[s] = nn.Embedding.from_pretrained(hid_label, freeze=True)
            self.value_lookup[s].padding_idx = -1

        print("Complete initialization of slot and value lookup")
        self.sv_encoder = None

    def forward(
        self,
        input_ids,
        token_type_ids,
        attention_mask,
        labels=None,
        n_gpu=1,
        target_slot=None,
    ):
        # input_ids: [B, M, N]
        # token_type_ids: [B, M, N]
        # attention_mask: [B, M, N]
        # labels: [B, M, J]

        # if target_slot is not specified, output values corresponding all slot-types
        if target_slot is None:
            target_slot = list(range(0, self.num_slots))
        
                                # Max sequence (N)
        ds = input_ids.size(0)  # Batch size (B)
        ts = input_ids.size(1)  # Max turn size (M)
        bs = ds * ts            # B*M
        slot_dim = len(target_slot)  # J


        # Utterance encoding
        hidden, _ = self.utterance_encoder(
            input_ids.view(-1, self.max_seq_length),
            token_type_ids.view(-1, self.max_seq_length),
            attention_mask.view(-1, self.max_seq_length),
        )
        hidden = torch.mul(
            hidden,
            attention_mask.view(-1, self.max_seq_length, 1)
            .expand(hidden.size())
            .float(),
        )
        hidden = hidden.repeat(slot_dim, 1, 1)  # [J*M*B, N, H] [6120,64,768]

        hid_slot = self.slot_lookup.weight[
            target_slot, :
        ]  # Select target slot embedding #[45,768]
        hid_slot = hid_slot.repeat(1, bs).view(bs * slot_dim, -1)  # [J*M*B, N, H] [6120,768]

        # Attended utterance vector
        hidden = self.attn(
            hid_slot,  # q^s  [J*M*B, N, H]
            hidden,  # U [J*M*B, N, H]
            hidden,  # U [J*M*B, N, H]
            mask=attention_mask.view(-1, 1, self.max_seq_length).repeat(slot_dim, 1, 1),
        )
        hidden = hidden.squeeze()  # h [J*M*B, H] Aggregated Slot Context
        hidden = hidden.view(slot_dim, ds, ts, -1).view(
            -1, ts, self.bert_output_dim
        )  # [J*B, M, H]

        # NBT
        if self.zero_init_rnn:
            h = torch.zeros(
                self.rnn_num_layers, input_ids.shape[0] * slot_dim, self.hidden_dim
            ).to(
                self.device
            )  # [1, slot_dim*ds, hidden]
        else:
            h = hidden[:, 0, :].unsqueeze(0).repeat(self.rnn_num_layers, 1, 1)
            h = self.rnn_init_linear(h)

        if isinstance(self.nbt, nn.GRU):
            h = h.repeat(2, 1, 1)
            rnn_out, _ = self.nbt(hidden, h)  # [J*B, M, H_GRU]
            
            
        elif isinstance(self.nbt, nn.LSTM):
            c = torch.zeros(
                self.rnn_num_layers, input_ids.shape[0] * slot_dim, self.hidden_dim
            ).to(
                self.device
            )  # [1, slot_dim*ds, hidden]
            rnn_out, _ = self.nbt(hidden, (h, c))  # [slot_dim*ds, turn, hidden]
        
        
        rnn_out = rnn_out[:,:,:self.hidden_dim] +rnn_out[:,:,self.hidden_dim:] 
        rnn_out = self.layer_norm(self.linear(self.dropout(rnn_out)))

        hidden = rnn_out.view(slot_dim, ds, ts, -1)  # [J, B, M, H_GRU]


        # Label (slot-value) encoding
        loss = 0
        loss_slot = []
        pred_slot = []
        output = []
        for s, slot_id in enumerate(target_slot):  ## note: target_slots are successive
            # loss calculation
            hid_label = self.value_lookup[slot_id].weight
            num_slot_labels = hid_label.size(0)

            _hid_label = (
                hid_label.unsqueeze(0)
                .unsqueeze(0)
                .repeat(ds, ts, 1, 1)
                .view(ds * ts * num_slot_labels, -1)
            )
            _hidden = (
                hidden[s, :, :, :]
                .unsqueeze(2)
                .repeat(1, 1, num_slot_labels, 1)
                .view(ds * ts * num_slot_labels, -1)
            )
            _dist = self.metric(_hid_label, _hidden).view(ds, ts, num_slot_labels)
            _dist = -_dist
            _, pred = torch.max(_dist, -1)
            pred_slot.append(pred.view(ds, ts, 1))
            output.append(_dist)

            if labels is not None:
                _loss = self.nll(_dist.view(ds * ts, -1), labels[:, :, s].view(-1))
                loss_slot.append(_loss.item())
                loss += _loss

        pred_slot = torch.cat(pred_slot, 2)
        if labels is None:
            return output, pred_slot

        # calculate joint accuracy
        accuracy = (pred_slot == labels).view(-1, slot_dim)
        acc_slot = (
            torch.sum(accuracy, 0).float()
            / torch.sum(labels.view(-1, slot_dim) > -1, 0).float()
        )
        acc = (
            sum(torch.sum(accuracy, 1) / slot_dim).float()
            / torch.sum(labels[:, :, 0].view(-1) > -1, 0).float()
        )  # joint accuracy

        if n_gpu == 1:
            return loss, loss_slot, acc, acc_slot, pred_slot
        else:
            return (
                loss.unsqueeze(0),
                None,
                acc.unsqueeze(0),
                acc_slot.unsqueeze(0),
                pred_slot.unsqueeze(0),
            )

    @staticmethod
    def init_parameter(module):
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_normal_(module.weight)
            torch.nn.init.constant_(module.bias, 0.0)
        elif isinstance(module, nn.GRU) or isinstance(module, nn.LSTM):
            torch.nn.init.xavier_normal_(module.weight_ih_l0)
            torch.nn.init.xavier_normal_(module.weight_hh_l0)
            torch.nn.init.constant_(module.bias_ih_l0, 0.0)
            torch.nn.init.constant_(module.bias_hh_l0, 0.0)

## TODO-2: Ontology Pre-Encoding 

Ontology의 slot type들과 이에 속하는 slot_value들을 tokenizing하는 `tokenize_ontology`를 작성하세요. <br>
[CLS] Pooling하여 `slot_lookup` 과 `value_lookup` embedding matrix들을 초기화하는 <br>
`initialize_slot_value_lookup`에 인자로 넘겨주세요. <br>

In [16]:
def tokenize_ontology(ontology, tokenizer, max_seq_length=12):
    slot_types = []
    slot_values = []
    for k, v in ontology.items():
        tokens = tokenizer.encode(k)
        if len(tokens) < max_seq_length:
            gap = max_seq_length - len(tokens)
            tokens.extend([tokenizer.pad_token_id] *  gap)
        slot_types.append(tokens)
        slot_value = []
        for vv in v:
            tokens = tokenizer.encode(vv)
            if len(tokens) < max_seq_length:
                gap = max_seq_length - len(tokens)
                tokens.extend([tokenizer.pad_token_id] *  gap)
            slot_value.append(tokens)
        slot_values.append(torch.LongTensor(slot_value))
    return torch.LongTensor(slot_types), slot_values

In [17]:
slot_type_ids, slot_values_ids = tokenize_ontology(ontology, tokenizer, 12)
num_labels = [len(s) for s in slot_values_ids]  # 각 Slot 별 후보 Values의 갯수

print("Tokenized Slot: ", slot_type_ids.size())
for slot, slot_value_id in zip(slot_meta, slot_values_ids):
    print(f"Tokenized Value of {slot}", slot_value_id.size())

Tokenized Slot:  torch.Size([45, 12])
Tokenized Value of 관광-경치 좋은 torch.Size([4, 12])
Tokenized Value of 관광-교육적 torch.Size([4, 12])
Tokenized Value of 관광-도보 가능 torch.Size([4, 12])
Tokenized Value of 관광-문화 예술 torch.Size([4, 12])
Tokenized Value of 관광-역사적 torch.Size([4, 12])
Tokenized Value of 관광-이름 torch.Size([315, 12])
Tokenized Value of 관광-종류 torch.Size([13, 12])
Tokenized Value of 관광-주차 가능 torch.Size([4, 12])
Tokenized Value of 관광-지역 torch.Size([7, 12])
Tokenized Value of 숙소-가격대 torch.Size([5, 12])
Tokenized Value of 숙소-도보 가능 torch.Size([4, 12])
Tokenized Value of 숙소-수영장 유무 torch.Size([4, 12])
Tokenized Value of 숙소-스파 유무 torch.Size([4, 12])
Tokenized Value of 숙소-예약 기간 torch.Size([12, 12])
Tokenized Value of 숙소-예약 명수 torch.Size([12, 12])
Tokenized Value of 숙소-예약 요일 torch.Size([9, 12])
Tokenized Value of 숙소-이름 torch.Size([315, 12])
Tokenized Value of 숙소-인터넷 가능 torch.Size([4, 12])
Tokenized Value of 숙소-조식 가능 torch.Size([4, 12])
Tokenized Value of 숙소-종류 torch.Size([7, 12])
Tokenized Valu

## Model 선언 

In [18]:
from argparse import Namespace

args = {
    'hidden_dim': 300,
    'num_rnn_layers': 1,
    'zero_init_rnn': False,
    'max_seq_length': 64,
    'max_label_length': 12,
    'attn_head': 4,
    'fix_utterance_encoder': False,
    'task_name': 'sumbtgru',
    'distance_metric': 'euclidean',
    'model_name_or_path': 'dsksd/bert-ko-small-minimal',
    'warmup_ratio': 0.1,
    'learning_rate': 5e-5,
    'weight_decay': 0.01,
    'num_train_epochs': 50
}

args = Namespace(**args)

num_labels = [len(s) for s in slot_values_ids]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_gpu = 1 if torch.cuda.device_count() < 2 else torch.cuda.device_count()
n_epochs = args.num_train_epochs

In [19]:
model = SUMBT(args, num_labels, device)
model.initialize_slot_value_lookup(slot_values_ids, slot_type_ids)  # Tokenized Ontology의 Pre-encoding using BERT_SV
model.to(device)
print()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=434.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=284118515.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at dsksd/bert-ko-small-minimal were not used when initializing BertForUtteranceEncoding: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForUtteranceEncoding from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForUtteranceEncoding from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at dsksd/bert-ko-small-minimal 

Complete initialization of slot and value lookup



## 데이터 로더 정의

In [20]:
from data_utils import WOSDataset
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import AdamW, get_linear_schedule_with_warmup
import random


train_data = WOSDataset(train_features)
train_sampler = RandomSampler(train_data)
train_loader = DataLoader(train_data, batch_size=4, sampler=train_sampler, collate_fn=processor.collate_fn)

dev_data = WOSDataset(dev_features)
dev_sampler = SequentialSampler(dev_data)
dev_loader = DataLoader(dev_data, batch_size=4, sampler=dev_sampler, collate_fn=processor.collate_fn)

## Optimizer & Scheduler 선언 

In [21]:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

t_total = len(train_loader) * n_epochs
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=int(t_total * args.warmup_ratio), num_training_steps=t_total
)

## TODO-3: Inference code 작성 

In [22]:
from evaluation import _evaluation

In [23]:
def inference(model, eval_loader, processor, device):
    model.eval()
    predictions = {}
    for batch in tqdm(eval_loader):
        input_ids, segment_ids, input_masks, target_ids, num_turns, guids = \
        [b.to(device) if not isinstance(b, list) else b for b in batch]

        with torch.no_grad():
            _, pred_slot = model(
                input_ids, segment_ids, input_masks, labels=None, n_gpu=1
            )
        
        batch_size = input_ids.size(0)
        for i in range(batch_size):
            guid = guids[i]
            states = processor.recover_state(pred_slot.tolist()[i], num_turns[i])
            for tid, state in enumerate(states):
                predictions[f"{guid}-{tid}"] = state
    return predictions

## Training 

In [24]:
best_score, best_checkpoint = 0, 0
for epoch in range(n_epochs):
    batch_loss = []
    model.train()
    for step, batch in enumerate(train_loader):
        input_ids, segment_ids, input_masks, target_ids, num_turns, guids  = \
        [b.to(device) if not isinstance(b, list) else b for b in batch]

        # Forward
        if n_gpu == 1:
            loss, loss_slot, acc, acc_slot, _ = model(input_ids, segment_ids, input_masks, target_ids, n_gpu)
        else:
            loss, _, acc, acc_slot, _ = model(input_ids, segment_ids, input_masks, target_ids, n_gpu)
        
        batch_loss.append(loss.item())

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        if step % 100 == 0:
            print('[%d/%d] [%d/%d] %f' % (epoch, n_epochs, step, len(train_loader), loss.item()))
    
    predictions = inference(model, dev_loader, processor, device)
    eval_result = _evaluation(predictions, dev_labels, slot_meta)
    score = eval_result['joint_goal_accuracy']
    if score > best_score:
        cnt = 0
        best_score = score
        torch.save(model.state_dict(), "/content/drive/MyDrive/Stage3/model/sumbt_ontology_50_best.pt")
    
    for k, v in eval_result.items():
        print(f"{k}: {v}")

[0/50] [0/1576] 120.948158
[0/50] [100/1576] 83.993591
[0/50] [200/1576] 49.601704
[0/50] [300/1576] 42.578835
[0/50] [400/1576] 38.138355
[0/50] [500/1576] 44.109894
[0/50] [600/1576] 33.930744
[0/50] [700/1576] 33.876492
[0/50] [800/1576] 34.609066
[0/50] [900/1576] 35.289425
[0/50] [1000/1576] 35.967735
[0/50] [1100/1576] 31.085178
[0/50] [1200/1576] 38.174740
[0/50] [1300/1576] 36.902851
[0/50] [1400/1576] 38.409119
[0/50] [1500/1576] 23.106920


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.019310344827586208, 'turn_slot_accuracy': 0.8373858784893291, 'turn_slot_f1': 0.189622577751179}
joint_goal_accuracy: 0.019310344827586208
turn_slot_accuracy: 0.8373858784893291
turn_slot_f1: 0.189622577751179
[1/50] [0/1576] 26.354023
[1/50] [100/1576] 33.297401
[1/50] [200/1576] 28.987619
[1/50] [300/1576] 28.825508
[1/50] [400/1576] 26.056416
[1/50] [500/1576] 24.234215
[1/50] [600/1576] 29.772575
[1/50] [700/1576] 27.506531
[1/50] [800/1576] 22.085443
[1/50] [900/1576] 21.196184
[1/50] [1000/1576] 22.520370
[1/50] [1100/1576] 24.395193
[1/50] [1200/1576] 20.529081
[1/50] [1300/1576] 25.055576
[1/50] [1400/1576] 20.322737
[1/50] [1500/1576] 20.760452


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.06660098522167487, 'turn_slot_accuracy': 0.8812172961138466, 'turn_slot_f1': 0.41184134338515915}
joint_goal_accuracy: 0.06660098522167487
turn_slot_accuracy: 0.8812172961138466
turn_slot_f1: 0.41184134338515915
[2/50] [0/1576] 27.683670
[2/50] [100/1576] 16.465807
[2/50] [200/1576] 22.589300
[2/50] [300/1576] 18.448269
[2/50] [400/1576] 17.186087
[2/50] [500/1576] 20.051168
[2/50] [600/1576] 17.925510
[2/50] [700/1576] 20.456566
[2/50] [800/1576] 18.852623
[2/50] [900/1576] 17.894077
[2/50] [1000/1576] 14.928161
[2/50] [1100/1576] 12.221363
[2/50] [1200/1576] 16.562441
[2/50] [1300/1576] 18.330002
[2/50] [1400/1576] 12.005337
[2/50] [1500/1576] 13.126920


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.2155665024630542, 'turn_slot_accuracy': 0.9409261083743911, 'turn_slot_f1': 0.729808903717254}
joint_goal_accuracy: 0.2155665024630542
turn_slot_accuracy: 0.9409261083743911
turn_slot_f1: 0.729808903717254
[3/50] [0/1576] 13.419750
[3/50] [100/1576] 9.211822
[3/50] [200/1576] 7.515280
[3/50] [300/1576] 10.151665
[3/50] [400/1576] 9.142398
[3/50] [500/1576] 7.220044
[3/50] [600/1576] 7.863735
[3/50] [700/1576] 7.490903
[3/50] [800/1576] 8.146549
[3/50] [900/1576] 4.935915
[3/50] [1000/1576] 6.583706
[3/50] [1100/1576] 3.943975
[3/50] [1200/1576] 10.462150
[3/50] [1300/1576] 8.051355
[3/50] [1400/1576] 7.809869
[3/50] [1500/1576] 5.739918


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.42935960591133004, 'turn_slot_accuracy': 0.9741302681992444, 'turn_slot_f1': 0.891371473695091}
joint_goal_accuracy: 0.42935960591133004
turn_slot_accuracy: 0.9741302681992444
turn_slot_f1: 0.891371473695091
[4/50] [0/1576] 7.012291
[4/50] [100/1576] 4.190318
[4/50] [200/1576] 3.450305
[4/50] [300/1576] 10.053491
[4/50] [400/1576] 4.613595
[4/50] [500/1576] 8.702610
[4/50] [600/1576] 8.506054
[4/50] [700/1576] 6.724343
[4/50] [800/1576] 5.470899
[4/50] [900/1576] 6.072900
[4/50] [1000/1576] 4.366434
[4/50] [1100/1576] 2.251687
[4/50] [1200/1576] 5.348655
[4/50] [1300/1576] 4.699886
[4/50] [1400/1576] 4.109402
[4/50] [1500/1576] 2.487442


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.6139901477832512, 'turn_slot_accuracy': 0.9862769567597237, 'turn_slot_f1': 0.9471783404891799}
joint_goal_accuracy: 0.6139901477832512
turn_slot_accuracy: 0.9862769567597237
turn_slot_f1: 0.9471783404891799
[5/50] [0/1576] 4.044650
[5/50] [100/1576] 6.169325
[5/50] [200/1576] 5.474092
[5/50] [300/1576] 2.904697
[5/50] [400/1576] 3.255375
[5/50] [500/1576] 4.790680
[5/50] [600/1576] 2.826751
[5/50] [700/1576] 4.346072
[5/50] [800/1576] 4.809063
[5/50] [900/1576] 2.120707
[5/50] [1000/1576] 3.213801
[5/50] [1100/1576] 3.868632
[5/50] [1200/1576] 2.959446
[5/50] [1300/1576] 3.034247
[5/50] [1400/1576] 3.833740
[5/50] [1500/1576] 2.240196


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.745024630541872, 'turn_slot_accuracy': 0.9916978653530457, 'turn_slot_f1': 0.9673230513570322}
joint_goal_accuracy: 0.745024630541872
turn_slot_accuracy: 0.9916978653530457
turn_slot_f1: 0.9673230513570322
[6/50] [0/1576] 2.179711
[6/50] [100/1576] 1.620479
[6/50] [200/1576] 1.910037
[6/50] [300/1576] 4.303927
[6/50] [400/1576] 1.509257
[6/50] [500/1576] 1.202067
[6/50] [600/1576] 3.882437
[6/50] [700/1576] 1.465084
[6/50] [800/1576] 2.233501
[6/50] [900/1576] 1.048862
[6/50] [1000/1576] 1.288480
[6/50] [1100/1576] 1.478171
[6/50] [1200/1576] 7.490780
[6/50] [1300/1576] 1.949910
[6/50] [1400/1576] 1.037042
[6/50] [1500/1576] 0.653341


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.7617733990147784, 'turn_slot_accuracy': 0.992218938149978, 'turn_slot_f1': 0.9694636392750372}
joint_goal_accuracy: 0.7617733990147784
turn_slot_accuracy: 0.992218938149978
turn_slot_f1: 0.9694636392750372
[7/50] [0/1576] 1.073775
[7/50] [100/1576] 1.395098
[7/50] [200/1576] 1.150813
[7/50] [300/1576] 1.418169
[7/50] [400/1576] 1.277105
[7/50] [500/1576] 1.202208
[7/50] [600/1576] 1.944877
[7/50] [700/1576] 2.314940
[7/50] [800/1576] 0.926401
[7/50] [900/1576] 1.096191
[7/50] [1000/1576] 2.717210
[7/50] [1100/1576] 0.887446
[7/50] [1200/1576] 1.451278
[7/50] [1300/1576] 3.324433
[7/50] [1400/1576] 1.490813
[7/50] [1500/1576] 3.554758


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.7785221674876848, 'turn_slot_accuracy': 0.9927969348659077, 'turn_slot_f1': 0.9734844424278735}
joint_goal_accuracy: 0.7785221674876848
turn_slot_accuracy: 0.9927969348659077
turn_slot_f1: 0.9734844424278735
[8/50] [0/1576] 1.958710
[8/50] [100/1576] 1.489389
[8/50] [200/1576] 1.795569
[8/50] [300/1576] 1.186861
[8/50] [400/1576] 1.306414
[8/50] [500/1576] 0.977671
[8/50] [600/1576] 4.807040
[8/50] [700/1576] 2.511413
[8/50] [800/1576] 1.771425
[8/50] [900/1576] 1.717062
[8/50] [1000/1576] 1.779273
[8/50] [1100/1576] 1.063704
[8/50] [1200/1576] 0.612353
[8/50] [1300/1576] 3.638254
[8/50] [1400/1576] 1.179016
[8/50] [1500/1576] 0.974418


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.7633497536945812, 'turn_slot_accuracy': 0.9924378762999503, 'turn_slot_f1': 0.9710860874856563}
joint_goal_accuracy: 0.7633497536945812
turn_slot_accuracy: 0.9924378762999503
turn_slot_f1: 0.9710860874856563
[9/50] [0/1576] 0.953971
[9/50] [100/1576] 0.429058
[9/50] [200/1576] 0.583896
[9/50] [300/1576] 1.002380
[9/50] [400/1576] 3.163256
[9/50] [500/1576] 2.782330
[9/50] [600/1576] 0.737215
[9/50] [700/1576] 1.658939
[9/50] [800/1576] 0.643321
[9/50] [900/1576] 0.594602
[9/50] [1000/1576] 1.530064
[9/50] [1100/1576] 0.601877
[9/50] [1200/1576] 1.943365
[9/50] [1300/1576] 1.659296
[9/50] [1400/1576] 0.509267
[9/50] [1500/1576] 1.357778


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8019704433497536, 'turn_slot_accuracy': 0.9938084291187809, 'turn_slot_f1': 0.9762289796812805}
joint_goal_accuracy: 0.8019704433497536
turn_slot_accuracy: 0.9938084291187809
turn_slot_f1: 0.9762289796812805
[10/50] [0/1576] 0.461749
[10/50] [100/1576] 1.933189
[10/50] [200/1576] 1.210933
[10/50] [300/1576] 0.622292
[10/50] [400/1576] 1.096686
[10/50] [500/1576] 2.210538
[10/50] [600/1576] 0.830535
[10/50] [700/1576] 1.076482
[10/50] [800/1576] 0.675930
[10/50] [900/1576] 0.480209
[10/50] [1000/1576] 0.406419
[10/50] [1100/1576] 2.095666
[10/50] [1200/1576] 0.739916
[10/50] [1300/1576] 1.078418
[10/50] [1400/1576] 0.576975
[10/50] [1500/1576] 0.976559


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.7992118226600985, 'turn_slot_accuracy': 0.9936683087027975, 'turn_slot_f1': 0.9753889033866485}
joint_goal_accuracy: 0.7992118226600985
turn_slot_accuracy: 0.9936683087027975
turn_slot_f1: 0.9753889033866485
[11/50] [0/1576] 1.421691
[11/50] [100/1576] 1.029471
[11/50] [200/1576] 0.838865
[11/50] [300/1576] 0.462063
[11/50] [400/1576] 1.989713
[11/50] [500/1576] 0.838436
[11/50] [600/1576] 1.227607
[11/50] [700/1576] 0.851277
[11/50] [800/1576] 0.543709
[11/50] [900/1576] 0.446252
[11/50] [1000/1576] 1.264720
[11/50] [1100/1576] 0.716212
[11/50] [1200/1576] 0.310171
[11/50] [1300/1576] 1.178609
[11/50] [1400/1576] 0.564783
[11/50] [1500/1576] 0.325183


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.7968472906403941, 'turn_slot_accuracy': 0.993685823754795, 'turn_slot_f1': 0.9766873397810825}
joint_goal_accuracy: 0.7968472906403941
turn_slot_accuracy: 0.993685823754795
turn_slot_f1: 0.9766873397810825
[12/50] [0/1576] 1.229933
[12/50] [100/1576] 0.732216
[12/50] [200/1576] 0.579562
[12/50] [300/1576] 0.624969
[12/50] [400/1576] 1.037171
[12/50] [500/1576] 2.742226
[12/50] [600/1576] 0.199665
[12/50] [700/1576] 1.035104
[12/50] [800/1576] 1.376810
[12/50] [900/1576] 0.187428
[12/50] [1000/1576] 0.537634
[12/50] [1100/1576] 0.253449
[12/50] [1200/1576] 0.681509
[12/50] [1300/1576] 0.658958
[12/50] [1400/1576] 0.565419
[12/50] [1500/1576] 0.337867


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8019704433497536, 'turn_slot_accuracy': 0.9939003831417661, 'turn_slot_f1': 0.9762220328574474}
joint_goal_accuracy: 0.8019704433497536
turn_slot_accuracy: 0.9939003831417661
turn_slot_f1: 0.9762220328574474
[13/50] [0/1576] 0.828101
[13/50] [100/1576] 3.739276
[13/50] [200/1576] 0.445329
[13/50] [300/1576] 0.574834
[13/50] [400/1576] 0.258291
[13/50] [500/1576] 0.262845
[13/50] [600/1576] 1.297638
[13/50] [700/1576] 0.509626
[13/50] [800/1576] 0.455770
[13/50] [900/1576] 0.238720
[13/50] [1000/1576] 0.821735
[13/50] [1100/1576] 0.526534
[13/50] [1200/1576] 0.485354
[13/50] [1300/1576] 0.446046
[13/50] [1400/1576] 0.423734
[13/50] [1500/1576] 0.751579


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8092610837438423, 'turn_slot_accuracy': 0.9940492610837489, 'turn_slot_f1': 0.9765952377166806}
joint_goal_accuracy: 0.8092610837438423
turn_slot_accuracy: 0.9940492610837489
turn_slot_f1: 0.9765952377166806
[14/50] [0/1576] 0.693574
[14/50] [100/1576] 0.450499
[14/50] [200/1576] 0.976082
[14/50] [300/1576] 0.339480
[14/50] [400/1576] 0.375163
[14/50] [500/1576] 0.821353
[14/50] [600/1576] 1.360358
[14/50] [700/1576] 0.715701
[14/50] [800/1576] 1.693540
[14/50] [900/1576] 0.405342
[14/50] [1000/1576] 1.365182
[14/50] [1100/1576] 2.625272
[14/50] [1200/1576] 0.336741
[14/50] [1300/1576] 0.445471
[14/50] [1400/1576] 0.942063
[14/50] [1500/1576] 0.408730


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.7980295566502463, 'turn_slot_accuracy': 0.9937383689107899, 'turn_slot_f1': 0.9755651148391359}
joint_goal_accuracy: 0.7980295566502463
turn_slot_accuracy: 0.9937383689107899
turn_slot_f1: 0.9755651148391359
[15/50] [0/1576] 1.640350
[15/50] [100/1576] 0.725890
[15/50] [200/1576] 0.502044
[15/50] [300/1576] 0.215404
[15/50] [400/1576] 0.160497
[15/50] [500/1576] 0.181979
[15/50] [600/1576] 0.285837
[15/50] [700/1576] 1.836848
[15/50] [800/1576] 0.651034
[15/50] [900/1576] 0.254886
[15/50] [1000/1576] 0.275133
[15/50] [1100/1576] 0.421076
[15/50] [1200/1576] 1.003041
[15/50] [1300/1576] 0.399073
[15/50] [1400/1576] 0.372873
[15/50] [1500/1576] 0.511223


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.809064039408867, 'turn_slot_accuracy': 0.9942813355227214, 'turn_slot_f1': 0.9778675252979697}
joint_goal_accuracy: 0.809064039408867
turn_slot_accuracy: 0.9942813355227214
turn_slot_f1: 0.9778675252979697
[16/50] [0/1576] 0.365869
[16/50] [100/1576] 0.415417
[16/50] [200/1576] 0.273384
[16/50] [300/1576] 0.492262
[16/50] [400/1576] 0.188497
[16/50] [500/1576] 0.261792
[16/50] [600/1576] 0.828434
[16/50] [700/1576] 0.262185
[16/50] [800/1576] 0.402803
[16/50] [900/1576] 1.621202
[16/50] [1000/1576] 0.294050
[16/50] [1100/1576] 0.088443
[16/50] [1200/1576] 0.770363
[16/50] [1300/1576] 0.134937
[16/50] [1400/1576] 0.302199
[16/50] [1500/1576] 0.307183


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8120197044334976, 'turn_slot_accuracy': 0.9943251231527139, 'turn_slot_f1': 0.9780103186692197}
joint_goal_accuracy: 0.8120197044334976
turn_slot_accuracy: 0.9943251231527139
turn_slot_f1: 0.9780103186692197
[17/50] [0/1576] 0.433918
[17/50] [100/1576] 0.915546
[17/50] [200/1576] 0.285892
[17/50] [300/1576] 0.679107
[17/50] [400/1576] 0.369781
[17/50] [500/1576] 0.281343
[17/50] [600/1576] 0.173833
[17/50] [700/1576] 0.342891
[17/50] [800/1576] 0.407962
[17/50] [900/1576] 0.207259
[17/50] [1000/1576] 0.154809
[17/50] [1100/1576] 0.188722
[17/50] [1200/1576] 0.407620
[17/50] [1300/1576] 0.503170
[17/50] [1400/1576] 0.474650
[17/50] [1500/1576] 0.230840


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8250246305418719, 'turn_slot_accuracy': 0.9945440613026844, 'turn_slot_f1': 0.9789988212258227}
joint_goal_accuracy: 0.8250246305418719
turn_slot_accuracy: 0.9945440613026844
turn_slot_f1: 0.9789988212258227
[18/50] [0/1576] 0.320006
[18/50] [100/1576] 0.252089
[18/50] [200/1576] 0.288009
[18/50] [300/1576] 0.456320
[18/50] [400/1576] 0.264081
[18/50] [500/1576] 0.219789
[18/50] [600/1576] 0.435738
[18/50] [700/1576] 0.204747
[18/50] [800/1576] 0.722991
[18/50] [900/1576] 0.485436
[18/50] [1000/1576] 0.083618
[18/50] [1100/1576] 0.477426
[18/50] [1200/1576] 0.090356
[18/50] [1300/1576] 0.928920
[18/50] [1400/1576] 0.378012
[18/50] [1500/1576] 1.059957


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8269950738916256, 'turn_slot_accuracy': 0.9948155446086531, 'turn_slot_f1': 0.9799211138104555}
joint_goal_accuracy: 0.8269950738916256
turn_slot_accuracy: 0.9948155446086531
turn_slot_f1: 0.9799211138104555
[19/50] [0/1576] 0.326072
[19/50] [100/1576] 0.480597
[19/50] [200/1576] 0.684925
[19/50] [300/1576] 1.334534
[19/50] [400/1576] 0.370833
[19/50] [500/1576] 0.206354
[19/50] [600/1576] 1.306980
[19/50] [700/1576] 0.222490
[19/50] [800/1576] 0.280721
[19/50] [900/1576] 0.183585
[19/50] [1000/1576] 0.649709
[19/50] [1100/1576] 0.543831
[19/50] [1200/1576] 0.247297
[19/50] [1300/1576] 1.395420
[19/50] [1400/1576] 0.240018
[19/50] [1500/1576] 0.495409


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8238423645320198, 'turn_slot_accuracy': 0.9946097427476782, 'turn_slot_f1': 0.9798757288894174}
joint_goal_accuracy: 0.8238423645320198
turn_slot_accuracy: 0.9946097427476782
turn_slot_f1: 0.9798757288894174
[20/50] [0/1576] 0.248599
[20/50] [100/1576] 0.536902
[20/50] [200/1576] 0.386374
[20/50] [300/1576] 0.149928
[20/50] [400/1576] 0.197248
[20/50] [500/1576] 0.216107
[20/50] [600/1576] 0.392212
[20/50] [700/1576] 0.153576
[20/50] [800/1576] 0.279275
[20/50] [900/1576] 0.178350
[20/50] [1000/1576] 0.307957
[20/50] [1100/1576] 0.303643
[20/50] [1200/1576] 0.318088
[20/50] [1300/1576] 0.576437
[20/50] [1400/1576] 0.488790
[20/50] [1500/1576] 0.503650


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8135960591133005, 'turn_slot_accuracy': 0.9941412151067337, 'turn_slot_f1': 0.9771069423255361}
joint_goal_accuracy: 0.8135960591133005
turn_slot_accuracy: 0.9941412151067337
turn_slot_f1: 0.9771069423255361
[21/50] [0/1576] 0.378807
[21/50] [100/1576] 0.171422
[21/50] [200/1576] 0.065027
[21/50] [300/1576] 0.187827
[21/50] [400/1576] 0.522056
[21/50] [500/1576] 2.576512
[21/50] [600/1576] 0.264283
[21/50] [700/1576] 0.314972
[21/50] [800/1576] 0.144170
[21/50] [900/1576] 0.297084
[21/50] [1000/1576] 0.366936
[21/50] [1100/1576] 0.188193
[21/50] [1200/1576] 1.717545
[21/50] [1300/1576] 0.086610
[21/50] [1400/1576] 0.253502
[21/50] [1500/1576] 0.315819


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8149753694581281, 'turn_slot_accuracy': 0.9941587301587365, 'turn_slot_f1': 0.9778339164247343}
joint_goal_accuracy: 0.8149753694581281
turn_slot_accuracy: 0.9941587301587365
turn_slot_f1: 0.9778339164247343
[22/50] [0/1576] 0.211142
[22/50] [100/1576] 0.163825
[22/50] [200/1576] 0.096663
[22/50] [300/1576] 0.378257
[22/50] [400/1576] 0.125833
[22/50] [500/1576] 0.391641
[22/50] [600/1576] 0.132351
[22/50] [700/1576] 0.218598
[22/50] [800/1576] 0.380100
[22/50] [900/1576] 0.139982
[22/50] [1000/1576] 0.054965
[22/50] [1100/1576] 0.098715
[22/50] [1200/1576] 0.508142
[22/50] [1300/1576] 0.122269
[22/50] [1400/1576] 0.255426
[22/50] [1500/1576] 0.258608


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8200985221674877, 'turn_slot_accuracy': 0.9944477285166988, 'turn_slot_f1': 0.9785378070964356}
joint_goal_accuracy: 0.8200985221674877
turn_slot_accuracy: 0.9944477285166988
turn_slot_f1: 0.9785378070964356
[23/50] [0/1576] 0.334965
[23/50] [100/1576] 0.416464
[23/50] [200/1576] 0.252370
[23/50] [300/1576] 0.197997
[23/50] [400/1576] 0.284919
[23/50] [500/1576] 0.350328
[23/50] [600/1576] 0.097096
[23/50] [700/1576] 0.325148
[23/50] [800/1576] 0.315134
[23/50] [900/1576] 0.314574
[23/50] [1000/1576] 0.259850
[23/50] [1100/1576] 0.156169
[23/50] [1200/1576] 0.204071
[23/50] [1300/1576] 0.732141
[23/50] [1400/1576] 0.284427
[23/50] [1500/1576] 0.441117


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8124137931034483, 'turn_slot_accuracy': 0.9942156540777287, 'turn_slot_f1': 0.9782198243269323}
joint_goal_accuracy: 0.8124137931034483
turn_slot_accuracy: 0.9942156540777287
turn_slot_f1: 0.9782198243269323
[24/50] [0/1576] 0.236092
[24/50] [100/1576] 0.220700
[24/50] [200/1576] 0.197208
[24/50] [300/1576] 1.425202
[24/50] [400/1576] 0.123688
[24/50] [500/1576] 0.160618
[24/50] [600/1576] 0.183016
[24/50] [700/1576] 0.178441
[24/50] [800/1576] 0.066615
[24/50] [900/1576] 0.217481
[24/50] [1000/1576] 0.091648
[24/50] [1100/1576] 0.229046
[24/50] [1200/1576] 0.360505
[24/50] [1300/1576] 0.208245
[24/50] [1400/1576] 0.406774
[24/50] [1500/1576] 0.156066


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8177339901477833, 'turn_slot_accuracy': 0.9944433497536996, 'turn_slot_f1': 0.9784896274019887}
joint_goal_accuracy: 0.8177339901477833
turn_slot_accuracy: 0.9944433497536996
turn_slot_f1: 0.9784896274019887
[25/50] [0/1576] 0.371444
[25/50] [100/1576] 0.168810
[25/50] [200/1576] 0.383297
[25/50] [300/1576] 2.043672
[25/50] [400/1576] 0.142595
[25/50] [500/1576] 0.222427
[25/50] [600/1576] 0.149246
[25/50] [700/1576] 0.171590
[25/50] [800/1576] 1.609555
[25/50] [900/1576] 0.088774
[25/50] [1000/1576] 0.078862
[25/50] [1100/1576] 0.259820
[25/50] [1200/1576] 0.207095
[25/50] [1300/1576] 0.153653
[25/50] [1400/1576] 0.182341
[25/50] [1500/1576] 0.171537


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8163546798029556, 'turn_slot_accuracy': 0.9943338806787128, 'turn_slot_f1': 0.9787079442037503}
joint_goal_accuracy: 0.8163546798029556
turn_slot_accuracy: 0.9943338806787128
turn_slot_f1: 0.9787079442037503
[26/50] [0/1576] 0.207245
[26/50] [100/1576] 0.053678
[26/50] [200/1576] 0.138553
[26/50] [300/1576] 0.229016
[26/50] [400/1576] 0.178588
[26/50] [500/1576] 0.103507
[26/50] [600/1576] 0.142094
[26/50] [700/1576] 0.268116
[26/50] [800/1576] 0.265634
[26/50] [900/1576] 0.341495
[26/50] [1000/1576] 0.148718
[26/50] [1100/1576] 0.234927
[26/50] [1200/1576] 0.214193
[26/50] [1300/1576] 0.235058
[26/50] [1400/1576] 0.116750
[26/50] [1500/1576] 0.134993


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8161576354679803, 'turn_slot_accuracy': 0.9943513957307109, 'turn_slot_f1': 0.9780739453802239}
joint_goal_accuracy: 0.8161576354679803
turn_slot_accuracy: 0.9943513957307109
turn_slot_f1: 0.9780739453802239
[27/50] [0/1576] 0.148760
[27/50] [100/1576] 0.150138
[27/50] [200/1576] 0.271657
[27/50] [300/1576] 0.225627
[27/50] [400/1576] 0.154465
[27/50] [500/1576] 0.335600
[27/50] [600/1576] 0.060086
[27/50] [700/1576] 0.413377
[27/50] [800/1576] 0.615357
[27/50] [900/1576] 0.359441
[27/50] [1000/1576] 0.142061
[27/50] [1100/1576] 0.082030
[27/50] [1200/1576] 0.499120
[27/50] [1300/1576] 0.249406
[27/50] [1400/1576] 0.139029
[27/50] [1500/1576] 0.074082


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8130049261083744, 'turn_slot_accuracy': 0.9941455938697362, 'turn_slot_f1': 0.9776344444414066}
joint_goal_accuracy: 0.8130049261083744
turn_slot_accuracy: 0.9941455938697362
turn_slot_f1: 0.9776344444414066
[28/50] [0/1576] 0.410292
[28/50] [100/1576] 0.174011
[28/50] [200/1576] 0.132106
[28/50] [300/1576] 0.113747
[28/50] [400/1576] 0.536520
[28/50] [500/1576] 0.129978
[28/50] [600/1576] 0.102232
[28/50] [700/1576] 0.075487
[28/50] [800/1576] 0.158470
[28/50] [900/1576] 0.048559
[28/50] [1000/1576] 0.218769
[28/50] [1100/1576] 0.102828
[28/50] [1200/1576] 0.389571
[28/50] [1300/1576] 0.292183
[28/50] [1400/1576] 0.283138
[28/50] [1500/1576] 0.181068


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8214778325123153, 'turn_slot_accuracy': 0.994548440065686, 'turn_slot_f1': 0.9796159269840271}
joint_goal_accuracy: 0.8214778325123153
turn_slot_accuracy: 0.994548440065686
turn_slot_f1: 0.9796159269840271
[29/50] [0/1576] 0.306204
[29/50] [100/1576] 0.321805
[29/50] [200/1576] 0.088422
[29/50] [300/1576] 0.250404
[29/50] [400/1576] 0.189016
[29/50] [500/1576] 0.096706
[29/50] [600/1576] 0.107008
[29/50] [700/1576] 0.195690
[29/50] [800/1576] 0.253912
[29/50] [900/1576] 0.072983
[29/50] [1000/1576] 0.118759
[29/50] [1100/1576] 0.330879
[29/50] [1200/1576] 0.294571
[29/50] [1300/1576] 0.384024
[29/50] [1400/1576] 0.194846
[29/50] [1500/1576] 0.150988


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8177339901477833, 'turn_slot_accuracy': 0.9943207443897155, 'turn_slot_f1': 0.9786299011331218}
joint_goal_accuracy: 0.8177339901477833
turn_slot_accuracy: 0.9943207443897155
turn_slot_f1: 0.9786299011331218
[30/50] [0/1576] 0.373552
[30/50] [100/1576] 0.134990
[30/50] [200/1576] 0.197325
[30/50] [300/1576] 0.265051
[30/50] [400/1576] 0.114738
[30/50] [500/1576] 0.248743
[30/50] [600/1576] 0.184980
[30/50] [700/1576] 0.447686
[30/50] [800/1576] 0.078016
[30/50] [900/1576] 0.239097
[30/50] [1000/1576] 0.102064
[30/50] [1100/1576] 0.400734
[30/50] [1200/1576] 0.140376
[30/50] [1300/1576] 0.164155
[30/50] [1400/1576] 0.223355
[30/50] [1500/1576] 0.268079


100%|██████████| 175/175 [00:56<00:00,  3.10it/s]


{'joint_goal_accuracy': 0.8214778325123153, 'turn_slot_accuracy': 0.9944039408867014, 'turn_slot_f1': 0.979115150874295}
joint_goal_accuracy: 0.8214778325123153
turn_slot_accuracy: 0.9944039408867014
turn_slot_f1: 0.979115150874295
[31/50] [0/1576] 0.083921


KeyboardInterrupt: ignored

## Inference

In [25]:
model.load_state_dict(torch.load("/content/drive/MyDrive/Stage3/model/sumbt_ontology_50_best.pt"))
model = model.eval()

In [26]:
eval_data = json.load(open(f"/content/drive/MyDrive/Stage3/input/data/eval_dataset/eval_dials.json", "r"))

eval_examples = get_examples_from_dialogues(
    eval_data, user_first=True, dialogue_level=True
)

# Extracting Featrues
eval_features = processor.convert_examples_to_features(eval_examples)
eval_data = WOSDataset(eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_loader = DataLoader(
    eval_data,
    batch_size=8,
    sampler=eval_sampler,
    collate_fn=processor.collate_fn,
)


100%|██████████| 2000/2000 [00:00<00:00, 17101.39it/s]


In [27]:
predictions = inference(model, eval_loader, processor, device)

100%|██████████| 250/250 [02:39<00:00,  1.57it/s]


In [28]:
json.dump(predictions, open('/content/drive/MyDrive/Stage3/sumbt_ontology_50_best_pred.csv', 'w'), indent=2, ensure_ascii=False) 

In [29]:
#model.load_state_dict(torch.load("/content/drive/MyDrive/Stage3/model/sumbt50_best.pt"))
#model = model.eval()
predictions = inference(model, dev_loader, processor, device)
eval_result = _evaluation(predictions, dev_labels, slot_meta)

100%|██████████| 175/175 [00:56<00:00,  3.10it/s]

{'joint_goal_accuracy': 0.8269950738916256, 'turn_slot_accuracy': 0.9948155446086531, 'turn_slot_f1': 0.9799211138104555}





In [30]:
cat_ontology = []
noncat_ontology = []
for i in range(45):
    if len(list(ontology.items())[i][1]) <=9:
        cat_ontology.append(list(ontology.items())[i][0])
    else:
        noncat_ontology.append(list(ontology.items())[i][0])

In [31]:
category_pred = []
noncategory_pred = []
category_label = []
noncategory_label = []

In [32]:
pred_keys = list(predictions.keys())
pred_values = list(predictions.values())
labels_keys = list(dev_labels.keys())
labels_values = list(dev_labels.values())

In [33]:
correct = []
wrong = []

In [34]:
# 전체 예측에 대해
last_key = ''
for num_labels in range(len(list(predictions))):
    
    # 마지막 예측이랑 비교해 새로운 예측만 생각
    if labels_keys[num_labels].split(':')[0] != last_key:
        last_pred = set([])
        last_label = set([])
        last_key = labels_keys[num_labels].split(':')[0]
    
    else:
        last_pred = set(pred_values[num_labels-1]) 
        last_label = set(labels_values[num_labels-1])
    
    pred_slots = list(set(pred_values[num_labels]) - last_pred)
    label_slots = list(set(labels_values[num_labels]) - last_label)
                       
    #예측 안에 각각의 예측 slot에 대해
    for label_slot_num in range(len(label_slots)):
        
        # 정답에 예측이 있으면
        if label_slots[label_slot_num] in pred_slots:
            correct.append(label_slots[label_slot_num])
        else:
            wrong.append(label_slots[label_slot_num])

In [35]:
print('len_correct:',len(correct))
print('len_wrong:',len(wrong))

len_correct: 7668
len_wrong: 269


In [36]:
cat_correct = []
noncat_correct = []
cat_wrong = []
noncat_wrong = []
for label in correct:
    if label.split('-')[0] + '-' +label.split('-')[1] in cat_ontology:
        cat_correct.append(label)
    else:
        noncat_correct.append(label)
for label in wrong:
    if label.split('-')[0] + '-' +label.split('-')[1] in cat_ontology:
        cat_wrong.append(label)
    else:
        noncat_wrong.append(label)


In [37]:
print('cat_correct:',len(cat_correct))
print('noncat_correct:',len(noncat_correct))
print('cat_wrong:',len(cat_wrong))
print('noncat_wrong:',len(noncat_wrong))

cat_correct: 3595
noncat_correct: 4073
cat_wrong: 115
noncat_wrong: 154


In [38]:
from collections import defaultdict

In [39]:
cat_correct_dict = defaultdict(int)
noncat_correct_dict = defaultdict(int)
cat_wrong_dict = defaultdict(int)
noncat_wrong_dict = defaultdict(int)

In [40]:
for cat in cat_correct:
    cat_correct_dict[cat.split('-')[0] + '-' +cat.split('-')[1]] +=1
for cat in noncat_correct:
    noncat_correct_dict[cat.split('-')[0] + '-' +cat.split('-')[1]] +=1    
for cat in cat_wrong:
    cat_wrong_dict[cat.split('-')[0] + '-' +cat.split('-')[1]] +=1
for cat in noncat_wrong:
    noncat_wrong_dict[cat.split('-')[0] + '-' +cat.split('-')[1]] +=1

In [41]:
print('cat 정답률')
for cat in cat_ontology:
    if (cat_correct_dict[cat] + cat_wrong_dict[cat]) == 0:
        print(str(cat),str(0))
        continue
    percent = cat_correct_dict[cat]/(cat_correct_dict[cat] + cat_wrong_dict[cat])*100
    print( str(cat), format(percent,".2f"),'%', '('+str(cat_correct_dict[cat])+'/'+str(cat_correct_dict[cat] + cat_wrong_dict[cat])+')')


cat 정답률
관광-경치 좋은 100.00 % (39/39)
관광-교육적 100.00 % (18/18)
관광-도보 가능 0.00 % (0/1)
관광-문화 예술 93.75 % (15/16)
관광-역사적 98.57 % (69/70)
관광-주차 가능 97.22 % (35/36)
관광-지역 97.08 % (299/308)
숙소-가격대 97.05 % (362/373)
숙소-도보 가능 85.71 % (12/14)
숙소-수영장 유무 83.33 % (5/6)
숙소-스파 유무 92.86 % (39/42)
숙소-예약 요일 95.20 % (357/375)
숙소-인터넷 가능 93.94 % (31/33)
숙소-조식 가능 94.74 % (36/38)
숙소-종류 98.34 % (356/362)
숙소-주차 가능 96.00 % (48/50)
숙소-지역 97.21 % (348/358)
숙소-헬스장 유무 94.59 % (35/37)
숙소-흡연 가능 91.43 % (32/35)
식당-가격대 98.55 % (339/344)
식당-도보 가능 75.00 % (3/4)
식당-야외석 유무 100.00 % (56/56)
식당-예약 요일 97.58 % (403/413)
식당-인터넷 가능 96.15 % (25/26)
식당-주류 판매 100.00 % (41/41)
식당-주차 가능 91.43 % (32/35)
식당-지역 96.18 % (327/340)
식당-흡연 가능 84.00 % (21/25)
택시-종류 98.60 % (212/215)


In [42]:
print('cat 오답률')
for cat in cat_ontology:
    if (cat_correct_dict[cat] + cat_wrong_dict[cat]) == 0:
        print(str(cat),str(0))
        continue
    percent = cat_wrong_dict[cat]/(cat_correct_dict[cat] + cat_wrong_dict[cat])*100
    print( str(cat), format(percent,".2f"),'%', '('+str(cat_wrong_dict[cat])+'/'+str(cat_correct_dict[cat] + cat_wrong_dict[cat])+')')


cat 오답률
관광-경치 좋은 0.00 % (0/39)
관광-교육적 0.00 % (0/18)
관광-도보 가능 100.00 % (1/1)
관광-문화 예술 6.25 % (1/16)
관광-역사적 1.43 % (1/70)
관광-주차 가능 2.78 % (1/36)
관광-지역 2.92 % (9/308)
숙소-가격대 2.95 % (11/373)
숙소-도보 가능 14.29 % (2/14)
숙소-수영장 유무 16.67 % (1/6)
숙소-스파 유무 7.14 % (3/42)
숙소-예약 요일 4.80 % (18/375)
숙소-인터넷 가능 6.06 % (2/33)
숙소-조식 가능 5.26 % (2/38)
숙소-종류 1.66 % (6/362)
숙소-주차 가능 4.00 % (2/50)
숙소-지역 2.79 % (10/358)
숙소-헬스장 유무 5.41 % (2/37)
숙소-흡연 가능 8.57 % (3/35)
식당-가격대 1.45 % (5/344)
식당-도보 가능 25.00 % (1/4)
식당-야외석 유무 0.00 % (0/56)
식당-예약 요일 2.42 % (10/413)
식당-인터넷 가능 3.85 % (1/26)
식당-주류 판매 0.00 % (0/41)
식당-주차 가능 8.57 % (3/35)
식당-지역 3.82 % (13/340)
식당-흡연 가능 16.00 % (4/25)
택시-종류 1.40 % (3/215)


In [43]:
print('noncat 정답률')
for cat in noncat_ontology:
    if noncat_correct_dict[cat] ==0:
        print(str(cat),' 0.00 %')
    else:
        percent = noncat_correct_dict[cat]/(noncat_correct_dict[cat] + noncat_wrong_dict[cat])*100
        print( str(cat), format(percent,".2f"),'%', '('+str(noncat_correct_dict[cat])+'/'+str(noncat_correct_dict[cat] + noncat_wrong_dict[cat])+')')


noncat 정답률
관광-이름 95.59 % (325/340)
관광-종류 98.01 % (295/301)
숙소-예약 기간 99.20 % (370/373)
숙소-예약 명수 98.36 % (359/365)
숙소-이름 94.91 % (354/373)
식당-예약 명수 97.66 % (375/384)
식당-예약 시간 97.09 % (400/412)
식당-이름 96.72 % (383/396)
식당-종류 99.41 % (337/339)
지하철-도착지 95.08 % (58/61)
지하철-출발 시간 72.73 % (8/11)
지하철-출발지 91.80 % (56/61)
택시-도착 시간 93.52 % (101/108)
택시-도착지 94.88 % (204/215)
택시-출발 시간 93.41 % (255/273)
택시-출발지 89.77 % (193/215)


In [44]:
print('noncat 오답률')

for cat in noncat_ontology:
    if noncat_correct_dict[cat] ==0:
        print(str(cat),' 0.00 %')
    else:
        percent = noncat_wrong_dict[cat]/(noncat_correct_dict[cat] + noncat_wrong_dict[cat])*100
        print( str(cat), format(percent,".2f"),'%', '('+str(noncat_wrong_dict[cat])+'/'+str(noncat_correct_dict[cat] + noncat_wrong_dict[cat])+')')

noncat 오답률
관광-이름 4.41 % (15/340)
관광-종류 1.99 % (6/301)
숙소-예약 기간 0.80 % (3/373)
숙소-예약 명수 1.64 % (6/365)
숙소-이름 5.09 % (19/373)
식당-예약 명수 2.34 % (9/384)
식당-예약 시간 2.91 % (12/412)
식당-이름 3.28 % (13/396)
식당-종류 0.59 % (2/339)
지하철-도착지 4.92 % (3/61)
지하철-출발 시간 27.27 % (3/11)
지하철-출발지 8.20 % (5/61)
택시-도착 시간 6.48 % (7/108)
택시-도착지 5.12 % (11/215)
택시-출발 시간 6.59 % (18/273)
택시-출발지 10.23 % (22/215)
