In [None]:
#pip install prettyprinter

In [None]:
#pip3 install ruamel.yaml  # 안되면 conda install raumel.yaml

In [1]:
import sys
sys.path.append('..')

In [2]:
import argparse
import json
import os
import random
import pickle
import numpy as np
from tqdm import tqdm
from pathlib import Path
import glob
import re
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.cuda.amp import GradScaler, autocast
# # 이상 검출을 위한 코드
# torch.autograd.set_detect_anomaly(True)

from transformers import (BertTokenizer, 
                          BertModel,  
                          BertConfig,
                          AdamW, 
                          get_linear_schedule_with_warmup
                          )
from transformers.modeling_bert import BertOnlyMLMHead



from data_utils import (load_dataset, 
                        get_examples_from_dialogues, 
                        convert_state_dict, 
                        DSTInputExample, 
                        OpenVocabDSTFeature, 
                        DSTPreprocessor, 
                        WOSDataset,
                        custom_to_mask, 
                        custom_get_examples_from_dialogues,
                        set_seed, 
                        YamlConfigManager,
                        )

from inference import inference
from evaluation import _evaluation
from model import TRADE, masked_cross_entropy_for_value
from preprocessor import TRADEPreprocessor

from prettyprinter import cpprint
import wandb

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Device Name : {device}')

# torch.cuda.empty_cache()

Device Name : cuda


In [4]:
def increment_output_dir(output_path, exist_ok=False):
  path = Path(output_path)
  if (path.exists() and exist_ok) or (not path.exists()):
    return str(path)
  else:
    dirs = glob.glob(f"{path}*")
    matches = [re.search(rf"%s(\d+)" %path.stem, d) for d in dirs]
    i = [int(m.groups()[0]) for m in matches if m]
    n = max(i) + 1 if i else 2
    return f"{path}{n}"

In [5]:
cfg = YamlConfigManager('../config.yml', 'base').values
cpprint(cfg)

easydict.EasyDict({
    'data_dir': '../input/data/train_dataset',
    'model_dir': 'results',
    'train_batch_size': 4,
    'eval_batch_size': 8,
    'learning_rate': 3e-05,
    'adam_epsilon': 1e-08,
    'max_grad_norm': 1.0,
    'num_train_epochs': 30,
    'warmup_ratio': 0.0,
    'random_seed': 42,
    'n_gate': 5,
    'teacher_forcing_ratio': 0.5,
    'model_name_or_path': 'dsksd/bert-ko-small-minimal',
    'proj_dim': 'None',
    'tag': ['trade'],
    'use_kfold': False,
    'num_k': 0,
    'val_ratio': 0.1,
    'scheduler': 'Linear',
    'mask': True
})


In [6]:
import wandb
# !wandb login  # run once

