In [1]:
import json
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

from transformers import BertModel, BertTokenizer, BertConfig, AdamW, get_linear_schedule_with_warmup
from data_utils import (
    load_dataset, 
    get_examples_from_dialogues, 
    convert_state_dict, 
    DSTInputExample, 
    OpenVocabDSTFeature, 
    DSTPreprocessor, 
    WOSDataset)
    
from inference import inference
from evaluation import _evaluation

## Data loading

In [2]:
train_data_file = "/opt/ml/input/data/train_dataset/train_dials.json"
slot_meta = json.load(open("/opt/ml/input/data/train_dataset/slot_meta.json"))
ontology = json.load(open("/opt/ml/input/data/train_dataset/ontology.json"))
train_data, dev_data, dev_labels = load_dataset(train_data_file)

train_examples = get_examples_from_dialogues(train_data,
                                             user_first=False,
                                             dialogue_level=False)
dev_examples = get_examples_from_dialogues(dev_data,
                                           user_first=False,
                                           dialogue_level=False)

100%|██████████| 6301/6301 [00:00<00:00, 8438.29it/s]
100%|██████████| 699/699 [00:00<00:00, 14980.50it/s]


In [3]:
print(len(train_examples))
print(len(dev_examples))

46258
4987


## TRADE Preprocessor 

기존의 GRU 기반의 인코더를 BERT-based Encoder로 바꿀 준비를 합시다.

1. 현재 `_convert_example_to_feature`에서는 `max_seq_length`를 핸들하고 있지 않습니다. `input_id`와 `segment_id`가 `max_seq_length`를 넘어가면 좌측부터 truncate시키는 코드를 삽입하세요.

2. `recover_state`를 구현해 보세요.

In [4]:
class TRADEPreprocessor(DSTPreprocessor):
    def __init__(
        self,
        slot_meta,
        src_tokenizer,
        trg_tokenizer=None,
        ontology=None,
        max_seq_length=512,
    ):
        self.slot_meta = slot_meta
        self.src_tokenizer = src_tokenizer
        self.trg_tokenizer = trg_tokenizer if trg_tokenizer else src_tokenizer
        self.ontology = ontology     
        self.gating2id = {"none": 0, "dontcare": 1, "ptr": 2}     
        self.id2gating = {v: k for k, v in self.gating2id.items()}
        self.max_seq_length = max_seq_length

    def _convert_example_to_feature(self, example):
        '''
        
        '''
        #context_turns는 t-1까지의 모든 대화, current_turn은 현재 대화 이므로 현재까지의 모든대화를 dialogue context에 넣는다.
        #논문에서는 slide window를 이용한다고 하는데 여기서 k값이 문장 전체라고 보면 될 것 같다.
        dialogue_context = " [SEP] ".join(example.context_turns + example.current_turn)
        print(dialogue_context)
        #tokenizing 
        input_id = self.src_tokenizer.encode(dialogue_context, add_special_tokens=False)
        print(input_id)
        
        #max_length보다 크면 왼쪽부터 truncating
        max_length = self.max_seq_length - 2
        
        if len(input_id) > max_length:
            gap = len(input_id) - max_length
            input_id = input_id[gap:]
        #cls 토큰 및 sep 토큰 부착
        input_id = (
            [self.src_tokenizer.cls_token_id]
            + input_id
            + [self.src_tokenizer.sep_token_id]
        )
        #segment_id는 문장을 구분해주는 id
        segment_id = [0] * len(input_id)
        
        '''
        05.12 15시 기준
        gating_id, target_ids의 개념을 현재 잘 이해하지 못하고 있음
        공부 후 다시 재도전
        '''
        
        target_ids = []
        gating_id = []
        #label이 none일수도 있는지 없는경우 list를 할당해준다.
        if not example.label:
            example.label = []

        state = convert_state_dict(example.label)
        #print(state)
        #print(slot_meta)
        for slot in self.slot_meta:
            value = state.get(slot, "none")
            #slot의 value값을 tokenizing
            target_id = self.trg_tokenizer.encode(value, add_special_tokens=False) + [
                self.trg_tokenizer.sep_token_id
            ]
            #print(self.trg_tokenizer.decode([21832,11764]),target_id)

            target_ids.append(target_id)
            #value가 있으면 value, 아니면 ptr이 gating_id로 들어감
            gating_id.append(self.gating2id.get(value, self.gating2id["ptr"]))
            #print(gating_id)
            #gate idx를 gate_id에 input
        target_ids = self.pad_ids(target_ids, self.trg_tokenizer.pad_token_id)
        #openvacabDSTFeature에는 guid, input_id, segment_id(bert 사용 시), gating_id(gate 사용), target_ids의 정보를 갖고 있음
        return OpenVocabDSTFeature(
            example.guid, input_id, segment_id, gating_id, target_ids
        )
    
    def convert_examples_to_features(self, examples):
        return list(map(self._convert_example_to_feature, examples))

    def recover_state(self, gate_list, gen_list):
        # problem 2.
        # Your code here!
        raise Exception('TRADE의 아웃풋을 prediction form으로 바꾸는 코드를 작성하세요!')
        
        return recovered

    def collate_fn(self, batch):
        guids = [b.guid for b in batch]
        input_ids = torch.LongTensor(
            self.pad_ids([b.input_id for b in batch], self.src_tokenizer.pad_token_id)
        )
        segment_ids = torch.LongTensor(
            self.pad_ids([b.segment_id for b in batch], self.src_tokenizer.pad_token_id)
        )
        input_masks = input_ids.ne(self.src_tokenizer.pad_token_id)

        gating_ids = torch.LongTensor([b.gating_id for b in batch])
        target_ids = self.pad_id_of_matrix(
            [torch.LongTensor(b.target_ids) for b in batch],
            self.trg_tokenizer.pad_token_id,
        )
        return input_ids, segment_ids, input_masks, gating_ids, target_ids, guids

