In [1]:
import sys
sys.path.append('..')

In [2]:
import os
from pathlib import Path
import json
from tqdm import tqdm
import random
from collections import defaultdict

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
import torch.cuda.amp as amp

from transformers import BertModel, BertTokenizer, BertConfig, AdamW, get_linear_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup
from data_utils import (
    load_dataset, 
    get_examples_from_dialogues, 
    convert_state_dict, 
    DSTInputExample, 
    OpenVocabDSTFeature, 
    DSTPreprocessor, 
    WOSDataset,
    set_seed)
    
from inference import inference
from evaluation import _evaluation

In [3]:
import wandb
# !wandb login  # run once

In [4]:
# def seed_everything(seed):
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed(seed)
#     torch.cuda.manual_seed_all(seed)  # if use multi-GPU        
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False
#     np.random.seed(seed)
#     random.seed(seed)

# seed_everything(42)
set_seed(42)

In [5]:
def increment_output_dir(output_path, exist_ok=False):
  path = Path(output_path)
  if (path.exists() and exist_ok) or (not path.exists()):
    return str(path)
  else:
    dirs = glob.glob(f"{path}*")
    matches = [re.search(rf"%s(\d+)" %path.stem, d) for d in dirs]
    i = [int(m.groups()[0]) for m in matches if m]
    n = max(i) + 1 if i else 2
    return f"{path}{n}"

### args setting

In [6]:
from argparse import Namespace

args = {
    'data_dir': 'data/train_dataset',
    'model_dir': 'trade',
    'train_batch_size': 16,
    'eval_batch_size': 32,
    'learning_rate': 1e-4,
    'adam_epsilon': 1e-8,
    'max_grad_norm': 1.0,   
    'num_train_epochs': 30,
    'warmup_ratio': 0.1,
    'random_seed': 42,
    'model_name_or_path': 'monologg/koelectra-base-v3-discriminator',
    'hidden_size': 768,
    'vocab_size': None,
    'hidden_dropout_prob': 0.1,
    'proj_dim': None,
    'teacher_forcing_ratio': 0.5,    
}

args = Namespace(**args)

In [7]:
# wandb sweep 생성 시 parameters에 전달하는 config 설정
# hyperparameter_defaults = dict(
#     batch_size = args.batch_size,
#     learning_rate = args.learning_rate,
#     epochs = args.num_train_epochs,
#     weight_decay = args.weight_decay,
#     attn_head = args.attn_head,
#     distance_metric = args.distance_metric,
    
#     dropout = 0.1,
#     smoothing = 0.2
#     model_name = 'BertForSequenceClassification',
#     tokenizer_name = 'BertTokenizer',
#     )