In [7]:
# --wandb initialize with configuration
# wandb.init(project="TRADE")
wandb.init(project='DST', tags=cfg.tag, config=cfg)
config = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33mtaepd[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.30 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [8]:
# Get current learning rate
def get_lr(scheduler):
    return scheduler.get_last_lr()[0]

In [9]:
# random seed 고정
set_seed(cfg.random_seed)

# Data Loading
train_data_file = "/opt/ml/repo/taepd/input/data/train_dataset/train_dials.json"
slot_meta = json.load(open("/opt/ml/repo/taepd/input/data/train_dataset/slot_meta.json"))
ontology = json.load(open("/opt/ml/repo/taepd/input/data/train_dataset/ontology.json"))
train_data, dev_data, dev_labels = load_dataset(train_data_file, cfg.val_ratio)

# dialogue_level=False : SUMBT와 다르게 dialogue context level로 input하므로
# train_examples = get_examples_from_dialogues(train_data,
#                                              user_first=False,
#                                              dialogue_level=False)
# dev_examples = get_examples_from_dialogues(dev_data,
#                                            user_first=False,
#                                            dialogue_level=False)
train_examples = custom_get_examples_from_dialogues(
    train_data, user_first=False, dialogue_level=False
)
dev_examples = custom_get_examples_from_dialogues(
    dev_data, user_first=False, dialogue_level=False
)


100%|██████████| 6301/6301 [00:00<00:00, 6538.49it/s] 
100%|██████████| 699/699 [00:00<00:00, 10622.53it/s]


In [10]:
print(len(train_examples))
print(len(dev_examples))

46170
5075


In [11]:
train_examples[0]

DSTInputExample(guid='snowy-hat-8324:관광_식당_11-0', context_turns=[], current_turn=['', ' # ', '서울 중앙에 있는 박물관을 찾아주세요', ' * '], label=['관광-종류-박물관', '관광-지역-서울 중앙'])

## TRADE Preprocessor

BERT Encoder가 적용된 TRADE의 preprocessor입니다.

In [12]:
# class TRADEPreprocessor(DSTPreprocessor):
#     def __init__(
#         self,
#         slot_meta,
#         src_tokenizer,
#         trg_tokenizer=None,
#         ontology=None,
#         max_seq_length=512,
#     ):
#         self.slot_meta = slot_meta
#         self.src_tokenizer = src_tokenizer
#         self.trg_tokenizer = trg_tokenizer if trg_tokenizer else src_tokenizer
#         self.ontology = ontology
#         self.gating2id = {"none": 0, "dontcare": 1, "yes": 2, "no": 3, "ptr": 4}
#         self.id2gating = {v: k for k, v in self.gating2id.items()}
#         self.max_seq_length = max_seq_length

#     def _convert_example_to_feature(self, example):
#         dialogue_context = " [SEP] ".join(example.context_turns + example.current_turn)

#         input_id = self.src_tokenizer.encode(dialogue_context, add_special_tokens=False)
#         max_length = self.max_seq_length - 2
#         if len(input_id) > max_length:
#             gap = len(input_id) - max_length
#             input_id = input_id[gap:]

#         input_id = (
#             [self.src_tokenizer.cls_token_id]
#             + input_id
#             + [self.src_tokenizer.sep_token_id]
#         )
#         segment_id = [0] * len(input_id)

#         target_ids = []
#         gating_id = []
#         if not example.label:
#             example.label = []

#         state = convert_state_dict(example.label)
#         for slot in self.slot_meta:
#             value = state.get(slot, "none")
#             target_id = self.trg_tokenizer.encode(value, add_special_tokens=False) + [
#                 self.trg_tokenizer.sep_token_id
#             ]
#             target_ids.append(target_id)
#             gating_id.append(self.gating2id.get(value, self.gating2id["ptr"]))
#         target_ids = self.pad_ids(target_ids, self.trg_tokenizer.pad_token_id)
#         return OpenVocabDSTFeature(
#             example.guid, input_id, segment_id, gating_id, target_ids
#         )

#     def convert_examples_to_features(self, examples):
#         return list(map(self._convert_example_to_feature, examples))

#     def recover_state(self, gate_list, gen_list):
#         assert len(gate_list) == len(self.slot_meta)
#         assert len(gen_list) == len(self.slot_meta)

#         recovered = []
#         for slot, gate, value in zip(self.slot_meta, gate_list, gen_list):
#             if self.id2gating[gate] == "none":
#                 continue

#             if self.id2gating[gate] in ["dontcare", "yes", "no"]:
#                 recovered.append("%s-%s" % (slot, self.id2gating[gate]))
#                 continue

#             token_id_list = []
#             for id_ in value:
#                 if id_ in self.trg_tokenizer.all_special_ids:
#                     break

#                 token_id_list.append(id_)
#             value = self.trg_tokenizer.decode(token_id_list, skip_special_tokens=True)

#             if value == "none":
#                 continue

#             recovered.append("%s-%s" % (slot, value))
#         return recovered

#     def collate_fn(self, batch):
#         guids = [b.guid for b in batch]
#         input_ids = torch.LongTensor(
#             self.pad_ids([b.input_id for b in batch], self.src_tokenizer.pad_token_id)
#         )
#         segment_ids = torch.LongTensor(
#             self.pad_ids([b.segment_id for b in batch], self.src_tokenizer.pad_token_id)
#         )
#         input_masks = input_ids.ne(self.src_tokenizer.pad_token_id)

#         gating_ids = torch.LongTensor([b.gating_id for b in batch])
#         target_ids = self.pad_id_of_matrix(
#             [torch.LongTensor(b.target_ids) for b in batch],
#             self.trg_tokenizer.pad_token_id,
#         )
#         return input_ids, segment_ids, input_masks, gating_ids, target_ids, guids

## Convert_Examples_to_Features

In [13]:
# tokenizer = BertTokenizer.from_pretrained('dsksd/bert-ko-small-minimal')
# processor = TRADEPreprocessor(slot_meta, tokenizer, max_seq_length=512)

# train_features = processor.convert_examples_to_features(train_examples)
# dev_features = processor.convert_examples_to_features(dev_examples)

In [13]:
# Define Preprocessor
tokenizer = BertTokenizer.from_pretrained(cfg.model_name_or_path)

# Dealing with long texts The maximum sequence length of BERT is 512.
processor = TRADEPreprocessor(slot_meta, tokenizer, max_seq_length=512, n_gate=cfg.n_gate)

In [14]:
# # # Extracting Featrues
# cpprint('Extracting Features...')
# train_features = processor.sep_custom_convert_examples_to_features(train_examples)
# dev_features = processor.sep_custom_convert_examples_to_features(dev_examples)

'Extracting Features...'


Token indices sequence length is longer than the specified maximum sequence length for this model (539 > 512). Running this sequence through the model will result in indexing errors


KeyboardInterrupt: 

In [14]:
# 전체 train data InputFeatur 저장
# with open('custom_train_features.pickle', 'wb') as f:
#     pickle.dump(train_features, f)
# with open('custom_dev_features.pickle', 'wb') as f:
#     pickle.dump(dev_features, f)

In [14]:
# 저장된 파일 사용
with open('custom_train_features.pickle', 'rb') as f:
    train_features = pickle.load(f)
with open('custom_dev_features.pickle', 'rb') as f:
    dev_features = pickle.load(f)

In [16]:
print(len(train_features))
print(len(dev_features))

46170
5075


# TRADE

## Model

In [15]:
class TRADE(nn.Module):
    def __init__(self, config, slot_vocab, slot_meta, pad_idx=0):
        super(TRADE, self).__init__()
        self.config = config
        self.slot_meta = slot_meta
        if config.model_name_or_path:
            self.encoder = BertModel.from_pretrained(config.model_name_or_path)
        else:
            self.encoder = BertModel(config)
            
        self.decoder = SlotGenerator(
            config.vocab_size,
            config.hidden_size,
            config.hidden_dropout_prob,
            config.n_gate,
            None,
            pad_idx,
        )
        
        self.decoder.set_slot_idx(slot_vocab)
        
        self.mlm_head = BertOnlyMLMHead(config)
        self.tie_weight()

    def tie_weight(self):
        self.decoder.embed.weight = self.encoder.embeddings.word_embeddings.weight

    def forward(self, input_ids, token_type_ids, attention_mask=None, max_len=10, teacher=None):

        encoder_outputs, pooled_output = self.encoder(input_ids=input_ids)
        all_point_outputs, all_gate_outputs = self.decoder(
            input_ids, encoder_outputs, pooled_output.unsqueeze(0), attention_mask, max_len, teacher
        )

        return all_point_outputs, all_gate_outputs
    
    @staticmethod
    def mask_tokens(inputs, tokenizer, config, mlm_probability=0.15):
        """ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. """
        labels = inputs.clone()
        # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
        probability_matrix = torch.full(labels.shape, mlm_probability).to(device)
        #special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]

        probability_matrix.masked_fill_(torch.eq(labels, 0), value=0.0)

        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).to(device=device, dtype=torch.bool) & masked_indices
        inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(["[MASK]"])[0]

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).to(device=device, dtype=torch.bool) & masked_indices & ~indices_replaced
        random_words = torch.randint(config.vocab_size, labels.shape, device=device, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random].to(device)

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels
    
    def forward_pretrain(self, input_ids, tokenizer):
        input_ids, labels = self.mask_tokens(input_ids, tokenizer, self.config)
        encoder_outputs, _ = self.encoder(input_ids=input_ids)
        mlm_logits = self.mlm_head(encoder_outputs)
        
        return mlm_logits, labels
    
