In [1]:
import sys
from model import TransformerDST
from pytorch_transformers import AdamW, WarmupLinearSchedule, BertConfig
from transformers import BertTokenizer
from transformer_dst_utils.data_utils import prepare_dataset, MultiWozDataset
from transformer_dst_utils.data_utils import make_slot_meta, domain2id, OP_SET, make_turn_label, postprocessing
from transformer_dst_utils.eval_utils import compute_prf, compute_acc, per_domain_join_accuracy
from transformer_dst_utils.ckpt_utils import download_ckpt, convert_ckpt_compatible
from transformer_dst_utils.evaluation import model_evaluation

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
import numpy as np
import argparse
import random
import os
import json
import time

## Data loading

In [2]:
def masked_cross_entropy_for_value(logits, target, pad_idx=0):
    mask = target.ne(pad_idx)
    logits_flat = logits.view(-1, logits.size(-1))
    log_probs_flat = torch.log(logits_flat)
    target_flat = target.view(-1, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    losses = losses_flat.view(*target.size())
    losses = losses * mask.float()
    loss = losses.sum() / (mask.sum().float())
    return loss


def save(args, epoch, model, enc_optimizer, dec_optimizer=None):
    model_to_save = model.module if hasattr(
        model, 'module') else model  # Only save the model it-self
    model_file = os.path.join(
        args.save_dir, "model.e{:}.bin".format(epoch))
    torch.save(model_to_save.state_dict(), model_file)

    # enc_optim_file = os.path.join(
    #     args.save_dir, "enc_optim.e{:}.bin".format(epoch))
    # torch.save(enc_optimizer.state_dict(), enc_optim_file)
    #
    # if dec_optimizer is not None:
    #     dec_optim_file = os.path.join(
    #         args.save_dir, "dec_optim.e{:}.bin".format(epoch))
    #     torch.save(dec_optimizer.state_dict(), dec_optim_file)


def load(args, epoch):
    model_file = os.path.join(
        args.save_dir, "model.e{:}.bin".format(epoch))
    model_recover = torch.load(model_file, map_location='cpu')

    enc_optim_file = os.path.join(
        args.save_dir, "enc_optim.e{:}.bin".format(epoch))
    enc_recover = torch.load(enc_optim_file, map_location='cpu')
    if hasattr(enc_recover, 'state_dict'):
        enc_recover = enc_recover.state_dict()

    dec_optim_file = os.path.join(
        args.save_dir, "dec_optim.e{:}.bin".format(epoch))
    dec_recover = torch.load(dec_optim_file, map_location='cpu')
    if hasattr(dec_recover, 'state_dict'):
        dec_recover = dec_recover.state_dict()

    return model_recover, enc_recover, dec_recover

In [3]:
from argparse import Namespace

args ={
    "data_root":'/opt/ml/input/data/train_dataset',
    "train_data":"train_dials.json",
    "dev_data":"dev_dials.json",
    "test_data":"test_dials.json",
    "save_dir":'outputs',
    "ontology_data":'/opt/ml/input/data/train_dataset/ontology.json',
    "vocab_path":'assets/vocab.txt',
    "bert_config_path":'./assets/bert_config_base_uncased.json',
    "bert_ckpt_path":'./assets/bert-base-uncased-pytorch_model.bin',
    "random_seed":42,
    "batch_size":16,
    "n_epochs":30,
    "eval_epoch":1,
    "op_code":"4",
    "slot_token":"[SLOT]",
    "hidden_dropout_prob":0.1,
    "decoder_teacher_forcing":1,
    "word_dropout":0.1,
    "shuffle_p":0.5,
    "n_history":1,
    "max_seq_length":256,
    "msg":None,
    "exclude_domain":False,
    "beam_size":1,
    "min_len":1,
    "length_penalty":0,
    "ngram_size":2,
    "shuffle_state":False,
    # By default, "decoder" only attend on a specific [SLOT] position.
    # If using this option, the "decoder" can access to this group of "[SLOT] domain slot - value".
    "use_full_slot":False,
    # Using only D_t in generation
    "use_dt_only":True,
    # w/o re-using dialogue
    "no_dial":False,
    # Using only [CLS]
    "use_cls_only":False,
    "dropout":0.1,
    "hidden_dropout_prob":0.1,
    "attention_probs_dropout_prob":0.1,
    "forbid_duplicate_ngrams":True,
    "forbid_ignore_word":None,
    "use_one_optim":True,   # 논문에서는 optim 1개 사용
    "enc_lr":3e-5,
    "enc_warmup":0.1,
    "num_workers":0,
    "only_pred_op":False
}

args = Namespace(**args)

args.train_data_path = os.path.join(args.data_root, args.train_data)
args.dev_data_path = os.path.join(args.data_root, args.dev_data)
args.test_data_path = os.path.join(args.data_root, args.test_data)
args.ontology_data = os.path.join(args.data_root, args.ontology_data)

In [None]:
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)
    print("### mkdir {:}".format(args.save_dir))

def worker_init_fn(worker_id):
    np.random.seed(args.random_seed + worker_id)

n_gpu = 0
n_gpu = torch.cuda.device_count()
device = torch.device('cuda')