# wandb.init(config=hyperparameter_defaults, project="TRADE")
wandb.init(project="TRADE")
config = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33mtaepd[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


## Data loading

In [8]:
train_data_file = "/opt/ml/repo/taepd/input/data/train_dataset/train_dials.json"
slot_meta = json.load(open("/opt/ml/repo/taepd/input/data/train_dataset/slot_meta.json"))
ontology = json.load(open("/opt/ml/repo/taepd/input/data/train_dataset/ontology.json"))
train_data, dev_data, dev_labels = load_dataset(train_data_file)

# dialogue_level=False : SUMBT와 다르게 dialogue context level로 input하므로
train_examples = get_examples_from_dialogues(train_data,
                                             user_first=False,
                                             dialogue_level=False)
dev_examples = get_examples_from_dialogues(dev_data,
                                           user_first=False,
                                           dialogue_level=False)

100%|██████████| 6301/6301 [00:00<00:00, 9005.64it/s]
100%|██████████| 699/699 [00:00<00:00, 2925.43it/s]


In [9]:
print(len(train_examples))
print(len(dev_examples))

46170
5075


## TRADE Preprocessor 

기존의 GRU 기반의 인코더를 BERT-based Encoder로 바꿀 준비를 합시다.

1. 현재 `_convert_example_to_feature`에서는 `max_seq_length`를 핸들하고 있지 않습니다. `input_id`와 `segment_id`가 `max_seq_length`를 넘어가면 좌측부터 truncate시키는 코드를 삽입하세요.

2. hybrid approach에서 얻은 교훈을 바탕으로 gate class를 3개에서 5개로 늘려봅시다.
    - WoS 데이터셋의 특성상 (boolean type이 많음) 'yes', 'no' gate를 추가했을 때 더 나은 성능을 보입니다.
    - `gating2id`를 수정하세요
    - 이에 따른 `recover_state`를 수정하세요.

3. word dropout을 구현하세요.

In [20]:
class TRADEPreprocessor(DSTPreprocessor):
    def __init__(
        self,
        slot_meta,
        src_tokenizer,
        trg_tokenizer=None,
        ontology=None,
        max_seq_length=512,
        word_drop = 0.1
    ):
        self.slot_meta = slot_meta
        self.src_tokenizer = src_tokenizer
        self.trg_tokenizer = trg_tokenizer if trg_tokenizer else src_tokenizer
        self.ontology = ontology
        self.gating2id =  {"none": 0, "dontcare": 1, "yes": 2, "no": 3, "ptr": 4}         
        self.id2gating = {v: k for k, v in self.gating2id.items()}
        self.max_seq_length = max_seq_length
        self.word_drop = word_drop
        
    def _convert_example_to_feature(self, example):
        dialogue_context = " [SEP] ".join(example.context_turns + example.current_turn)

        input_id = self.src_tokenizer.encode(dialogue_context, add_special_tokens=False)
        max_length = self.max_seq_length - 2
        
        if len(input_id) > max_length:
            gap = len(input_id) - max_length
            input_id = input_id[gap:]
       
        input_id = (
            [self.src_tokenizer.cls_token_id]
            + input_id
            + [self.src_tokenizer.sep_token_id]
        )
        segment_id = [0] * len(input_id)  # bert기반이 되었을 땐 segment 구분이 필요할 수 있음

        target_ids = []
        gating_id = []
        if not example.label:
            example.label = []

        state = convert_state_dict(example.label)
        for slot in self.slot_meta:
            value = state.get(slot, "none")
            target_id = self.trg_tokenizer.encode(value, add_special_tokens=False) + [
                self.trg_tokenizer.sep_token_id
            ]  # [sep]를 추가하여, AutoRegressive시, EOS token으로 사용
            target_ids.append(target_id)
            gating_id.append(self.gating2id.get(value, self.gating2id["ptr"]))
        target_ids = self.pad_ids(target_ids, self.trg_tokenizer.pad_token_id)   # collate_fn에서 처리해주는 것 같은데?
        return OpenVocabDSTFeature(
            example.guid, input_id, segment_id, gating_id, target_ids
        )

    def convert_examples_to_features(self, examples):
        return list(map(self._convert_example_to_feature, examples))

    def recover_state(self, gate_list, gen_list):
        assert len(gate_list) == len(self.slot_meta)
        assert len(gen_list) == len(self.slot_meta)

        recovered = []
        for slot, gate, value in zip(self.slot_meta, gate_list, gen_list):
            if self.id2gating[gate] == "none":
                continue

            if self.id2gating[gate] == "dontcare":
                recovered.append("%s-%s" % (slot, "dontcare"))
                continue

            token_id_list = []
            for id_ in value:
                if id_ in self.trg_tokenizer.all_special_ids:
                    break

                token_id_list.append(id_)
            value = self.trg_tokenizer.decode(token_id_list, skip_special_tokens=True)

            if value == "none":
                continue

            recovered.append("%s-%s" % (slot, value))

        return recovered

    def collate_fn(self, batch):
        guids = [b.guid for b in batch]
        input_ids = torch.LongTensor(
            self.pad_ids([b.input_id for b in batch], self.src_tokenizer.pad_token_id)
        )
        if self.word_drop > 0.0:
            input_ids = []
            for b in batch:
                drop_mask = (
                    np.array(
                        self.src_tokenizer.get_special_tokens_mask(
                            b.input_id, already_has_special_tokens=True)
                        )
                        == 0
                ).astype(int)
                word_drop = np.random.binomial(drop_mask, self.word_drop)
                input_id = [
                    token_id if word_drop[i] == 0 else self.src_tokenizer.unk_token_id
                    for i, token_id in enumerate(b.input_id)
                ]
                input_ids.append(input_id)
        input_ids = torch.LongTensor(
            self.pad_ids(
                [b for b in input_ids],
                self.src_tokenizer.pad_token_id,
                max_length=512,
            )
        )
    
        segment_ids = torch.LongTensor(
            self.pad_ids([b.segment_id for b in batch], self.src_tokenizer.pad_token_id)
        )
        input_masks = input_ids.ne(self.src_tokenizer.pad_token_id)

        gating_ids = torch.LongTensor([b.gating_id for b in batch])
        target_ids = self.pad_id_of_matrix(
            [torch.LongTensor(b.target_ids) for b in batch],
            self.trg_tokenizer.pad_token_id,
        )
        return input_ids, segment_ids, input_masks, gating_ids, target_ids, guids

## Convert_Examples_to_Features 

In [21]:
tokenizer = BertTokenizer.from_pretrained('dsksd/bert-ko-small-minimal')
processor = TRADEPreprocessor(slot_meta, tokenizer, max_seq_length=512)

train_features = processor.convert_examples_to_features(train_examples)
dev_features = processor.convert_examples_to_features(dev_examples)

# 간이 테스트 시 사용
# train_features = processor.convert_examples_to_features(train_examples[:1000])
# dev_features = processor.convert_examples_to_features(dev_examples)

Token indices sequence length is longer than the specified maximum sequence length for this model (537 > 512). Running this sequence through the model will result in indexing errors


In [22]:
print(len(train_features))
print(len(dev_features))

46170
5075


# Model 

1. `GRUEncoder`를 `BertModel`로 교체하세요. 이에 따라 `tie_weight` 함수가 수정되어야 합니다.

In [23]:
class TRADE(nn.Module):
    def __init__(self, config, slot_vocab, slot_meta, pad_idx=0):
        super(TRADE, self).__init__()
        self.slot_meta = slot_meta
        # https://huggingface.co/dsksd/bert-ko-small-minimal/blob/main/config.json
        if config.model_name_or_path:
            self.encoder = BertModel.from_pretrained(config.model_name_or_path)
        else:
            self.encoder = BertModel(config)    
          
        self.decoder = SlotGenerator(
            config.vocab_size,
            config.hidden_size,
            config.hidden_dropout_prob,
            config.n_gate,
            None,
            pad_idx,
        )
        
        # init for only subword embedding
        self.decoder.set_slot_idx(slot_vocab)
        self.tie_weight()

    # encoder와 decoder의 embedding matrix sharing    
    def tie_weight(self):
#         self.decoder.embed = self.encoder.get_input_embeddings()
        self.decoder.embed.weight = self.encoder.embeddings.word_embeddings.weight
        if self.decoder.proj_layer:
            self.decoder.proj_layer.weight = self.encoder.proj_layer.weight

    def forward(self, 
                input_ids,  # p^history를 vocal size로 변환하기 위해 input_ids받음 
                token_type_ids,
                attention_mask=None, 
                max_len=10,  # maximum length를 target_id에서 받고, inference 때는 고정된 constant를 받음
                teacher=None):  # teacher forcing을 위한 GT Sequence

        encoder_outputs, pooled_output = self.encoder(input_ids=input_ids)
        all_point_outputs, all_gate_outputs = self.decoder(
            input_ids, encoder_outputs, pooled_output.unsqueeze(0), attention_mask, max_len, teacher
        )

        return all_point_outputs, all_gate_outputs
    
class SlotGenerator(nn.Module):
    def __init__(
        self, vocab_size, hidden_size, dropout, n_gate, proj_dim=None, pad_idx=0
    ):
        super(SlotGenerator, self).__init__()
        self.pad_idx = pad_idx
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(
            vocab_size, hidden_size, padding_idx=pad_idx
        )  # shared with encoder

        if proj_dim:
            self.proj_layer = nn.Linear(hidden_size, proj_dim, bias=False)
        else:
            self.proj_layer = None
        self.hidden_size = proj_dim if proj_dim else hidden_size

        self.gru = nn.GRU(
            self.hidden_size, self.hidden_size, 1, dropout=dropout, batch_first=True
        )  # unidirectional GRU
        self.n_gate = n_gate
        self.dropout = nn.Dropout(dropout)
        self.w_gen = nn.Linear(self.hidden_size * 3, 1)  # p_gen을 만들 때 3가지를 concat하므로 *3
        self.sigmoid = nn.Sigmoid()
        self.w_gate = nn.Linear(self.hidden_size, n_gate)

    def set_slot_idx(self, slot_vocab_idx):
        whole = []
        max_length = max(map(len, slot_vocab_idx))   # tokenized slot_meta 길이 최댓값: ex)4
        for idx in slot_vocab_idx:
            if len(idx) < max_length:
                gap = max_length - len(idx)
                idx.extend([self.pad_idx] * gap)
            whole.append(idx)
        self.slot_embed_idx = whole  # torch.LongTensor(whole)  ex) [[6728, 6479, 1, 1],..]

    def embedding(self, x):
        x = self.embed(x)
        if self.proj_layer:
            x = self.proj_layer(x)
        return x

    def forward(
        self, input_ids, encoder_output, hidden, input_masks, max_len, teacher=None
    ):
        input_masks = input_masks.ne(1)
        # J, slot_meta : key : [domain, slot] ex> LongTensor([1,2])
        # J,2
        batch_size = encoder_output.size(0)
        slot = torch.LongTensor(self.slot_embed_idx).to(input_ids.device)  ##
#         slot = torch.LongTensor(self.slot_embed_idx, device=input_ids.device)  ##
        slot_e = torch.sum(self.embedding(slot), 1)  # J,d  J * slot_vocab_idx 를 embedding을 거쳐서 sum해주어 J * hidden 이 됨
        J = slot_e.size(0)

        all_point_outputs = torch.zeros(batch_size, J, max_len, self.vocab_size).to(  # max_len: max gen len
            input_ids.device
        )  # Output Tensor (for Placeholder)
        
        # Parallel Decoding
        w = slot_e.repeat(batch_size, 1).unsqueeze(1)
        hidden = hidden.repeat_interleave(J, dim=1)
        encoder_output = encoder_output.repeat_interleave(J, dim=0)
        input_ids = input_ids.repeat_interleave(J, dim=0)
        input_masks = input_masks.repeat_interleave(J, dim=0)
        for k in range(max_len):
            w = self.dropout(w)
            _, hidden = self.gru(w, hidden)  # 1,B,D

            # B,T,D * B,D,1 => B,T,1  (T가 |X_t|인듯, input sequence length)
            attn_e = torch.bmm(encoder_output, hidden.permute(1, 2, 0))  # B,T,1
#             attn_e = attn_e.squeeze(-1).masked_fill(input_masks, -1e9)
            attn_e = attn_e.squeeze(-1).masked_fill(input_masks, -1e4)  # mixed precision 적용위해 수정
            attn_history = F.softmax(attn_e, -1)  # B,T

            if self.proj_layer:
                hidden_proj = torch.matmul(hidden, self.proj_layer.weight)
            else:
                hidden_proj = hidden

            # B,D * D,V => B,V
            attn_v = torch.matmul(
                hidden_proj.squeeze(0), self.embed.weight.transpose(0, 1)
            )  # B,V
            attn_vocab = F.softmax(attn_v, -1)

            # B,1,T * B,T,D => B,1,D
            context = torch.bmm(attn_history.unsqueeze(1), encoder_output)  # B,1,D
            p_gen = self.sigmoid(
                self.w_gen(torch.cat([w, hidden.transpose(0, 1), context], -1))
            )  # B,1
            p_gen = p_gen.squeeze(-1)

            p_context_ptr = torch.zeros_like(attn_vocab).to(input_ids.device)
            p_context_ptr.scatter_add_(1, input_ids, attn_history)  # copy B,V
            p_final = p_gen * attn_vocab + (1 - p_gen) * p_context_ptr  # B,V
            _, w_idx = p_final.max(-1)

            if teacher is not None:
#                 w = self.embedding(teacher[:, :, k]).transpose(0, 1).reshape(batch_size * J, 1, -1)
                w = self.embedding(teacher[:, :, k]).reshape(batch_size * J, 1, -1)
            else:
                w = self.embedding(w_idx).unsqueeze(1)  # B,1,D
            if k == 0:
                gated_logit = self.w_gate(context.squeeze(1))  # B,3
                all_gate_outputs = gated_logit.view(batch_size, J, self.n_gate)
            all_point_outputs[:, :, k, :] = p_final.view(batch_size, J, self.vocab_size)

        return all_point_outputs, all_gate_outputs

# 모델 및 데이터 로더 정의

In [24]:
slot_vocab = []
for slot in slot_meta:
    slot_vocab.append(
        tokenizer.encode(slot.replace('-', ' '),  # 원래는 domain, slot을 영어와 달리 동일 길이로 토큰화되지 않기에 그냥 concat
                         add_special_tokens=False)
    )
    
config = BertConfig.from_pretrained('dsksd/bert-ko-small-minimal')
config.model_name_or_path = 'dsksd/bert-ko-small-minimal'
config.n_gate = len(processor.gating2id)
config.proj_dim = None
model = TRADE(config, slot_vocab, slot_meta)

wandb.watch(model)

[<wandb.wandb_torch.TorchGraph at 0x7fda32b46a10>]

In [25]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_data = WOSDataset(train_features)
train_sampler = RandomSampler(train_data)
train_loader = DataLoader(train_data, batch_size=4, sampler=train_sampler, collate_fn=processor.collate_fn,  num_workers=4)

dev_data = WOSDataset(dev_features)
dev_sampler = SequentialSampler(dev_data)
dev_loader = DataLoader(dev_data, batch_size=16, sampler=dev_sampler, collate_fn=processor.collate_fn, num_workers=4)

# Optimizer & Scheduler 선언

In [26]:
n_epochs = 100
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

t_total = len(train_loader) * n_epochs
# warmup_steps = int(t_total * cfg.warmup_ratio)
# optimizer = AdamW(optimizer_grouped_parameters, lr=3e-5, eps=1e-8)
optimizer = AdamW(optimizer_grouped_parameters, lr=5e-6, eps=1e-8)
# scheduler = get_linear_schedule_with_warmup(
#     optimizer, num_warmup_steps=0.1, num_training_steps=t_total
# )


scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
    optimizer, num_warmup_steps=0.1, num_training_steps=t_total, num_cycles=20
)
# teacher_forcing = 0.5
teacher_forcing = 1
model.to(device)