class SlotGenerator(nn.Module):
    def __init__(
        self, vocab_size, hidden_size, dropout, n_gate, proj_dim=None, pad_idx=0
    ):
        super(SlotGenerator, self).__init__()
        self.pad_idx = pad_idx
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(
            vocab_size, hidden_size, padding_idx=pad_idx
        )  # shared with encoder

        if proj_dim:
            self.proj_layer = nn.Linear(hidden_size, proj_dim, bias=False)
        else:
            self.proj_layer = None
        self.hidden_size = proj_dim if proj_dim else hidden_size

        self.gru = nn.GRU(
            self.hidden_size, self.hidden_size, 1, dropout=dropout, batch_first=True
        )
        self.n_gate = n_gate
        self.dropout = nn.Dropout(dropout)
        self.w_gen = nn.Linear(self.hidden_size * 3, 1)
        self.sigmoid = nn.Sigmoid()
        self.w_gate = nn.Linear(self.hidden_size, n_gate)

    def set_slot_idx(self, slot_vocab_idx):
        whole = []
        max_length = max(map(len, slot_vocab_idx))
        for idx in slot_vocab_idx:
            if len(idx) < max_length:
                gap = max_length - len(idx)
                idx.extend([self.pad_idx] * gap)
            whole.append(idx)
        self.slot_embed_idx = whole  # torch.LongTensor(whole)

    def embedding(self, x):
        x = self.embed(x)
        if self.proj_layer:
            x = self.proj_layer(x)
        return x

    def forward(
        self, input_ids, encoder_output, hidden, input_masks, max_len, teacher=None
    ):
        input_masks = input_masks.ne(1)
        # J, slot_meta : key : [domain, slot] ex> LongTensor([1,2])
        # J,2
        batch_size = encoder_output.size(0)
        slot = torch.LongTensor(self.slot_embed_idx).to(input_ids.device)  ##
        slot_e = torch.sum(self.embedding(slot), 1)  # J,d
        J = slot_e.size(0)

        all_point_outputs = torch.zeros(batch_size, J, max_len, self.vocab_size).to(
            input_ids.device
        )
        
        # Parallel Decoding
        w = slot_e.repeat(batch_size, 1).unsqueeze(1)
        hidden = hidden.repeat_interleave(J, dim=1)
        encoder_output = encoder_output.repeat_interleave(J, dim=0)
        input_ids = input_ids.repeat_interleave(J, dim=0)
        input_masks = input_masks.repeat_interleave(J, dim=0)
        for k in range(max_len):
            w = self.dropout(w)
            _, hidden = self.gru(w, hidden)  # 1,B,D

            # B,T,D * B,D,1 => B,T
            attn_e = torch.bmm(encoder_output, hidden.permute(1, 2, 0))  # B,T,1
