# Train ChatModel

Description here.

## Install dependent libraries
This section installs required package. Version should be specified for reproducibility.

In [None]:
! pip install transformers==2.6.0
! pip install tqdm==4.43.0
! pip install mecab-python3==0.996.2
! pip install attrdict==2.0.1
! pip install tensorboard==2.1.1

## Test library

Test your all the libraries used in this notebook.

## Parameters
Declare parameters set by `papermill` .

In [None]:
# general parameters
name = "model"
data_dir = "data_sample"
pretrained_dir = "../gpt/output/models"
output_dir ="output"

# training parameters
seed=1234
num_distructors = 2
num_epochs=10
batch_size=2
learning_rate=5e-5
max_grad_norm=1.0
warmup_rate=0.1
patience = 3

In [None]:
import attrdict

_params = attrdict.AttrDict({
    "name": name,
    "data_dir": data_dir,
    "pretrained_dir": pretrained_dir,
    "output_dir": output_dir,
    "seed": seed,
    "num_distructors": num_distructors,
    "num_epochs": num_epochs,
    "batch_size": batch_size,
    "learning_rate": learning_rate,
    "max_grad_norm": max_grad_norm,
    "warmup_rate": warmup_rate,
    "patience": patience,
})

del name
del data_dir
del output_dir
del pretrained_dir
del seed
del num_distructors
del num_epochs
del batch_size
del learning_rate
del max_grad_norm
del warmup_rate
del patience

## Define preprocessor and tokenizer

In [None]:
import transformers

def build_tokenizer(model_dir):
    tokenizer = transformers.BertJapaneseTokenizer.from_pretrained(model_dir)
    special_tokens_dict = {
        "additional_special_tokens": [
            "<SEP>",   # Separator
            "<BOS>",  # Begin of sentence
            "<EOS>",  # End of sentence
        ]
    }
    tokenizer.add_special_tokens(special_tokens_dict)
    return tokenizer
    

## Define dataset

Define a function to add **distructors** to conversations.

In [None]:
import numpy


def add_distructors(conversations, num_distructors):
    if num_distructors <= 0:
        raise Exception("num_distractors should be larger than 0")

    contexts = []
    reses = []
    for context, response in conversations:
        contexts.append(context)
        reses.append(response)
        
    random_vals = zip(
        *[numpy.random.permutation(len(reses))
          for _ in range(num_distructors)]
    )
    
    dist = []
    for i, rand in enumerate(random_vals):
        items = [contexts[i], reses[i]] + [reses[j] for j in rand]
        dist.append(items)
        
    return dist

In [None]:
import torch


def convert_to_model_inputs(tokenizer, context, response, is_distructor):
    SEP, BOS, EOS = tokenizer.additional_special_tokens

    seq = [
        [BOS] + tokenizer.tokenize(context),
        [SEP] + tokenizer.tokenize(response) + [EOS],
    ]
    # build tokens
    tokens = sum(seq, [])
    # convert to ids
    token_ids = tokenizer.convert_tokens_to_ids(tokens)
    # build segment
    # segment id should be [0, 1] - 0 represents sentence A, and 1 for B
    segments = (
        [0] * len(seq[0]) +
        [1] * len(seq[1])
    )
    segment_ids = tokenizer.convert_tokens_to_ids(segments)

    # build target ids
    # set -100 to ignore target ids
    target_ids = [-100] * len(seq[0])
    # Ignore all target_ids if reply is a distractor
    if is_distructor:
        target_ids += [-100] * len(seq[1])
    else:
        target_ids += [-100] + token_ids[len(seq[0])+1:]
    # mc_id is a position of last token.
    mc_id = len(token_ids) - 1

    return token_ids, segment_ids, target_ids, mc_id



class ChatDataset(torch.utils.data.Dataset):
    def __init__(self, tokenizer, inputs):
        inputs_ = []
        for context, response, distructors in inputs:
            item = self._build(tokenizer, context, response, distructors)
            inputs_.append(item)

        self._inputs = inputs_

    def __len__(self):
        return len(self._inputs)

    def __getitem__(self, idx):
        return self._inputs[idx]

    def _build(self, tokenizer, context , response, distructors):
        token_ids_ = []
        segment_ids_ = []
        target_ids_ = []
        mc_ids_ = []

        for idx, rpl in enumerate([response] + distructors):
            is_distructor = idx > 0  # distructor flag
            token_ids, segment_ids, target_ids, mc_id = convert_to_model_inputs(tokenizer, context, rpl, is_distructor)
            assert len(token_ids) == len(segment_ids) == len(target_ids)

            token_ids_.append(token_ids)
            segment_ids_.append(segment_ids)
            target_ids_.append(target_ids)
            mc_ids_.append(mc_id)

        return token_ids_, segment_ids_, target_ids_, mc_ids_, 0