def masked_cross_entropy_for_value(logits, target, pad_idx=0):  # logits: (B, J, k, V)
    mask = target.ne(pad_idx)
    logits_flat = logits.view(-1, logits.size(-1))   # B*J*k, V (ex) [900, 35000])
    log_probs_flat = torch.log(logits_flat)
    target_flat = target.view(-1, 1)  # 열벡터로 변환  [900, 1]
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)  # [900, 1]
    losses = losses_flat.view(*target.size())   # 행벡터로 변환 [900]
    losses = losses * mask.float()   # mask(padding)은 제외
    loss = losses.sum() / (mask.sum().float())
    return loss

loss_fnc_1 = masked_cross_entropy_for_value  # generation
loss_fnc_2 = nn.CrossEntropyLoss()  # gating

In [27]:
def get_lr(scheduler):
    return scheduler.get_last_lr()[0]

## Train

In [29]:
# for checkpoint management
chk_list = []
output_dir = increment_output_dir(wandb.run.name)
PATIENCE = 0 
if not os.path.exists(f"checkpoint/{output_dir}"):
    os.makedirs(f"checkpoint/{output_dir}")  

best_score, best_checkpoint = 0, 0
epoch_miss_labels = defaultdict(list)
for epoch in range(n_epochs):
    batch_loss = []
    model.train()
    for step, batch in enumerate(tqdm(train_loader)):
        input_ids, segment_ids, input_masks, gating_ids, target_ids,_ = [b.to(device) if not isinstance(b, list) else b for b in batch]
        if teacher_forcing > 0.0 and random.random() < teacher_forcing:
            tf = target_ids
        else:
            tf = None
        with amp.autocast():
            all_point_outputs, all_gate_outputs = model(input_ids, segment_ids, input_masks, target_ids.size(-1), tf)  # gt - length (generation)
            loss_1 = loss_fnc_1(all_point_outputs.contiguous(), target_ids.contiguous().view(-1))
            loss_2 = loss_fnc_2(all_gate_outputs.contiguous().view(-1, 5), gating_ids.contiguous().view(-1))
            loss = loss_1 + loss_2
    #         batch_loss.append(loss.item())

            loss.backward()
        
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        if step % 100 == 0:                   
            print('[%d/%d] [%d/%d] %f' % (epoch+1, n_epochs, step, len(train_loader), loss.item()))
            # -- train 단계에서 Loss, Accuracy 로그 저장
            wandb.log({
                "train/loss": loss.item(),
                "train/gen_loss": loss_1.item(),
                "train/gate_loss": loss_2.item(),
                "train/epoch": epoch+1,
                "train/learning rate": get_lr(scheduler)
            })
            
    predictions = inference(model, dev_loader, processor, device)