#             attn_e = attn_e.squeeze(-1).masked_fill(input_masks, -1e9)
            attn_e = attn_e.squeeze(-1).masked_fill(input_masks, -1e4)
            attn_history = F.softmax(attn_e, -1)  # B,T

            if self.proj_layer:
                hidden_proj = torch.matmul(hidden, self.proj_layer.weight)
            else:
                hidden_proj = hidden

            # B,D * D,V => B,V
            attn_v = torch.matmul(
                hidden_proj.squeeze(0), self.embed.weight.transpose(0, 1)
            )  # B,V
            attn_vocab = F.softmax(attn_v, -1)

            # B,1,T * B,T,D => B,1,D
            context = torch.bmm(attn_history.unsqueeze(1), encoder_output)  # B,1,D
            p_gen = self.sigmoid(
                self.w_gen(torch.cat([w, hidden.transpose(0, 1), context], -1))
            )  # B,1
            p_gen = p_gen.squeeze(-1)

            p_context_ptr = torch.zeros_like(attn_vocab).to(input_ids.device)
            p_context_ptr.scatter_add_(1, input_ids, attn_history)  # copy B,V
            p_final = p_gen * attn_vocab + (1 - p_gen) * p_context_ptr  # B,V
            _, w_idx = p_final.max(-1)

            if teacher is not None:
                w = self.embedding(teacher[:, :, k]).transpose(0, 1).reshape(batch_size * J, 1, -1)
            else:
                w = self.embedding(w_idx).unsqueeze(1)  # B,1,D
            if k == 0:
                gated_logit = self.w_gate(context.squeeze(1))  # B,3
                all_gate_outputs = gated_logit.view(batch_size, J, self.n_gate)
            all_point_outputs[:, :, k, :] = p_final.view(batch_size, J, self.vocab_size)

        return all_point_outputs, all_gate_outputs

## 모델 및 데이터 로더 정의

In [16]:
slot_vocab = []
for slot in slot_meta:
    slot_vocab.append(
        tokenizer.encode(slot.replace('-', ' '),
                         add_special_tokens=False)
    )
    
config = BertConfig.from_pretrained('dsksd/bert-ko-small-minimal')
config.model_name_or_path = 'dsksd/bert-ko-small-minimal'
# config.n_gate = len(processor.gating2id)
config.n_gate = cfg.n_gate
config.proj_dim = None

model = TRADE(config, slot_vocab, slot_meta)

model.to(device)
print("Model is initialized")

  "num_layers={}".format(dropout, num_layers))


Model is initialized


In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_data = WOSDataset(train_features)
train_sampler = RandomSampler(train_data)
train_loader = DataLoader(train_data, 
                          batch_size=16, 
                          sampler=train_sampler, 
                          collate_fn=processor.collate_fn,
                          num_workers=4,
                          pin_memory=True,
                          )

print("# train:", len(train_data))

dev_data = WOSDataset(dev_features)
dev_sampler = SequentialSampler(dev_data)
dev_loader = DataLoader(dev_data, 
                        batch_size=8, 
                        sampler=dev_sampler, 
                        collate_fn=processor.collate_fn,
                        num_workers=4,
                        pin_memory=True,
                        )

print("# dev:", len(dev_data))

# train: 46170
# dev: 5075


## Optimizer & Scheduler 선언

In [18]:
n_epochs = cfg.num_train_epochs

no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