## Convert_Examples_to_Features 

In [None]:
tokenizer = BertTokenizer.from_pretrained('dsksd/bert-ko-small-minimal')
processor = TRADEPreprocessor(slot_meta, tokenizer, max_seq_length=512)

train_features = processor.convert_examples_to_features([train_examples[0]])
#dev_features = processor.convert_examples_to_features(dev_examples)

In [51]:
#1번째 턴의 inpur_id
input_id = train_features[0].input_id
tokenizer.convert_ids_to_tokens(input_id)
train_examples[0]

DSTInputExample(guid='snowy-hat-8324:관광_식당_11-0', context_turns=[], current_turn=['', '서울 중앙에 있는 박물관을 찾아주세요'], label=['관광-종류-박물관', '관광-지역-서울 중앙'])

In [19]:
print(len(train_features))
print(len(dev_features))

46245
5000


# Model 

In [None]:
class TRADE(nn.Module):
    def __init__(self, config, slot_vocab, slot_meta, pad_idx=0):
        super(TRADE, self).__init__()
        self.slot_meta = slot_meta
        
        self.encoder = GRUEncoder(
            config.vocab_size,
            config.hidden_size,
            1,
            config.hidden_dropout_prob,
            config.proj_dim,
            pad_idx,
        )
        
        self.decoder = SlotGenerator(
            config.vocab_size,
            config.hidden_size,
            config.hidden_dropout_prob,
            config.n_gate,
            None,
            pad_idx,
        )
        
        # init for only subword embedding
        self.decoder.set_slot_idx(slot_vocab)
        self.tie_weight()

        
    def tie_weight(self):
        self.decoder.embed.weight = self.encoder.embed.weight
        if self.decoder.proj_layer:
            self.decoder.proj_layer.weight = self.encoder.proj_layer.weight

    def forward(self, input_ids, token_type_ids, attention_mask=None, max_len=10, teacher=None):

        encoder_outputs, pooled_output = self.encoder(input_ids=input_ids)
        all_point_outputs, all_gate_outputs = self.decoder(
            input_ids, encoder_outputs, pooled_output.unsqueeze(0), attention_mask, max_len, teacher
        )

        return all_point_outputs, all_gate_outputs
    

class GRUEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, n_layer, dropout, proj_dim=None, pad_idx=0):
        super(GRUEncoder, self).__init__()
        self.pad_idx = pad_idx
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        if proj_dim:
            self.proj_layer = nn.Linear(d_model, proj_dim, bias=False)
        else:
            self.proj_layer = None

        self.d_model = proj_dim if proj_dim else d_model
        self.gru = nn.GRU(
            self.d_model,
            self.d_model,
            n_layer,
            dropout=dropout,
            batch_first=True,
            bidirectional=True,
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids):
        mask = input_ids.eq(self.pad_idx).unsqueeze(-1)
        x = self.embed(input_ids)
        if self.proj_layer:
            x = self.proj_layer(x)
        x = self.dropout(x)
        o, h = self.gru(x)
        o = o.masked_fill(mask, 0.0)
        output = o[:, :, : self.d_model] + o[:, :, self.d_model :]
        hidden = h[0] + h[1]  # n_layer 고려
        return output, hidden
    
    