#     eval_result,  = _evaluation(predictions, dev_labels, slot_meta)
    eval_result, batch_miss_labels = _evaluation(predictions, dev_labels, slot_meta)
            
    epoch_miss_labels[epoch].extend(batch_miss_labels)   
    
    wandb.log({
            "eval/loss": loss.item(),
            "Joint Goal Accuracy": eval_result['joint_goal_accuracy'],
            "Turn Slot_Accuracy": eval_result['turn_slot_accuracy'],
            "Turn Slot F1": eval_result['turn_slot_f1']
            })
      
    for k, v in eval_result.items():
        print(f"{k}: {v}")
        
    if best_score < eval_result['joint_goal_accuracy']:
        print("Update Best checkpoint!")
        best_score = eval_result['joint_goal_accuracy']
        best_checkpoint = epoch
        if not os.path.isdir('./checkpoint'):
            os.makedirs('./checkpoint')
        output_path = f"checkpoint/{output_dir}/{epoch}_{step}_{best_score}.pth"
        chk_list.append(output_path)

        torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                    }, output_path)
        PATIENCE = 0
    else:
        PATIENCE += 1
        if PATIENCE == 10:
            break

#     torch.save(model.state_dict(), f"{args.model_dir}/model-{epoch}.bin")
#     torch.save(model.state_dict(), f"checkpoint/trage/{output_dir}/model-{epoch}.bin")
# print(f"Best checkpoint: {args.model_dir}/model-{best_checkpoint}.bin")    

  0%|          | 2/11543 [00:00<1:53:12,  1.70it/s]