t_total = len(train_loader) * n_epochs
optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.learning_rate, eps=cfg.adam_epsilon)
warmup_steps = int(t_total * cfg.warmup_ratio)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total
)
teacher_forcing = cfg.teacher_forcing_ratio
# model.to(device)

def masked_cross_entropy_for_value(logits, target, pad_idx=0):
    mask = target.ne(pad_idx)
    logits_flat = logits.view(-1, logits.size(-1))
    log_probs_flat = torch.log(logits_flat)
    target_flat = target.view(-1, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    losses = losses_flat.view(*target.size())
    losses = losses * mask.float()
    loss = losses.sum() / (mask.sum().float())
    return loss

loss_fnc_1 = masked_cross_entropy_for_value  # generation
loss_fnc_2 = nn.CrossEntropyLoss()  # gating
loss_fnc_pretrain = nn.CrossEntropyLoss()  # MLM pretrain

In [19]:
# 모델 저장될 파일 위치 생성
if not os.path.exists(f"{cfg.model_dir}/{wandb.run.name}"):
#     os.mkdir(f"{cfg.model_dir}")
    os.mkdir(f"{cfg.model_dir}/{wandb.run.name}")

In [20]:
json.dump(
    vars(cfg),
    open(f"{cfg.model_dir}/{wandb.run.name}/exp_config.json", "w"),
    indent=2,
    ensure_ascii=False,
)
json.dump(
    slot_meta,
    open(f"{cfg.model_dir}/slot_meta.json", "w"),
    indent=2,
    ensure_ascii=False,
)

## Pretraining

In [21]:
MLM_PRE = True

scaler = GradScaler()
n_pretrain_epochs = 3

def mlm_pretrain(loader, n_epochs):
    model.train()
    for step, batch in enumerate(tqdm(loader)):
        input_ids, segment_ids, input_masks, gating_ids, target_ids, guids = [b.to(device) if not isinstance(b, list) else b for b in batch]
        
        with autocast(): # 밑에 해당하는 코드를 자동으로 mixed precision으로 변환시켜서 실행
            logits, labels = model.forward_pretrain(input_ids, tokenizer)
            loss = loss_fnc_pretrain(logits.view(-1, config.vocab_size), labels.view(-1))

        scaler.scale(loss).backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        optimizer.zero_grad()

        if step % 100 == 0:
            print('[%d/%d] [%d/%d] %f' % (epoch, n_epochs, step, len(loader), loss.item()))

if MLM_PRE:
    for epoch in range(n_pretrain_epochs):
        mlm_pretrain(train_loader, n_pretrain_epochs)

## 모델 학습

In [22]:
# torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'scheduler.state_dict': scheduler.state_dict(),
#         },'tmp.model.bin')

ckpt = torch.load('tmp.model.bin')
model.load_state_dict(ckpt['model_state_dict'])
model.to(device)
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
# scheduler.load_state_dict(ckpt['scheduler.state_dict'])

In [23]:
MLM_DURING = False

# backward pass시 gradient 정보가 손실되지 않게 하려고 사용(loss에 scale factor를 곱해서 gradient 값이 너무 작아지는 것을 방지)
scaler = GradScaler()
best_score, best_checkpoint = 0, 0

for epoch in range(n_epochs):
    batch_loss = []
    model.train()
    for step, batch in enumerate(tqdm(train_loader)):
        input_ids, segment_ids, input_masks, gating_ids, target_ids, guids = [
            b.to(device) if not isinstance(b, list) else b for b in batch
        ]
        # mask
        if cfg.mask:
            change_mask_prop = 0.8
            mask_p = random.random()
            if cfg.mask and mask_p < change_mask_prop:
                input_ids = custom_to_mask(input_ids)
        # teacher forcing
        if teacher_forcing > 0.0 and random.random() < teacher_forcing:
            tf = target_ids
        else:
            tf = None
            
        optimizer.zero_grad()
        
        with autocast():  # 밑에 해당하는 코드를 자동으로 mixed precision으로 변환시켜서 실행
            all_point_outputs, all_gate_outputs = model(input_ids, segment_ids, input_masks, target_ids.size(-1))  # gt - length (generation)
            # generation loss
            loss_1 = loss_fnc_1(all_point_outputs.contiguous(), target_ids.contiguous().view(-1))
            # gating loss
            loss_2 = loss_fnc_2(all_gate_outputs.contiguous().view(-1, 5), gating_ids.contiguous().view(-1))
            loss = loss_1 + loss_2
        
        batch_loss.append(loss.item())

        scaler.scale(loss).backward()
        nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        
#         optimizer.zero_grad()
        
        # global_step 추가 부분
        wandb.log({"train/learning_rate": get_lr(scheduler),
                   "train/epoch": epoch
                   })
        if step % 100 == 0:
            print(
                 f"[{epoch}/{n_epochs}] [{step}/{len(train_loader)}] loss: {loss.item()} gen: {loss_1.item()} gate: {loss_2.item()}"
            )
            wandb.log({
                "train/loss": loss.item(),
                "train/gen_loss": loss_1.item(),
                "train/gate_loss": loss_2.item(),
            })
            
    if MLM_DURING:
        mlm_pretrain(train_loader, n_epochs)
                
    predictions = inference(model, dev_loader, processor, device, cfg.n_gate)
    eval_result, batch_miss_labels = _evaluation(predictions, dev_labels, slot_meta)
    
#     # -- eval 단계에서 Loss, Accuracy 로그 저장
#     wandb.log({
#         "eval/join_goal_acc": eval_result["joint_goal_accuracy"],
#         "eval/turn_slot_f1": eval_result["turn_slot_f1"],
#         "eval/turn_slot_acc": eval_result["turn_slot_accuracy"],
#     })
    
#     for k, v in eval_result.items():
#         print(f"{k}: {v}")
        
    if best_score < eval_result['joint_goal_accuracy']:
        cpprint(f"--Update Best checkpoint!, epoch: {epoch+1}")
        best_score = eval_result['joint_goal_accuracy']
        best_checkpoint = epoch
        if not os.path.isdir(cfg.model_dir):
            os.makedirs(cfg.model_dir)
        print("--Saving best model checkpoint")
        torch.save(model.state_dict(), f"{cfg.model_dir}/{wandb.run.name}/best.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler.state_dict': scheduler.state_dict(),
            'loss': loss.item(),
            'gen_loss': loss_1.item(),
            'gate_loss': loss_2.item(),
        }, os.path.join(f"{cfg.model_dir}/{wandb.run.name}", "training_best_checkpoint.bin"))


    torch.save(model.state_dict(), f"{cfg.model_dir}/{wandb.run.name}/last.pth")