class PaddingCollation:
    def __init__(self, padding_value):
        self._padding_value = padding_value

    def apply(self, batch):
        """
        Returns:
            2 -> batch_size
            3 -> 1 + num_distructors
            7 -> max_length
            [
                torch.Size([2, 3, 7]),
                torch.Size([2, 3, 7]),
                torch.Size([2, 3, 7]),
                torch.Size([2, 3]),
                torch.Size([2])
            ]
        """
        lst = []

        paired_batch = list(zip(*batch))
        lm_items = paired_batch[:-2]
        mc_items = paired_batch[-2:]

        mc_ids, mc_labels = [torch.tensor(x) for x in mc_items]
        for idx, items in enumerate(lm_items):
            x = torch.nn.utils.rnn.pad_sequence(
                [torch.tensor(x) for x in sum(list(items), [])],
                batch_first=True,
                padding_value=self._padding_value
            )
            lst.append(x.reshape(len(batch), -1, x.size()[1]))

        return [*lst, mc_ids, mc_labels]


In [None]:
def build_dataloader(tokenizer, conversation_with_distructors, batch_size, shuffle):
    inputs = []
    for item in conversation_with_distructors:
        context = item[0]
        response = item[1]
        distructors = item[2:]
        inputs.append((context, response, distructors))

    data_set = ChatDataset(tokenizer, inputs)
    padding_collation = PaddingCollation(padding_value=tokenizer.pad_token_id)
    data_loader = torch.utils.data.DataLoader(
        data_set,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=padding_collation.apply
    )
    return data_loader

## Define model

In [None]:
from transformers import GPT2DoubleHeadsModel

def build_model(model_dir, tokenizer):
    model = GPT2DoubleHeadsModel.from_pretrained(model_dir)
    model.resize_token_embeddings(len(tokenizer))
    return model

## Build and save vocabulary

In [None]:
import os


_model_output_dir = os.path.join(_params.output_dir, _params.name)
if not os.path.exists(_model_output_dir):
    os.mkdir(_model_output_dir)
    
_tokenizer = build_tokenizer(_params.pretrained_dir)  
_tokenizer.save_pretrained(_model_output_dir)

## Train and save model

In [None]:
import torch
import numpy as np
import random


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # When use GPU
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def calc_ppl(loss):
    """Calculate perplexity from Softmax Cross Entropy loss"""
    ppl = torch.exp(torch.tensor(loss)).item()
    return ppl


In [None]:
import tqdm