[1/100] [0/11543] 0.022652


  1%|          | 101/11543 [00:19<33:59,  5.61it/s]

[1/100] [100/11543] 0.009611


  2%|▏         | 202/11543 [00:37<31:22,  6.03it/s]

[1/100] [200/11543] 0.039694


  3%|▎         | 301/11543 [00:54<33:10,  5.65it/s]

[1/100] [300/11543] 0.009463


  3%|▎         | 401/11543 [01:12<34:13,  5.43it/s]

[1/100] [400/11543] 0.021971


  4%|▍         | 502/11543 [01:31<32:24,  5.68it/s]

[1/100] [500/11543] 0.042811


  5%|▌         | 602/11543 [01:49<31:09,  5.85it/s]

[1/100] [600/11543] 0.012937


  6%|▌         | 702/11543 [02:07<32:57,  5.48it/s]

[1/100] [700/11543] 0.036543


  7%|▋         | 802/11543 [02:25<28:21,  6.31it/s]

[1/100] [800/11543] 0.105237


  8%|▊         | 901/11543 [02:46<38:43,  4.58it/s]

[1/100] [900/11543] 0.014273


  9%|▊         | 1001/11543 [03:05<39:23,  4.46it/s]

[1/100] [1000/11543] 0.052122


 10%|▉         | 1102/11543 [03:23<29:42,  5.86it/s]