#     print(f"time for 1 epoch: {time.time() - start_time}")    

  0%|          | 1/2886 [00:01<51:43,  1.08s/it]

[0/30] [0/2886] loss: 0.0385856069624424 gen: 0.03522958606481552 gate: 0.003356022061780095


  3%|▎         | 101/2886 [01:01<29:13,  1.59it/s]

[0/30] [100/2886] loss: 0.04682090878486633 gen: 0.031805574893951416 gate: 0.015015332959592342


  7%|▋         | 201/2886 [02:01<27:03,  1.65it/s]

[0/30] [200/2886] loss: 0.04550093784928322 gen: 0.04474420100450516 gate: 0.0007567384163849056


 10%|█         | 301/2886 [03:00<24:57,  1.73it/s]

[0/30] [300/2886] loss: 0.02762472629547119 gen: 0.027412667870521545 gate: 0.00021205896337050945


 14%|█▍        | 401/2886 [03:58<21:47,  1.90it/s]

[0/30] [400/2886] loss: 0.01613502763211727 gen: 0.01571708917617798 gate: 0.00041793924174271524


 17%|█▋        | 501/2886 [04:57<24:30,  1.62it/s]

[0/30] [500/2886] loss: 0.030067481100559235 gen: 0.02944207191467285 gate: 0.0006254087202250957


 21%|██        | 601/2886 [05:55<23:51,  1.60it/s]

[0/30] [600/2886] loss: 0.030410708859562874 gen: 0.02914070338010788 gate: 0.001270006294362247


 24%|██▍       | 701/2886 [06:54<22:15,  1.64it/s]

[0/30] [700/2886] loss: 0.05332624539732933 gen: 0.05118023231625557 gate: 0.00214601238258183


 28%|██▊       | 801/2886 [07:54<18:26,  1.88it/s]

[0/30] [800/2886] loss: 0.025891892611980438 gen: 0.024980438873171806 gate: 0.0009114528656937182


 31%|███       | 901/2886 [08:53<17:36,  1.88it/s]

[0/30] [900/2886] loss: 0.01877494528889656 gen: 0.018709994852542877 gate: 6.495110574178398e-05


 35%|███▍      | 1001/2886 [09:52<18:16,  1.72it/s]

[0/30] [1000/2886] loss: 0.030525047332048416 gen: 0.030432933941483498 gate: 9.211351425619796e-05


 38%|███▊      | 1101/2886 [10:50<17:01,  1.75it/s]

[0/30] [1100/2886] loss: 0.016968248412013054 gen: 0.013309621252119541 gate: 0.0036586278583854437


 42%|████▏     | 1201/2886 [11:49<15:53,  1.77it/s]