def train_model(model_output_dir, net, dataloader_dict, train_config):
    PHASE_TRAIN = "train"
    PHASE_VAL = "val"
    
    # keep the best model
    best = {"model": None, "epoch": 0, "loss":float("infinity"), "ppl": float("infinity")}
    
    # 学習イテレーションの回数を保持
    num_iters = 0
    
    # keep the count which the validation metric does not improved
    num_patience = 0
        
    net.to(train_config.device)
    
    for epoch in range(train_config.num_epochs+1):
        print("Epoch {}/{}".format(epoch, train_config.num_epochs))
        # 学習と検証のループ
        for phase in [PHASE_TRAIN, PHASE_VAL]:
            # フェーズによってネットワークのモードを変更する
            # Dropout等の挙動に影響あり
            if phase == PHASE_TRAIN:
                net.train()
            elif phase == PHASE_VAL:
                net.eval()
            else:
                raise Exception("got {} expected one of {}".format(phase, [PHASE_TRAIN, PHASE_VAL]))
                
            epoch_loss = 0
            
            # 未学習時の検証性能を確かめる
            if epoch == 0 and phase == PHASE_TRAIN:
                continue
                
            for batch in tqdm.tqdm(dataloader_dict[phase], disable=True):

                # GPUが使える場合はGPUにデータを送る
                ### ここも変更点 ###
                inputs = [x.to(train_config.device) for x in batch]
                ### ここまで ###

                # Initialize optimizer
                if phase == PHASE_TRAIN:
                    train_config.optimizer.zero_grad()
                
                # set_grad_enabled(phrase=="train") で
                # 学習時のみ勾配計算できるようにグラフ作成する
                with torch.set_grad_enabled(phase==PHASE_TRAIN):
                    ### ここから次の箇所までが pretrained との差分 ###
                    model_out = net(
                        input_ids=inputs[0],
                        token_type_ids=inputs[1],
                        lm_labels=inputs[2],
                        mc_token_ids=inputs[3],
                        mc_labels=inputs[4],
                    )
                    lm_loss, mc_loss, lm_score, mc_score = model_out[:4]
                    lm_weight = 2.0
                    mc_weight = 1.0
                    loss = lm_loss * lm_weight + mc_loss * mc_weight
                    ### ここまで ###
                    
                    if phase == PHASE_TRAIN:
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(net.parameters(), train_config.max_grad_norm)
                        train_config.optimizer.step()
                        train_config.scheduler.step()

                        num_iters += 1

                    # epoch loss を更新
                    ### ここも変更点 ###
                    epoch_loss += loss.item() * inputs[0].size()[0]
                    ### ここまで ###

                    # TensorBoardへの描画を行う
                    # 学習時のみlossを描画
                    if phase == PHASE_TRAIN:
                        train_config.writer.add_scalars("train/loss", {phase: loss.item()}, num_iters)
                        train_config.writer.add_scalars("train/lm_loss", {phase: lm_loss.item()}, num_iters)
                        train_config.writer.add_scalars("train/mc_loss", {phase: mc_loss.item()}, num_iters)

                        train_config.writer.add_scalars("train/lr", {phase: train_config.scheduler.get_lr()[0]}, num_iters)

            epoch_loss = epoch_loss / len(dataloader_dict[phase].dataset)
            epoch_ppl = calc_ppl(epoch_loss)
            print("phase {}, loss: {:.4f}, ppl: {:.4f}".format(phase, epoch_loss, epoch_ppl))

            if train_config.writer and phase == PHASE_VAL:
                train_config.writer.add_scalars("train/loss", {phase: epoch_loss}, num_iters)
                train_config.writer.add_scalars("metric/ppl", {phase: epoch_ppl}, num_iters)

                
                if best["loss"] > epoch_loss:
                    best = {"model": net, "epoch": epoch, "loss": epoch_loss, "ppl": epoch_ppl}
                    num_patience = 0
                    # save model
                    if model_output_dir:
                        print("Save model, epoch:", epoch)
                        net.save_pretrained(model_output_dir)
                else:
                    num_patience += 1
                    print("Patience {}, epoch: {}".format(num_patience, epoch))
                    
                if num_patience > train_config.patience:
                    return

In [None]:
import os
from torch.utils.tensorboard import SummaryWriter


def get_texts(filepath):
    return [line.strip("\n").split("\t") for line in open(filepath)]


def train(model_output_dir, params):
    # Fix seed for reproducability
    set_seed(seed=params.seed)
    
    # Dataset and dataloader
    tokenizer = transformers.BertJapaneseTokenizer.from_pretrained(model_output_dir)
    dataloader_dict = {
        "train": build_dataloader(tokenizer, add_distructors(get_texts(params.data_dir + "/train.tsv"), params.num_distructors), params.batch_size, shuffle=True),
        "val": build_dataloader(tokenizer, add_distructors(get_texts(params.data_dir + "/valid.tsv"), params.num_distructors), params.batch_size, shuffle=False),
    }
    
    # Model
    net = build_model(params.pretrained_dir, tokenizer)
    
    # create train config
    optimizer = torch.optim.Adam(net.parameters(),  lr=params.learning_rate)
    total_steps = len(dataloader_dict["train"]) * params.num_epochs
    scheduler = transformers.get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=total_steps*params.warmup_rate,
        num_training_steps=total_steps
    )
    train_config = attrdict.AttrDict({
        "optimizer": optimizer,
        "scheduler": scheduler,
        "writer": SummaryWriter(log_dir=params.output_dir + "/runs/" + params.name),
        "num_epochs": params.num_epochs,
        "max_grad_norm": params.max_grad_norm,
        "patience": params.patience,
        "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    })
    
    train_model(model_output_dir, net, dataloader_dict, train_config)

In [None]:
train(_model_output_dir, _params)

## Evaluate the best model

In [None]:
def evaluate(model_output_dir, params):
    # Load model
    tokenizer = transformers.BertJapaneseTokenizer.from_pretrained(model_output_dir)
    net = transformers.GPT2DoubleHeadsModel.from_pretrained(model_output_dir)
    
    dataloader_dict = {
        "val": build_dataloader(tokenizer, add_distructors(get_texts(params.data_dir + "/test.tsv"), params.num_distructors), params.batch_size, shuffle=False),
    }

    train_config = attrdict.AttrDict({
        "writer": None,
        "num_epochs": 0,
        "patience": 1,
        "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    })

    return train_model(None, net, dataloader_dict, train_config)

In [None]:
evaluate(_model_output_dir, _params)