[1/100] [1100/11543] 0.062076


 10%|█         | 1202/11543 [03:41<31:40,  5.44it/s]

[1/100] [1200/11543] 0.045864


 11%|█▏        | 1302/11543 [03:58<27:47,  6.14it/s]

[1/100] [1300/11543] 0.071168


 12%|█▏        | 1402/11543 [04:16<29:34,  5.71it/s]

[1/100] [1400/11543] 0.033931


 13%|█▎        | 1502/11543 [04:33<27:54,  6.00it/s]

[1/100] [1500/11543] 0.014039


 14%|█▍        | 1602/11543 [04:52<33:21,  4.97it/s]

[1/100] [1600/11543] 0.077670


 15%|█▍        | 1702/11543 [05:11<33:29,  4.90it/s]

[1/100] [1700/11543] 0.031041


 16%|█▌        | 1802/11543 [05:30<32:46,  4.95it/s]

[1/100] [1800/11543] 0.039078


 16%|█▋        | 1902/11543 [05:50<27:55,  5.75it/s]

[1/100] [1900/11543] 0.027767


 17%|█▋        | 2002/11543 [06:08<31:02,  5.12it/s]

[1/100] [2000/11543] 0.047656


 18%|█▊        | 2102/11543 [06:26<25:31,  6.16it/s]

[1/100] [2100/11543] 0.020192


 19%|█▉        | 2202/11543 [06:43<29:33,  5.27it/s]

[1/100] [2200/11543] 0.078884


 20%|█▉        | 2301/11543 [07:01<28:34,  5.39it/s]

[1/100] [2300/11543] 0.193545


 21%|██        | 2402/11543 [07:20<27:21,  5.57it/s]

[1/100] [2400/11543] 0.001603


 22%|██▏       | 2502/11543 [07:38<27:49,  5.41it/s]

[1/100] [2500/11543] 0.014340


 23%|██▎       | 2601/11543 [07:57<36:08,  4.12it/s]

[1/100] [2600/11543] 0.034791


 23%|██▎       | 2702/11543 [08:16<24:25,  6.03it/s]

[1/100] [2700/11543] 0.018635


 24%|██▍       | 2802/11543 [08:34<25:31,  5.71it/s]

[1/100] [2800/11543] 0.071383


 25%|██▌       | 2901/11543 [08:52<26:23,  5.46it/s]

[1/100] [2900/11543] 0.032785


 26%|██▌       | 3002/11543 [09:11<27:31,  5.17it/s]

[1/100] [3000/11543] 0.028624


 27%|██▋       | 3102/11543 [09:29<24:18,  5.79it/s]

[1/100] [3100/11543] 0.046108


 28%|██▊       | 3201/11543 [09:50<34:18,  4.05it/s]

[1/100] [3200/11543] 0.019663


 29%|██▊       | 3302/11543 [10:08<23:05,  5.95it/s]

[1/100] [3300/11543] 0.003926


 29%|██▉       | 3402/11543 [10:26<25:26,  5.33it/s]

[1/100] [3400/11543] 0.026515


 30%|███       | 3501/11543 [10:43<25:04,  5.35it/s]

[1/100] [3500/11543] 0.064183


 31%|███       | 3601/11543 [11:01<24:45,  5.35it/s]

[1/100] [3600/11543] 0.008887


 32%|███▏      | 3702/11543 [11:20<28:07,  4.65it/s]

[1/100] [3700/11543] 0.058059


 33%|███▎      | 3801/11543 [11:41<26:14,  4.92it/s]

[1/100] [3800/11543] 0.001081


 34%|███▍      | 3902/11543 [12:00<22:33,  5.65it/s]

[1/100] [3900/11543] 0.027919


 35%|███▍      | 4001/11543 [12:19<31:06,  4.04it/s]