[0/30] [1200/2886] loss: 0.0171981081366539 gen: 0.017080456018447876 gate: 0.00011765128874685615


 45%|████▌     | 1301/2886 [12:47<14:36,  1.81it/s]

[0/30] [1300/2886] loss: 0.02037067897617817 gen: 0.02032819204032421 gate: 4.2487768951104954e-05


 49%|████▊     | 1401/2886 [13:46<14:47,  1.67it/s]

[0/30] [1400/2886] loss: 0.032428160309791565 gen: 0.032327800989151 gate: 0.00010035950981546193


 52%|█████▏    | 1501/2886 [14:45<12:54,  1.79it/s]

[0/30] [1500/2886] loss: 0.025729717686772346 gen: 0.025449499487876892 gate: 0.00028021863545291126


 55%|█████▌    | 1601/2886 [15:43<13:25,  1.60it/s]

[0/30] [1600/2886] loss: 0.015783799812197685 gen: 0.015704961493611336 gate: 7.883746729930863e-05


 59%|█████▉    | 1701/2886 [16:41<12:24,  1.59it/s]

[0/30] [1700/2886] loss: 0.028908835723996162 gen: 0.02851296029984951 gate: 0.0003958755114581436


 62%|██████▏   | 1801/2886 [17:40<10:38,  1.70it/s]

[0/30] [1800/2886] loss: 0.03069692850112915 gen: 0.030636005103588104 gate: 6.092256444389932e-05


 66%|██████▌   | 1901/2886 [18:39<10:00,  1.64it/s]

[0/30] [1900/2886] loss: 0.013830875046551228 gen: 0.01379421167075634 gate: 3.66633539670147e-05


 69%|██████▉   | 2001/2886 [19:37<08:21,  1.76it/s]

[0/30] [2000/2886] loss: 0.016865618526935577 gen: 0.016823071986436844 gate: 4.254604209563695e-05


 73%|███████▎  | 2101/2886 [20:35<07:39,  1.71it/s]

[0/30] [2100/2886] loss: 0.019694289192557335 gen: 0.01956585794687271 gate: 0.00012843142030760646


 76%|███████▋  | 2201/2886 [21:36<06:44,  1.69it/s]

[0/30] [2200/2886] loss: 0.02744816616177559 gen: 0.027282744646072388 gate: 0.0001654222869547084


 80%|███████▉  | 2301/2886 [22:34<05:17,  1.84it/s]

[0/30] [2300/2886] loss: 0.02487608790397644 gen: 0.024702219292521477 gate: 0.00017386891704518348


 83%|████████▎ | 2401/2886 [23:33<04:45,  1.70it/s]

[0/30] [2400/2886] loss: 0.02906790003180504 gen: 0.02841595932841301 gate: 0.000651939888484776


 87%|████████▋ | 2501/2886 [24:31<03:42,  1.73it/s]

[0/30] [2500/2886] loss: 0.01771087944507599 gen: 0.01763898693025112 gate: 7.189202005974948e-05


 90%|█████████ | 2601/2886 [25:32<02:38,  1.79it/s]

[0/30] [2600/2886] loss: 0.03462297096848488 gen: 0.034576211124658585 gate: 4.6761597332078964e-05


 94%|█████████▎| 2701/2886 [26:30<01:53,  1.64it/s]

[0/30] [2700/2886] loss: 0.030892521142959595 gen: 0.012497988529503345 gate: 0.018394531682133675


 97%|█████████▋| 2801/2886 [27:28<00:49,  1.72it/s]

[0/30] [2800/2886] loss: 0.030797092244029045 gen: 0.026598118245601654 gate: 0.004198973998427391


100%|██████████| 2886/2886 [28:18<00:00,  1.70it/s]
100%|██████████| 635/635 [02:00<00:00,  5.26it/s]


{'joint_goal_accuracy': 0.6108374384236454, 'turn_slot_accuracy': 0.9870695128626251, 'turn_slot_f1': 0.9413211894206762}
'--Update Best checkpoint!, epoch: 1'
--Saving best model checkpoint


  0%|          | 1/2886 [00:01<49:38,  1.03s/it]

[1/30] [0/2886] loss: 0.045541439205408096 gen: 0.02865423820912838 gate: 0.016887200996279716


  3%|▎         | 101/2886 [01:01<27:42,  1.67it/s]

[1/30] [100/2886] loss: 0.021033475175499916 gen: 0.01873994991183281 gate: 0.002293525729328394


  7%|▋         | 201/2886 [02:00<25:06,  1.78it/s]

[1/30] [200/2886] loss: 0.011448057368397713 gen: 0.009432303719222546 gate: 0.0020157531835138798


 10%|█         | 301/2886 [03:00<24:57,  1.73it/s]