if args.random_seed < 0:
    print("### Pick a random seed")
    args.random_seed = random.sample(list(range(0, 100000)), 1)[0]

print("### Random Seed: {:}".format(args.random_seed))
np.random.seed(args.random_seed)
random.seed(args.random_seed)
rng = random.Random(args.random_seed)
torch.manual_seed(args.random_seed)

if n_gpu > 0:
    if args.random_seed >= 0:
        torch.cuda.manual_seed(args.random_seed)
        torch.cuda.manual_seed_all(args.random_seed)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)

ontology = json.load(open(args.ontology_data))
slot_meta, ontology = make_slot_meta(ontology)
op2id = OP_SET[args.op_code]
print(op2id)

tokenizer = BertTokenizer.from_pretrained("dsksd/bert-ko-small-minimal")

train_path = os.path.join(args.data_root, "train.pt")
dev_path = os.path.join(args.data_root, "dev.pt")
test_path = os.path.join(args.data_root, "test.pt")

# train_path = args.train_data_path
# dev_path = args.dev_data_path
# test_path = args.test_data_path

# if not os.path.exists(test_path):
#     test_data_raw = prepare_dataset(data_path=args.test_data_path,
#                                     tokenizer=tokenizer,
#                                     slot_meta=slot_meta,
#                                     n_history=args.n_history,
#                                     max_seq_length=args.max_seq_length,
#                                     op_code=args.op_code)
#     torch.save(test_data_raw, test_path)
# else:
#     test_data_raw = torch.load(test_path)

# print("# test examples %d" % len(test_data_raw))

if not os.path.exists(train_path):
    train_data_raw = prepare_dataset(data_path=args.train_data_path,
                                      tokenizer=tokenizer,
                                      slot_meta=slot_meta,
                                      n_history=args.n_history,
                                      max_seq_length=args.max_seq_length,
                                      op_code=args.op_code)

    torch.save(train_data_raw, train_path)
else:
    train_data_raw = torch.load(train_path)