[1/100] [4000/11543] 0.050544


 36%|███▌      | 4101/11543 [12:40<24:13,  5.12it/s]

[1/100] [4100/11543] 0.042712


 36%|███▋      | 4202/11543 [12:59<21:51,  5.60it/s]

[1/100] [4200/11543] 0.026943


 37%|███▋      | 4301/11543 [13:18<25:32,  4.73it/s]

[1/100] [4300/11543] 0.048204


 38%|███▊      | 4401/11543 [13:40<23:31,  5.06it/s]

[1/100] [4400/11543] 0.009611


 39%|███▉      | 4501/11543 [13:59<22:50,  5.14it/s]

[1/100] [4500/11543] 0.026221


 40%|███▉      | 4602/11543 [14:18<19:36,  5.90it/s]

[1/100] [4600/11543] 0.005858


 41%|████      | 4702/11543 [14:35<21:15,  5.36it/s]

[1/100] [4700/11543] 0.081244


 42%|████▏     | 4801/11543 [14:53<19:17,  5.82it/s]

[1/100] [4800/11543] 0.020873


 42%|████▏     | 4901/11543 [15:11<18:28,  5.99it/s]

[1/100] [4900/11543] 0.053263


 43%|████▎     | 5001/11543 [15:29<22:49,  4.78it/s]

[1/100] [5000/11543] 0.010574


 44%|████▍     | 5102/11543 [15:47<18:03,  5.94it/s]

[1/100] [5100/11543] 0.000519


 45%|████▌     | 5202/11543 [16:05<16:54,  6.25it/s]

[1/100] [5200/11543] 0.047248


 46%|████▌     | 5301/11543 [16:23<20:46,  5.01it/s]

[1/100] [5300/11543] 0.023456


 47%|████▋     | 5402/11543 [16:41<18:24,  5.56it/s]

[1/100] [5400/11543] 0.134960


 48%|████▊     | 5502/11543 [16:59<17:43,  5.68it/s]

[1/100] [5500/11543] 0.011380


 49%|████▊     | 5601/11543 [17:17<18:17,  5.41it/s]

[1/100] [5600/11543] 0.049557


 49%|████▉     | 5702/11543 [17:36<17:20,  5.62it/s]

[1/100] [5700/11543] 0.013842


 50%|█████     | 5802/11543 [17:53<16:43,  5.72it/s]

[1/100] [5800/11543] 0.010422


 51%|█████     | 5901/11543 [18:11<18:25,  5.10it/s]

[1/100] [5900/11543] 0.034677


 52%|█████▏    | 6001/11543 [18:30<23:13,  3.98it/s]

[1/100] [6000/11543] 0.056668


 53%|█████▎    | 6101/11543 [18:51<17:45,  5.11it/s]

[1/100] [6100/11543] 0.075240


 54%|█████▎    | 6202/11543 [19:10<16:04,  5.54it/s]

[1/100] [6200/11543] 0.024719


 55%|█████▍    | 6302/11543 [19:27<15:03,  5.80it/s]

[1/100] [6300/11543] 0.072280


 55%|█████▌    | 6402/11543 [19:45<14:03,  6.10it/s]

[1/100] [6400/11543] 0.036398


 56%|█████▋    | 6501/11543 [20:03<14:39,  5.73it/s]

[1/100] [6500/11543] 0.049111


 57%|█████▋    | 6602/11543 [20:21<15:38,  5.26it/s]

[1/100] [6600/11543] 0.039290


 58%|█████▊    | 6701/11543 [20:43<18:45,  4.30it/s]

[1/100] [6700/11543] 0.010772


 59%|█████▉    | 6802/11543 [21:02<13:05,  6.03it/s]

[1/100] [6800/11543] 0.008300


 60%|█████▉    | 6902/11543 [21:20<14:42,  5.26it/s]

[1/100] [6900/11543] 0.015298


 61%|██████    | 7002/11543 [21:39<16:12,  4.67it/s]

[1/100] [7000/11543] 0.017047


 62%|██████▏   | 7102/11543 [21:57<13:58,  5.30it/s]

[1/100] [7100/11543] 0.093169


 62%|██████▏   | 7202/11543 [22:15<12:36,  5.74it/s]

[1/100] [7200/11543] nan


 63%|██████▎   | 7302/11543 [22:33<12:27,  5.67it/s]

[1/100] [7300/11543] nan


 64%|██████▍   | 7402/11543 [22:53<14:08,  4.88it/s]

[1/100] [7400/11543] nan


 65%|██████▍   | 7502/11543 [23:14<13:11,  5.10it/s]

[1/100] [7500/11543] nan


 66%|██████▌   | 7602/11543 [23:32<11:52,  5.53it/s]