[1/30] [300/2886] loss: 0.027593042701482773 gen: 0.02735261060297489 gate: 0.0002404312981525436


 14%|█▍        | 401/2886 [03:58<23:09,  1.79it/s]

[1/30] [400/2886] loss: 0.04968748241662979 gen: 0.02602601796388626 gate: 0.02366146445274353


 17%|█▋        | 501/2886 [04:56<20:54,  1.90it/s]

[1/30] [500/2886] loss: 0.045130111277103424 gen: 0.043976079672575 gate: 0.001154032303020358


 21%|██        | 601/2886 [05:55<23:24,  1.63it/s]

[1/30] [600/2886] loss: 0.041153766214847565 gen: 0.02801438421010971 gate: 0.013139383867383003


 24%|██▍       | 701/2886 [06:53<22:16,  1.63it/s]

[1/30] [700/2886] loss: 0.010424511507153511 gen: 0.010145289823412895 gate: 0.0002792217128444463


 28%|██▊       | 801/2886 [07:52<21:21,  1.63it/s]

[1/30] [800/2886] loss: 0.018742863088846207 gen: 0.016397370025515556 gate: 0.0023454935289919376


 31%|███       | 901/2886 [08:50<19:52,  1.66it/s]

[1/30] [900/2886] loss: 0.04873919486999512 gen: 0.03398391976952553 gate: 0.014755276963114738


 35%|███▍      | 1001/2886 [09:49<19:00,  1.65it/s]

[1/30] [1000/2886] loss: 0.05258360877633095 gen: 0.04790079966187477 gate: 0.004682808183133602


 38%|███▊      | 1101/2886 [10:47<17:59,  1.65it/s]

[1/30] [1100/2886] loss: 0.026704492047429085 gen: 0.025701720267534256 gate: 0.0010027713142335415


 42%|████▏     | 1201/2886 [11:45<15:05,  1.86it/s]

[1/30] [1200/2886] loss: 0.03163466230034828 gen: 0.030045390129089355 gate: 0.0015892722876742482


 45%|████▌     | 1301/2886 [12:45<15:04,  1.75it/s]

[1/30] [1300/2886] loss: 0.04021570459008217 gen: 0.03553926199674606 gate: 0.00467644352465868


 49%|████▊     | 1401/2886 [13:44<14:35,  1.70it/s]

[1/30] [1400/2886] loss: 0.01624140702188015 gen: 0.01584966853260994 gate: 0.0003917384019587189


 52%|█████▏    | 1501/2886 [14:43<13:42,  1.68it/s]

[1/30] [1500/2886] loss: 0.016415314748883247 gen: 0.014272852800786495 gate: 0.0021424617152661085


 55%|█████▌    | 1601/2886 [15:42<12:27,  1.72it/s]

[1/30] [1600/2886] loss: 0.05150020122528076 gen: 0.044073570519685745 gate: 0.0074266307055950165


 59%|█████▉    | 1701/2886 [16:42<12:42,  1.55it/s]

[1/30] [1700/2886] loss: 0.053172096610069275 gen: 0.03451456502079964 gate: 0.018657531589269638


 62%|██████▏   | 1801/2886 [17:40<10:45,  1.68it/s]

[1/30] [1800/2886] loss: 0.05333475396037102 gen: 0.03993688523769379 gate: 0.013397869653999805


 66%|██████▌   | 1901/2886 [18:39<08:32,  1.92it/s]

[1/30] [1900/2886] loss: 0.04136032983660698 gen: 0.0401897169649601 gate: 0.001170613570138812


 67%|██████▋   | 1939/2886 [19:02<09:18,  1.70it/s]


RuntimeError: CUDA out of memory. Tried to allocate 1.05 GiB (GPU 0; 31.72 GiB total capacity; 24.97 GiB already allocated; 693.56 MiB free; 29.83 GiB reserved in total by PyTorch)

## Inference

In [None]:
eval_data = json.load(open(f"/opt/ml/input/data/eval_dataset/eval_dials.json", "r"))

eval_examples = get_examples_from_dialogues(
    eval_data, user_first=False, dialogue_level=False
)

# Extracting Featrues
eval_features = processor.convert_examples_to_features(eval_examples)
eval_data = WOSDataset(eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_loader = DataLoader(
    eval_data,
    batch_size=8,
    sampler=eval_sampler,
    collate_fn=processor.collate_fn,
)

100%|██████████| 1000/1000 [00:00<00:00, 36946.73it/s]


In [None]:
predictions = inference(model, eval_loader, processor, device)

100%|██████████| 944/944 [01:45<00:00,  8.92it/s]


In [None]:
json.dump(predictions, open('predictions.csv', 'w'), indent=2, ensure_ascii=False)