### mkdir outputs
### Random Seed: 42
{'delete': 0, 'update': 1, 'dontcare': 2, 'carryover': 3}


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=263327.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=124.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=288.0, style=ProgressStyle(description_…




In [None]:
tokenizer.decode(train_data_raw[12].input_id_p)

'[CLS] [SEP] [SEP] [UNK] 관광 경치 좋은 - [UNK] [UNK] 관광 교육적 - [UNK] [UNK] 관광 도보 가능 - [UNK] [UNK] 관광 문화 예술 - [UNK] [UNK] 관광 역사적 - [UNK] [UNK] 관광 이름 - 노량진 수산물 도매시장 [UNK] 관광 종류 - 쇼핑 [UNK] 관광 주차 가능 - [UNK] [UNK] 관광 지역 - 서울 서쪽 [UNK] 숙소 가격대 - [UNK] [UNK] 숙소 도보 가능 - [UNK] [UNK] 숙소 수영장 유무 - [UNK] [UNK] 숙소 스파 유무 - [UNK] [UNK] 숙소 예약 기간 - [UNK] [UNK] 숙소 예약 명수 - [UNK] [UNK] 숙소 예약 요일 - [UNK] [UNK] 숙소 이름 - [UNK] [UNK] 숙소 인터넷 가능 - [UNK] [UNK] 숙소 조식 가능 - [UNK] [UNK] 숙소 종류 - [UNK] [UNK] 숙소 주차 가능 - [UNK] [UNK] 숙소 지역 - [UNK] [UNK] 숙소 헬스장 유무 - [UNK] [UNK] 숙소 흡연 가능 - [UNK] [UNK] 식당 가격대 - [UNK] [UNK] 식당 도보 가능 - [UNK] [UNK] 식당 야외석 유무 - [UNK] [UNK] 식당 예약 명수 - [UNK] [UNK] 식당 예약 시간 - [UNK] [UNK] 식당 예약 요일 - [UNK] [UNK] 식당 이름 - [UNK] [UNK] 식당 인터넷 가능 - [UNK] [UNK] 식당 종류 - [UNK] [UNK] 식당 주류 판매 - [UNK] [UNK] 식당 주차 가능 - [UNK] [UNK] 식당 지역 - [UNK] [UNK] 식당 흡연 가능 - [UNK] [UNK] 지하철 도착지 - [UNK] [UNK] 지하철 출발 시간 - [UNK] [UNK] 지하철 출발지 - [UNK] [UNK] 택시 도착 시간 - [UNK] [UNK] 택시 도착지 - [UNK] [UNK] 택시 종류 - [UNK] [UNK] 택시 출발 시간 - [UNK] [

In [None]:
train_data = MultiWozDataset(train_data_raw,
                              tokenizer,
                              slot_meta,
                              args.max_seq_length,
                              rng,
                              ontology,
                              args.word_dropout,
                              args.shuffle_state,
                              args.shuffle_p, pad_id=tokenizer.convert_tokens_to_ids(['[PAD]'])[0],
                              slot_id=tokenizer.convert_tokens_to_ids(['[SLOT]'])[0],
                              decoder_teacher_forcing=args.decoder_teacher_forcing,
                              use_full_slot=args.use_full_slot,
                              use_dt_only=args.use_dt_only, no_dial=args.no_dial,
                              use_cls_only=args.use_cls_only)

print("# train examples %d" % len(train_data_raw))

if not os.path.exists(dev_path):
    dev_data_raw = prepare_dataset(data_path=args.dev_data_path,
                                    tokenizer=tokenizer,
                                    slot_meta=slot_meta,
                                    n_history=args.n_history,
                                    max_seq_length=args.max_seq_length,
                                    op_code=args.op_code)
    torch.save(dev_data_raw,  dev_path)
else:
    dev_data_raw = torch.load(dev_path)

print("# dev examples %d" % len(dev_data_raw))

### decoder_teacher_forcing: 1
# train examples 54984
# dev examples 7371


In [None]:
vars(train_data.data[250])

{'diag_1_len': 31,
 'diag_len': 52,
 'dialog_history': ['i',
  'am',
  'sorry',
  ',',
  'your',
  'booking',
  'was',
  'unsuccessful',
  '.',
  'would',
  'you',
  'like',
  'to',
  'book',
  'another',
  'day',
  'or',
  'a',
  'shorter',
  'stay',
  '?',
  ';',
  'could',
  'you',
  'try',
  'wednesday',
  ',',
  'instead',
  '?'],
 'domain_id': 0,
 'gen_max_len': 0,
 'generate_ids': [],
 'generate_y': [],
 'gold_p_state': {'hotel-book day': 'wednesday',
  'hotel-book people': '8',
  'hotel-book stay': '2',
  'hotel-name': 'autumn house'},
 'gold_state': ['hotel-book day-wednesday',
  'hotel-book people-8',
  'hotel-book stay-2',
  'hotel-name-autumn house'],
 'i_dslen_map': {0: 2,
  1: 2,
  2: 2,
  3: 2,
  4: 3,
  5: 3,
  6: 3,
  7: 2,
  8: 2,
  9: 2,
  10: 4,
  11: 2,
  12: 2,
  13: 2,
  14: 3,
  15: 3,
  16: 3,
  17: 2,
  18: 2,
  19: 4,
  20: 3,
  21: 2,
  22: 2,
  23: 3,
  24: 3,
  25: 3,
  26: 2,
  27: 2,
  28: 2,
  29: 3},
 'i_to_update': set(),
 'id': 'SNG01722.json',
 'inp

In [None]:
model_config = BertConfig.from_json_file(args.bert_config_path)
model_config.dropout = args.dropout
model_config.attention_probs_dropout_prob = args.attention_probs_dropout_prob
model_config.hidden_dropout_prob = args.hidden_dropout_prob

type_vocab_size = 4
dec_config = args
model = TransformerDST(model_config, dec_config, len(op2id), len(domain2id),
                        op2id['update'],
                        tokenizer.convert_tokens_to_ids(['[MASK]'])[0],
                        tokenizer.convert_tokens_to_ids(['[SEP]'])[0],
                        tokenizer.convert_tokens_to_ids(['[PAD]'])[0],
                        tokenizer.convert_tokens_to_ids(['-'])[0],
                        type_vocab_size, args.exclude_domain)

### word index of '-',  1011


In [None]:
if not os.path.exists(args.bert_ckpt_path):
    args.bert_ckpt_path = download_ckpt(args.bert_ckpt_path, args.bert_config_path, 'assets')

state_dict = torch.load(args.bert_ckpt_path, map_location='cpu')
_k = 'embeddings.token_type_embeddings.weight'
print("config.type_vocab_size != state_dict[bert.embeddings.token_type_embeddings.weight] ({0} != {1})".format(
        type_vocab_size, state_dict[_k].shape[0]))
# state_dict[_k].repeat(
#     type_vocab_size, state_dict[_k].shape[1])
state_dict[_k] = state_dict[_k].repeat(int(type_vocab_size/state_dict[_k].shape[0]), 1)
state_dict[_k].data[2, :].copy_(state_dict[_k].data[0, :])
state_dict[_k].data[3, :].copy_(state_dict[_k].data[0, :])
model.bert.load_state_dict(state_dict)
print("\n### Done Load BERT")
sys.stdout.flush()

config.type_vocab_size != state_dict[bert.embeddings.token_type_embeddings.weight] (4 != 2)

### Done Load BERT


In [None]:
# re-initialize added special tokens ([SLOT], [NULL], [EOS])
model.bert.embeddings.word_embeddings.weight.data[1].normal_(mean=0.0, std=0.02)
model.bert.embeddings.word_embeddings.weight.data[2].normal_(mean=0.0, std=0.02)
model.bert.embeddings.word_embeddings.weight.data[3].normal_(mean=0.0, std=0.02)

# re-initialize seg-2, seg-3
model.bert.embeddings.token_type_embeddings.weight.data[2].normal_(mean=0.0, std=0.02)
model.bert.embeddings.token_type_embeddings.weight.data[3].normal_(mean=0.0, std=0.02)
model.to(device)

TransformerDST(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(4, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
    

In [None]:
num_train_steps = int(len(train_data_raw) / args.batch_size * args.n_epochs)

if args.use_one_optim:
    print("### Use One Optim")
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(
            nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.enc_lr)
    scheduler = WarmupLinearSchedule(optimizer, int(num_train_steps * args.enc_warmup),
                                          t_total=num_train_steps)
else:
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    enc_param_optimizer = list(model.bert.named_parameters())  # TODO: For BERT only
    print('### Optim BERT: {:}'.format(len(enc_param_optimizer)))
    enc_optimizer_grouped_parameters = [
        {'params': [p for n, p in enc_param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in enc_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr)
    enc_scheduler = WarmupLinearSchedule(enc_optimizer, int(num_train_steps * args.enc_warmup),
                                          t_total=num_train_steps)

    dec_param_optimizer = list(model.named_parameters())  # TODO:  For other parameters
    print('### Optim All: {:}'.format(len(dec_param_optimizer)))
    dec_param_optimizer = [p for (n, p) in dec_param_optimizer if 'bert' not in n]
    print('### Optim OTH: {:}'.format(len(dec_param_optimizer)))
    dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr)
    dec_scheduler = WarmupLinearSchedule(dec_optimizer, int(num_train_steps * args.dec_warmup),
                                          t_total=num_train_steps)

### Use One Optim


In [None]:
if n_gpu > 1:
    model = torch.nn.DataParallel(model)

train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              collate_fn=train_data.collate_fn,
                              num_workers=args.num_workers,
                              worker_init_fn=worker_init_fn)

loss_fnc = nn.CrossEntropyLoss()
best_score = {'epoch': 0, 'joint_acc': 0, 'op_acc': 0, 'final_slot_f1': 0}

In [None]:
start_time = time.time()

for epoch in range(args.n_epochs):
    batch_loss = []
    model.train()
    for step, batch in enumerate(train_dataloader):

        batch = [b.to(device) if (not isinstance(b, int)) and (not isinstance(b, dict) and (not isinstance(b, list)) and (not isinstance(b, np.ndarray))) else b for b in batch]

        input_ids_p, segment_ids_p, input_mask_p, \
        state_position_ids, op_ids, domain_ids, input_ids_g, segment_ids_g, position_ids_g, input_mask_g, \
        masked_pos, masked_weights, lm_label_ids, id_n_map, gen_max_len, n_total_pred = batch

        domain_scores, state_scores, loss_g = model(input_ids_p, segment_ids_p, input_mask_p, state_position_ids,
            input_ids_g, segment_ids_g, position_ids_g, input_mask_g,
            masked_pos, masked_weights, lm_label_ids, id_n_map, gen_max_len, only_pred_op=args.only_pred_op, n_gpu=n_gpu)

        if n_total_pred > 0:
            loss_g = loss_g.sum() / n_total_pred
        else:
            loss_g = 0

        loss_s = loss_fnc(state_scores.view(-1, len(op2id)), op_ids.view(-1))

        if args.only_pred_op:
            loss = loss_s
        else:
            loss = loss_s + loss_g

        if args.exclude_domain is not True:
            loss_d = loss_fnc(domain_scores.view(-1, len(domain2id)), domain_ids.view(-1))
            loss = loss + loss_d

        batch_loss.append(loss.item())

        loss.backward()

        if args.use_one_optim:
            optimizer.step()
            scheduler.step()
        else:
            enc_optimizer.step()
            enc_scheduler.step()
            dec_optimizer.step()
            dec_scheduler.step()

        model.zero_grad()

        if step % 100 == 0:
            try:
                loss_g = loss_g.item()
            except AttributeError:
                loss_g = loss_g

            if args.exclude_domain is not True:
                print("time %.1f min, [%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f, dom_loss : %.3f" \
                      % ((time.time()-start_time)/60, epoch+1, args.n_epochs, step,
                          len(train_dataloader), np.mean(batch_loss),
                          loss_s.item(), loss_g, loss_d.item()))
            else:
                print("time %.1f min, [%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f" \
                      % ((time.time()-start_time)/60, epoch+1, args.n_epochs, step,
                          len(train_dataloader), np.mean(batch_loss),
                          loss_s.item(), loss_g))

            sys.stdout.flush()
            batch_loss = []

    if args.use_one_optim:
        save(args, epoch + 1, model, optimizer)
    else:
        save(args, epoch + 1, model, enc_optimizer, dec_optimizer)

    if ((epoch+1) % args.eval_epoch == 0) and (epoch+1 >= 8):
        eval_res = model_evaluation(model, dev_data_raw, tokenizer, slot_meta, epoch+1, args.op_code,
                                    use_full_slot=args.use_full_slot, use_dt_only=args.use_dt_only, no_dial=args.no_dial, use_cls_only=args.use_cls_only, n_gpu=n_gpu)
        print("### Epoch {:} Score : ".format(epoch+1), eval_res)

        if eval_res['joint_acc'] > best_score['joint_acc']:
            best_score = eval_res
            print("### Best Joint Acc: {:} ###".format(best_score['joint_acc']))
            print('\n')

            if epoch+1 >= 8:  # To speed up
                eval_res_test = model_evaluation(model, test_data_raw, tokenizer, slot_meta, epoch + 1, args.op_code,
                                                  use_full_slot=args.use_full_slot, use_dt_only=args.use_dt_only, no_dial=args.no_dial, use_cls_only=args.use_cls_only, n_gpu=n_gpu)
                print("### Epoch {:} Test Score : ".format(epoch + 1), eval_res_test)

RuntimeError: ignored

In [None]:
train_data_file = "/content/drive/MyDrive/WOS/train_dials.json"
slot_meta = json.load(open("/content/drive/MyDrive/WOS/slot_meta.json"))
ontology = json.load(open("/content/drive/MyDrive/WOS/ontology.json"))
train_data, dev_data, dev_labels = load_dataset(train_data_file)

train_examples = get_examples_from_dialogues(train_data,
                                             user_first=False,
                                             dialogue_level=False)
dev_examples = get_examples_from_dialogues(dev_data,
                                           user_first=False,
                                           dialogue_level=False)

100%|██████████| 6301/6301 [00:00<00:00, 11783.13it/s]
100%|██████████| 699/699 [00:00<00:00, 16536.95it/s]


In [None]:
print(len(train_examples))
print(len(dev_examples))

46294
4951


In [None]:
train_examples[0]

DSTInputExample(guid='snowy-hat-8324:관광_식당_11-0', context_turns=[], current_turn=['', '서울 중앙에 있는 박물관을 찾아주세요'], label=['관광-종류-박물관', '관광-지역-서울 중앙'])

## TRADE Preprocessor 

기존의 GRU 기반의 인코더를 BERT-based Encoder로 바꿀 준비를 합시다.

1. 현재 `_convert_example_to_feature`에서는 `max_seq_length`를 핸들하고 있지 않습니다. `input_id`와 `segment_id`가 `max_seq_length`를 넘어가면 좌측부터 truncate시키는 코드를 삽입하세요.

2. hybrid approach에서 얻은 교훈을 바탕으로 gate class를 3개에서 5개로 늘려봅시다.
    - `gating2id`를 수정하세요
    - 이에 따른 `recover_state`를 수정하세요.
    
3. word dropout을 구현하세요.

In [None]:
class TRADEPreprocessor(DSTPreprocessor):
    def __init__(
        self,
        slot_meta,
        src_tokenizer,
        trg_tokenizer=None,
        ontology=None,
        max_seq_length=512,
    ):
        self.slot_meta = slot_meta
        self.src_tokenizer = src_tokenizer
        self.trg_tokenizer = trg_tokenizer if trg_tokenizer else src_tokenizer
        self.ontology = ontology
        self.gating2id = {"none": 0, "dontcare": 1, "yes": 2, "no": 3, "ptr": 4}
        self.id2gating = {v: k for k, v in self.gating2id.items()}
        self.max_seq_length = max_seq_length

    def _convert_example_to_feature(self, example):
        dialogue_context = " [SEP] ".join(example.context_turns + example.current_turn)

        input_id = self.src_tokenizer.encode(dialogue_context, add_special_tokens=False)
        max_length = self.max_seq_length - 2
        if len(input_id) > max_length:
            gap = len(input_id) - max_length
            input_id = input_id[gap:]

        input_id = (
            [self.src_tokenizer.cls_token_id]
            + input_id
            + [self.src_tokenizer.sep_token_id]
        )
        segment_id = [0] * len(input_id)

        target_ids = []
        gating_id = []
        if not example.label:
            example.label = []

        state = convert_state_dict(example.label)
        for slot in self.slot_meta:
            value = state.get(slot, "none")
            target_id = self.trg_tokenizer.encode(value, add_special_tokens=False) + [
                self.trg_tokenizer.sep_token_id
            ]
            target_ids.append(target_id)
            gating_id.append(self.gating2id.get(value, self.gating2id["ptr"]))
        target_ids = self.pad_ids(target_ids, self.trg_tokenizer.pad_token_id)
        return OpenVocabDSTFeature(
            example.guid, input_id, segment_id, gating_id, target_ids
        )

    def convert_examples_to_features(self, examples):
        return list(map(self._convert_example_to_feature, tqdm(examples)))

    def recover_state(self, gate_list, gen_list):
        assert len(gate_list) == len(self.slot_meta)
        assert len(gen_list) == len(self.slot_meta)

        recovered = []
        for slot, gate, value in zip(self.slot_meta, gate_list, gen_list):
            if self.id2gating[gate] == "none":
                continue

            if self.id2gating[gate] in ["dontcare", "yes", "no"]:
                recovered.append("%s-%s" % (slot, self.id2gating[gate]))
                continue

            token_id_list = []
            for id_ in value:
                if id_ in self.trg_tokenizer.all_special_ids:
                    break

                token_id_list.append(id_)
            value = self.trg_tokenizer.decode(token_id_list, skip_special_tokens=True)

            if value == "none":
                continue

            recovered.append("%s-%s" % (slot, value))
        return recovered

    def collate_fn(self, batch):
        guids = [b.guid for b in batch]
        input_ids = torch.LongTensor(
            self.pad_ids([b.input_id for b in batch], self.src_tokenizer.pad_token_id)
        )
        segment_ids = torch.LongTensor(
            self.pad_ids([b.segment_id for b in batch], self.src_tokenizer.pad_token_id)
        )
        input_masks = input_ids.ne(self.src_tokenizer.pad_token_id)

        gating_ids = torch.LongTensor([b.gating_id for b in batch])
        target_ids = self.pad_id_of_matrix(
            [torch.LongTensor(b.target_ids) for b in batch],
            self.trg_tokenizer.pad_token_id,
        )
        return input_ids, segment_ids, input_masks, gating_ids, target_ids, guids

## Convert_Examples_to_Features 

In [None]:
tokenizer = BertTokenizer.from_pretrained('dsksd/bert-ko-small-minimal')
processor = TRADEPreprocessor(slot_meta, tokenizer, max_seq_length=512)

train_features = processor.convert_examples_to_features(train_examples)
dev_features = processor.convert_examples_to_features(dev_examples)

  0%|          | 54/46294 [00:00<04:22, 176.14it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (537 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 46294/46294 [04:37<00:00, 166.83it/s]
100%|██████████| 4951/4951 [00:28<00:00, 172.46it/s]


In [None]:
print(len(train_features))
print(len(dev_features))

46255
4990


In [None]:
train_features[10]

OpenVocabDSTFeature(guid='polished-poetry-0057:관광_9-2', input_id=[2, 3, 7596, 4292, 3755, 4228, 18781, 6265, 10806, 4073, 3249, 11649, 4150, 35, 3, 6265, 10806, 4073, 7596, 4007, 6259, 4283, 2084, 4007, 24874, 28060, 16301, 15550, 12178, 4007, 3249, 4576, 6216, 18, 3, 3158, 2279, 7149, 9068, 3305, 6449, 4076, 8553, 18, 3, 28060, 16301, 15550, 12178, 4234, 9068, 4034, 6265, 27439, 10732, 11684, 4096, 10561, 18, 3, 6449, 4076, 4114, 4034, 4396, 4073, 20025, 4294, 18790, 4086, 3305, 6449, 4076, 8553, 18, 3], segment_id=[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], gating_id=[0, 0, 0, 0, 0, 4, 4, 0, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], target_ids=[[21832, 11764, 3, 0, 0], [21832, 11764, 3, 0, 0], [21832, 11764, 3, 0, 0], [21832, 11764,

In [None]:
lst = [tokenizer.decode(x) for x in train_features[10].target_ids]
print(lst)

['none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', '노량진 수산물 도매시장 [SEP]', '쇼핑 [SEP] [PAD] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', '서울 서쪽 [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP] [PAD] [PAD]', 'none [SEP

# Model 

1. `GRUEncoder`를 `BertModel`로 교체하세요. 이에 따라 `tie_weight` 함수가 수정되어야 합니다.

In [None]:
class TRADE(nn.Module):
    def __init__(self, config, slot_vocab, slot_meta, pad_idx=0):
        super(TRADE, self).__init__()
        self.slot_meta = slot_meta
        if config.model_name_or_path:
            self.encoder = BertModel.from_pretrained(config.model_name_or_path)
        else:
            self.encoder = BertModel(config)
            
        self.decoder = SlotGenerator(
            config.vocab_size,
            config.hidden_size,
            config.hidden_dropout_prob,
            config.n_gate,
            None,
            pad_idx,
        )
        
        # init for only subword embedding
        self.decoder.set_slot_idx(slot_vocab)
        self.tie_weight()

    def tie_weight(self):
        self.decoder.embed.weight = self.encoder.embeddings.word_embeddings.weight

    def forward(self, input_ids, token_type_ids, attention_mask=None, max_len=10, teacher=None):

        encoder_outputs, pooled_output = self.encoder(input_ids=input_ids)
        all_point_outputs, all_gate_outputs = self.decoder(
            input_ids, encoder_outputs, pooled_output.unsqueeze(0), attention_mask, max_len, teacher
        )

        return all_point_outputs, all_gate_outputs
    
class SlotGenerator(nn.Module):
    def __init__(
        self, vocab_size, hidden_size, dropout, n_gate, proj_dim=None, pad_idx=0
    ):
        super(SlotGenerator, self).__init__()
        self.pad_idx = pad_idx
        self.vocab_size = vocab_size
        self.embed = nn.Embedding(
            vocab_size, hidden_size, padding_idx=pad_idx
        )  # shared with encoder

        if proj_dim:
            self.proj_layer = nn.Linear(hidden_size, proj_dim, bias=False)
        else:
            self.proj_layer = None
        self.hidden_size = proj_dim if proj_dim else hidden_size

        self.gru = nn.GRU(
            self.hidden_size, self.hidden_size, 1, dropout=dropout, batch_first=True
        )
        self.n_gate = n_gate
        self.dropout = nn.Dropout(dropout)
        self.w_gen = nn.Linear(self.hidden_size * 3, 1)
        self.sigmoid = nn.Sigmoid()
        self.w_gate = nn.Linear(self.hidden_size, n_gate)

    def set_slot_idx(self, slot_vocab_idx):
        whole = []
        max_length = max(map(len, slot_vocab_idx))
        for idx in slot_vocab_idx:
            if len(idx) < max_length:
                gap = max_length - len(idx)
                idx.extend([self.pad_idx] * gap)
            whole.append(idx)
        self.slot_embed_idx = whole  # torch.LongTensor(whole)

    def embedding(self, x):
        x = self.embed(x)
        if self.proj_layer:
            x = self.proj_layer(x)
        return x

    def forward(
        self, input_ids, encoder_output, hidden, input_masks, max_len, teacher=None
    ):
        input_masks = input_masks.ne(1)
        # J, slot_meta : key : [domain, slot] ex> LongTensor([1,2])
        # J,2
        batch_size = encoder_output.size(0)
        slot = torch.LongTensor(self.slot_embed_idx).to(input_ids.device)  ##
        slot_e = torch.sum(self.embedding(slot), 1)  # J,d
        J = slot_e.size(0)

        all_point_outputs = torch.zeros(batch_size, J, max_len, self.vocab_size).to(
            input_ids.device
        )
        
        # Parallel Decoding
        w = slot_e.repeat(batch_size, 1).unsqueeze(1)
        hidden = hidden.repeat_interleave(J, dim=1)
        encoder_output = encoder_output.repeat_interleave(J, dim=0)
        input_ids = input_ids.repeat_interleave(J, dim=0)
        input_masks = input_masks.repeat_interleave(J, dim=0)
        for k in range(max_len):
            w = self.dropout(w)
            _, hidden = self.gru(w, hidden)  # 1,B,D

            # B,T,D * B,D,1 => B,T
            attn_e = torch.bmm(encoder_output, hidden.permute(1, 2, 0))  # B,T,1
            attn_e = attn_e.squeeze(-1).masked_fill(input_masks, -1e9)
            attn_history = F.softmax(attn_e, -1)  # B,T

            if self.proj_layer:
                hidden_proj = torch.matmul(hidden, self.proj_layer.weight)
            else:
                hidden_proj = hidden

            # B,D * D,V => B,V
            attn_v = torch.matmul(
                hidden_proj.squeeze(0), self.embed.weight.transpose(0, 1)
            )  # B,V
            attn_vocab = F.softmax(attn_v, -1)

            # B,1,T * B,T,D => B,1,D
            context = torch.bmm(attn_history.unsqueeze(1), encoder_output)  # B,1,D
            p_gen = self.sigmoid(
                self.w_gen(torch.cat([w, hidden.transpose(0, 1), context], -1))
            )  # B,1
            p_gen = p_gen.squeeze(-1)

            p_context_ptr = torch.zeros_like(attn_vocab).to(input_ids.device)
            p_context_ptr.scatter_add_(1, input_ids, attn_history)  # copy B,V
            p_final = p_gen * attn_vocab + (1 - p_gen) * p_context_ptr  # B,V
            _, w_idx = p_final.max(-1)

            if teacher is not None:
                w = self.embedding(teacher[:, :, k]).transpose(0, 1).reshape(batch_size * J, 1, -1)
            else:
                w = self.embedding(w_idx).unsqueeze(1)  # B,1,D
            if k == 0:
                gated_logit = self.w_gate(context.squeeze(1))  # B,3
                all_gate_outputs = gated_logit.view(batch_size, J, self.n_gate)
            all_point_outputs[:, :, k, :] = p_final.view(batch_size, J, self.vocab_size)

        return all_point_outputs, all_gate_outputs

# 모델 및 데이터 로더 정의

In [None]:
slot_vocab = []
for slot in slot_meta:
    slot_vocab.append(
        tokenizer.encode(slot.replace('-', ' '),
                         add_special_tokens=False)
    )
    
config = BertConfig.from_pretrained('dsksd/bert-ko-small-minimal')
config.model_name_or_path = 'dsksd/bert-ko-small-minimal'
config.n_gate = len(processor.gating2id)
config.proj_dim = None
model = TRADE(config, slot_vocab, slot_meta)



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_data = WOSDataset(train_features)
train_sampler = RandomSampler(train_data)
train_loader = DataLoader(train_data, batch_size=16, sampler=train_sampler, collate_fn=processor.collate_fn)

dev_data = WOSDataset(dev_features)
dev_sampler = SequentialSampler(dev_data)
dev_loader = DataLoader(dev_data, batch_size=8, sampler=dev_sampler, collate_fn=processor.collate_fn)

# Optimizer & Scheduler 선언

In [None]:
n_epochs = 10
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": 0.01,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]

t_total = len(train_loader) * n_epochs
optimizer = AdamW(optimizer_grouped_parameters, lr=3e-5, eps=1e-8)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0.1, num_training_steps=t_total
)
teacher_forcing = 0.5
model.to(device)

def masked_cross_entropy_for_value(logits, target, pad_idx=0):
    mask = target.ne(pad_idx)
    logits_flat = logits.view(-1, logits.size(-1))
    log_probs_flat = torch.log(logits_flat)
    target_flat = target.view(-1, 1)
    losses_flat = -torch.gather(log_probs_flat, dim=1, index=target_flat)
    losses = losses_flat.view(*target.size())
    losses = losses * mask.float()
    loss = losses.sum() / (mask.sum().float())
    return loss

loss_fnc_1 = masked_cross_entropy_for_value  # generation
loss_fnc_2 = nn.CrossEntropyLoss()  # gating

## Train

In [None]:
 for epoch in range(n_epochs):
    batch_loss = []
    model.train()
    for step, batch in enumerate(train_loader):
        input_ids, segment_ids, input_masks, gating_ids, target_ids, guids = [b.to(device) if not isinstance(b, list) else b for b in batch]
        if teacher_forcing > 0.0 and random.random() < teacher_forcing:
            tf = target_ids
        else:
            tf = None

        all_point_outputs, all_gate_outputs = model(input_ids, segment_ids, input_masks, target_ids.size(-1))  # gt - length (generation)
        loss_1 = loss_fnc_1(all_point_outputs.contiguous(), target_ids.contiguous().view(-1))
        loss_2 = loss_fnc_2(all_gate_outputs.contiguous().view(-1, 5), gating_ids.contiguous().view(-1))
        loss = loss_1 + loss_2
        batch_loss.append(loss.item())

        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        if step % 100 == 0:
            print('[%d/%d] [%d/%d] %f' % (epoch, n_epochs, step, len(train_loader), loss.item()))
                
    predictions = inference(model, dev_loader, processor, device)
    eval_result = _evaluation(predictions, dev_labels, slot_meta)
    for k, v in eval_result.items():
        print(f"{k}: {v}")

[0/10] [0/2891] 10.265509
[0/10] [100/2891] 1.779430
[0/10] [200/2891] 1.638214
[0/10] [300/2891] 1.447321
[0/10] [400/2891] 1.126810
[0/10] [500/2891] 1.047267
[0/10] [600/2891] 1.076839
[0/10] [700/2891] 0.928714
[0/10] [800/2891] 0.688376
[0/10] [900/2891] 0.605567
[0/10] [1000/2891] 0.638915
[0/10] [1100/2891] 0.530238
[0/10] [1200/2891] 0.453068
[0/10] [1300/2891] 0.450586
[0/10] [1400/2891] 0.394243
[0/10] [1500/2891] 0.290189
[0/10] [1600/2891] 0.471834
[0/10] [1700/2891] 0.314694
[0/10] [1800/2891] 0.270238
[0/10] [1900/2891] 0.414060
[0/10] [2000/2891] 0.290958
[0/10] [2100/2891] 0.350836
[0/10] [2200/2891] 0.272087
[0/10] [2300/2891] 0.359535
[0/10] [2400/2891] 0.315303
[0/10] [2500/2891] 0.296642
[0/10] [2600/2891] 0.198743
[0/10] [2700/2891] 0.247285
[0/10] [2800/2891] 0.248392


100%|██████████| 624/624 [01:09<00:00,  9.02it/s]


{'joint_goal_accuracy': 0.30280561122244487, 'turn_slot_accuracy': 0.9667022934758523, 'turn_slot_f1': 0.8535316828673348}
joint_goal_accuracy: 0.30280561122244487
turn_slot_accuracy: 0.9667022934758523
turn_slot_f1: 0.8535316828673348
[1/10] [0/2891] 0.168331
[1/10] [100/2891] 0.183329
[1/10] [200/2891] 0.206472
[1/10] [300/2891] 0.219990
[1/10] [400/2891] 0.209132
[1/10] [500/2891] 0.169938
[1/10] [600/2891] 0.149240
[1/10] [700/2891] 0.139512
[1/10] [800/2891] 0.182836
[1/10] [900/2891] 0.213563
[1/10] [1000/2891] 0.205337
[1/10] [1100/2891] 0.172299
[1/10] [1200/2891] 0.110077
[1/10] [1300/2891] 0.115137
[1/10] [1400/2891] 0.226755
[1/10] [1500/2891] 0.158137
[1/10] [1600/2891] 0.102026
[1/10] [1700/2891] 0.170833
[1/10] [1800/2891] 0.142623
[1/10] [1900/2891] 0.116227
[1/10] [2000/2891] 0.121230
[1/10] [2100/2891] 0.161562


KeyboardInterrupt: 

## Inference 

In [None]:
eval_data = json.load(open(f"/opt/ml/input/data/eval_dataset/eval_dials.json", "r"))

eval_examples = get_examples_from_dialogues(
    eval_data, user_first=False, dialogue_level=False
)

# Extracting Featrues
eval_features = processor.convert_examples_to_features(eval_examples)
eval_data = WOSDataset(eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_loader = DataLoader(
    eval_data,
    batch_size=8,
    sampler=eval_sampler,
    collate_fn=processor.collate_fn,
)

100%|██████████| 1000/1000 [00:00<00:00, 1936.46it/s]


In [None]:
predictions = inference(model, eval_loader, processor, device)

100%|██████████| 944/944 [01:45<00:00,  8.92it/s]


In [None]:
json.dump(predictions, open('predictions.csv', 'w'), indent=2, ensure_ascii=False) 