[1/100] [7600/11543] nan


 67%|██████▋   | 7701/11543 [23:49<10:55,  5.86it/s]

[1/100] [7700/11543] nan


 68%|██████▊   | 7802/11543 [24:10<11:12,  5.56it/s]

[1/100] [7800/11543] nan


 68%|██████▊   | 7902/11543 [24:28<12:08,  5.00it/s]

[1/100] [7900/11543] nan


 69%|██████▉   | 8002/11543 [24:47<10:46,  5.47it/s]

[1/100] [8000/11543] nan


 70%|██████▉   | 8028/11543 [24:53<10:53,  5.38it/s]


KeyboardInterrupt: 

## Inference 

In [29]:
eval_data = json.load(open(f"/opt/ml/repo/taepd/input/data/eval_dataset/eval_dials.json", "r"))

eval_examples = get_examples_from_dialogues(
    eval_data, user_first=False, dialogue_level=False
)

# Extracting Featrues
eval_features = processor.convert_examples_to_features(eval_examples)
eval_data = WOSDataset(eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_loader = DataLoader(
    eval_data,
    batch_size=8,
    sampler=eval_sampler,
    collate_fn=processor.collate_fn,
)

100%|██████████| 2000/2000 [00:00<00:00, 16252.49it/s]


In [107]:
predictions = inference(model, eval_loader, processor, device)

100%|██████████| 1847/1847 [05:23<00:00,  5.71it/s]


In [108]:
json.dump(predictions, open('predictions.csv', 'w'), indent=2, ensure_ascii=False)

### eda용 missing label exporting

In [109]:
with open('miss_labels.json', 'w') as outfile:
    json.dump(epoch_miss_labels, outfile)

In [28]:
# config = BertConfig.from_pretrained('dsksd/bert-ko-small-minimal')
# config.model_name_or_path = 'dsksd/bert-ko-small-minimal'
# config.n_gate = len(processor.gating2id)
# config.proj_dim = None
# model1 = TRADE(config, slot_vocab, slot_meta)
PATH = 'checkpoint/atomic-eon-15/10_11542_0.7903448275862069.pth'

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)

optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']


In [76]:
predictions = inference(model1, eval_loader, processor, device)

100%|██████████| 1847/1847 [05:38<00:00,  5.46it/s]


### 꿔 = 바로우 case 후처리

In [110]:
import re

for k, v in predictions.items():
    tmp = []
    for e in v:          
        tmp.append(re.sub('\s(?=[\=\(\)\&])|(?<=[\=\(\)\&])\s', "", e))
    predictions[k] = tmp



In [111]:
json.dump(predictions, open('predictions.csv', 'w'), indent=2, ensure_ascii=False)

In [104]:
predictions

{'wild-bonus-5601:식당_택시_12-0': ['식당-종류-중식당', '식당-주차 가능-yes', '식당-지역-서울 북쪽'],
 'wild-bonus-5601:식당_택시_12-1': ['식당-가격대-dontcare',
  '식당-종류-중식당',
  '식당-주류 판매-yes',
  '식당-주차 가능-yes',
  '식당-지역-서울 북쪽'],
 'wild-bonus-5601:식당_택시_12-2': ['식당-가격대-dontcare',
  '식당-종류-중식당',
  '식당-주류 판매-yes',
  '식당-주차 가능-yes',
  '식당-지역-서울 북쪽'],
 'wild-bonus-5601:식당_택시_12-3': ['식당-가격대-dontcare',
  '식당-예약 명수-9',
  '식당-예약 시간-02:20',
  '식당-예약 요일-월요일',
  '식당-이름-엄중식',
  '식당-종류-중식당',
  '식당-주류 판매-yes',
  '식당-주차 가능-yes',
  '식당-지역-서울 북쪽'],
 'wild-bonus-5601:식당_택시_12-4': ['식당-가격대-dontcare',
  '식당-예약 명수-9',
  '식당-예약 시간-02:20',
  '식당-예약 요일-월요일',
  '식당-이름-엄중식',
  '식당-종류-중식당',
  '식당-주류 판매-yes',
  '식당-주차 가능-yes',
  '식당-지역-서울 북쪽'],
 'wild-bonus-5601:식당_택시_12-5': ['식당-가격대-dontcare',
  '식당-예약 명수-9',
  '식당-예약 시간-03:20',
  '식당-예약 요일-월요일',
  '식당-이름-엄중식',
  '식당-종류-중식당',
  '식당-주류 판매-yes',
  '식당-주차 가능-yes',
  '식당-지역-서울 북쪽',
  '택시-도착지-서울 역사 박물관',
  '택시-출발 시간-12:45',
  '택시-출발지-엄중식'],
 'wild-bonus-5601:식당_택시_12-6': ['식당-가격대-dontcare',
  '식당-예