class SlotGenerator(nn.Module):
    def __init__(
        self, vocab_size, hidden_size, dropout, n_gate, proj_dim=None, pad_idx=0
    ):
        super(SlotGenerator, self).__init__()
        self.pad_idx = pad_idx
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(
            vocab_size, hidden_size, padding_idx=pad_idx
        )  # shared with encoder

        if proj_dim:
            self.proj_layer = nn.Linear(hidden_size, proj_dim, bias=False)
        else:
            self.proj_layer = None
        self.hidden_size = proj_dim if proj_dim else hidden_size

        self.gru = nn.GRU(
            self.hidden_size, self.hidden_size, 1, dropout=dropout, batch_first=True
        )
        self.n_gate = n_gate
        self.dropout = nn.Dropout(dropout)
        self.w_gen = nn.Linear(self.hidden_size * 3, 1)
        self.sigmoid = nn.Sigmoid()
        self.w_gate = nn.Linear(self.hidden_size, n_gate)

    def set_slot_idx(self, slot_vocab_idx):
        whole = []
        max_length = max(map(len, slot_vocab_idx))
        for idx in slot_vocab_idx:
            if len(idx) < max_length:
                gap = max_length - len(idx)
                idx.extend([self.pad_idx] * gap)
            whole.append(idx)
        self.slot_embed_idx = whole  # torch.LongTensor(whole)

    def embedding(self, x):
        x = self.embed(x)
        if self.proj_layer:
            x = self.proj_layer(x)
        return x

    def forward(
        self, input_ids, encoder_output, hidden, input_masks, max_len, teacher=None
    ):
        input_masks = input_masks.ne(1)
        # J, slot_meta : key : [domain, slot] ex> LongTensor([1,2])
        # J,2
        batch_size = encoder_output.size(0)
        slot = torch.LongTensor(self.slot_embed_idx).to(input_ids.device)  ##
        slot_e = torch.sum(self.embedding(slot), 1)  # J,d
        J = slot_e.size(0)

        all_point_outputs = torch.zeros(batch_size, J, max_len, self.vocab_size).to(
            input_ids.device
        )
        
        # Parallel Decoding
        w = slot_e.repeat(batch_size, 1).unsqueeze(1)
        hidden = hidden.repeat_interleave(J, dim=1)
        encoder_output = encoder_output.repeat_interleave(J, dim=0)
        input_ids = input_ids.repeat_interleave(J, dim=0)
        input_masks = input_masks.repeat_interleave(J, dim=0)
        for k in range(max_len):
            w = self.dropout(w)
            _, hidden = self.gru(w, hidden)  # 1,B,D

            # B,T,D * B,D,1 => B,T
            attn_e = torch.bmm(encoder_output, hidden.permute(1, 2, 0))  # B,T,1
            attn_e = attn_e.squeeze(-1).masked_fill(input_masks, -1e9)
            attn_history = F.softmax(attn_e, -1)  # B,T

            if self.proj_layer:
                hidden_proj = torch.matmul(hidden, self.proj_layer.weight)
            else:
                hidden_proj = hidden

            # B,D * D,V => B,V
            attn_v = torch.matmul(
                hidden_proj.squeeze(0), self.embed.weight.transpose(0, 1)
            )  # B,V
            attn_vocab = F.softmax(attn_v, -1)

            # B,1,T * B,T,D => B,1,D
            context = torch.bmm(attn_history.unsqueeze(1), encoder_output)  # B,1,D
            p_gen = self.sigmoid(
                self.w_gen(torch.cat([w, hidden.transpose(0, 1), context], -1))
            )  # B,1
            p_gen = p_gen.squeeze(-1)

            p_context_ptr = torch.zeros_like(attn_vocab).to(input_ids.device)
            p_context_ptr.scatter_add_(1, input_ids, attn_history)  # copy B,V
            p_final = p_gen * attn_vocab + (1 - p_gen) * p_context_ptr  # B,V
            _, w_idx = p_final.max(-1)

            if teacher is not None:
                w = self.embedding(teacher[:, :, k]).transpose(0, 1).reshape(batch_size * J, 1, -1)
            else:
                w = self.embedding(w_idx).unsqueeze(1)  # B,1,D
            if k == 0:
                gated_logit = self.w_gate(context.squeeze(1))  # B,3
                all_gate_outputs = gated_logit.view(batch_size, J, self.n_gate)
            all_point_outputs[:, :, k, :] = p_final.view(batch_size, J, self.vocab_size)

        return all_point_outputs, all_gate_outputs

# 모델 및 데이터 로더 정의

In [None]:
slot_vocab = []
for slot in slot_meta:
    slot_vocab.append(
        tokenizer.encode(slot.replace('-', ' '),
                         add_special_tokens=False)
    )
    
config = BertConfig.from_pretrained('dsksd/bert-ko-small-minimal')
config.model_name_or_path = 'dsksd/bert-ko-small-minimal'
config.n_gate = len(processor.gating2id)
config.proj_dim = None
model = TRADE(config, slot_vocab, slot_meta)



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_data = WOSDataset(train_features)
train_sampler = RandomSampler(train_data)
train_loader = DataLoader(train_data, batch_size=4, sampler=train_sampler, collate_fn=processor.collate_fn)

dev_data = WOSDataset(dev_features)
dev_sampler = SequentialSampler(dev_data)
dev_loader = DataLoader(dev_data, batch_size=8, sampler=dev_sampler, collate_fn=processor.collate_fn)

# Optimizer & Scheduler 선언

In [None]:
n_epochs = 10
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

t_total = len(train_loader) * n_epochs
optimizer = AdamW(optimizer_grouped_parameters, lr=3e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0.1, num_training_steps=t_total
)
teacher_forcing = 0.5
model.to(device)

def masked_cross_entropy_for_value(logits, target, pad_idx=0):
    mask = target.ne(pad_idx)
    logits_flat = logits.view(-1, logits.size(-1))
    log_probs_flat = torch.log(logits_flat)
    target_flat = target.view(-1, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    losses = losses_flat.view(*target.size())
    losses = losses * mask.float()
    loss = losses.sum() / (mask.sum().float())
    return loss

loss_fnc_1 = masked_cross_entropy_for_value  # generation
loss_fnc_2 = nn.CrossEntropyLoss()  # gating

## Train

In [None]:
for epoch in range(n_epochs):
    batch_loss = []
    model.train()
    for step, batch in enumerate(train_loader):
        input_ids, segment_ids, input_masks, gating_ids, target_ids, _ = [b.to(device) if not isinstance(b, list) else b for b in batch]
        if teacher_forcing > 0.0 and random.random() < teacher_forcing:
            tf = target_ids
        else:
            tf = None

        all_point_outputs, all_gate_outputs = model(input_ids, segment_ids, input_masks, target_ids.size(-1))  # gt - length (generation)
        loss_1 = loss_fnc_1(all_point_outputs.contiguous(), target_ids.contiguous().view(-1))
        loss_2 = loss_fnc_2(all_gate_outputs.contiguous().view(-1, 3), gating_ids.contiguous().view(-1))
        loss = loss_1 + loss_2
        batch_loss.append(loss.item())

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        if step % 100 == 0:            
            print('[%d/%d] [%d/%d] %f' % (epoch, n_epochs, step, len(train_loader), loss.item()))

predictions = inference(model, dev_loader, processor, device)
eval_result = _evaluation(predictions, dev_labels, slot_meta)
for k, v in eval_result.items():
    print(f"{k}: {v}")

[0/10] [0/11550] 28.629761
[1/10] [0/11550] nan
[2/10] [0/11550] nan
[3/10] [0/11550] nan
[4/10] [0/11550] nan
[5/10] [0/11550] nan
[6/10] [0/11550] nan
[7/10] [0/11550] nan
[8/10] [0/11550] nan


  0%|          | 1/631 [00:00<01:04,  9.84it/s]

[9/10] [0/11550] nan


100%|██████████| 631/631 [01:05<00:00,  9.62it/s]

{'joint_goal_accuracy': 0.018034086405073327, 'turn_slot_accuracy': 0.8214691504822179, 'turn_slot_f1': 0.018034086405073327}
joint_goal_accuracy: 0.018034086405073327
turn_slot_accuracy: 0.8214691504822179
turn_slot_f1: 0.018034086405073327





## Inference 

In [None]:
eval_data = json.load(open(f"/opt/ml/input/data/eval/eval_dials.json", "r"))

eval_examples = get_examples_from_dialogues(
    eval_data, user_first=False, dialogue_level=False
)

# Extracting Featrues
eval_features = processor.convert_examples_to_features(eval_examples)
eval_data = WOSDataset(eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_loader = DataLoader(
    eval_data,
    batch_size=8,
    sampler=eval_sampler,
    collate_fn=processor.collate_fn,
)

In [None]:
predictions = inference(model, eval_loader, processor, device)

In [None]:
json.dumps(predictions, open('predictions.csv', 'w'), indent=2, ensure_ascii=False) 