### ContextTransformer Implementation

### Import library

In [None]:
import numpy as np
import torch
from torch import nn
import torch
import torch.nn as nn
import time
import math
from torch import Tensor
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.nn import (
    TransformerEncoder, TransformerDecoder,
    TransformerEncoderLayer, TransformerDecoderLayer
)
from janome.tokenizer import Tokenizer
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import random
from DialogueQualityEvaluator.dialoguequalityevaluator import *

random.seed(20060317)

### BASELINE

In [None]:
class Seq2SeqTransformer(nn.Module):
    def __init__(
        self, num_encoder_layers: int, num_decoder_layers: int,
        embedding_size: int, vocab_size_src: int, vocab_size_tgt: int,
        dim_feedforward:int = 512, dropout:float = 0.1, nhead:int = 8
    ):
        
        super(Seq2SeqTransformer, self).__init__()

        self.token_embedding_src = TokenEmbedding(vocab_size_src, embedding_size)
        self.positional_encoding = PositionalEncoding(embedding_size, dropout=dropout)
        encoder_layer = TransformerEncoderLayer(
            d_model=embedding_size, nhead=nhead, dim_feedforward=dim_feedforward
        )
        self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
        
        self.token_embedding_tgt = TokenEmbedding(vocab_size_tgt, embedding_size)
        decoder_layer = TransformerDecoderLayer(
            d_model=embedding_size, nhead=nhead, dim_feedforward=dim_feedforward
        )
        self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
        
        self.output = nn.Linear(embedding_size, vocab_size_tgt)

    def forward(
        self, src: Tensor, tgt: Tensor,
        mask_src: Tensor, mask_tgt: Tensor,
        padding_mask_src: Tensor, padding_mask_tgt: Tensor,
        memory_key_padding_mask: Tensor
    ):
        
        embedding_src = self.positional_encoding(self.token_embedding_src(src))
        memory = self.transformer_encoder(embedding_src, mask_src, padding_mask_src)
        embedding_tgt = self.positional_encoding(self.token_embedding_tgt(tgt))
        outs = self.transformer_decoder(
            embedding_tgt, memory, mask_tgt, None,
            padding_mask_tgt, memory_key_padding_mask
        )
        return self.output(outs)

    def encode(self, src: Tensor, mask_src: Tensor):
        return self.transformer_encoder(self.positional_encoding(self.token_embedding_src(src)), mask_src)

    def decode(self, tgt: Tensor, memory: Tensor, mask_tgt: Tensor):
        return self.transformer_decoder(self.positional_encoding(self.token_embedding_tgt(tgt)), memory, mask_tgt)

In [None]:
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size, embedding_size):
        
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.embedding_size = embedding_size
        
    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.embedding_size)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_size: int, dropout: float, maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        
        den = torch.exp(-torch.arange(0, embedding_size, 2) * math.log(10000) / embedding_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        embedding_pos = torch.zeros((maxlen, embedding_size))
        embedding_pos[:, 0::2] = torch.sin(pos * den)
        embedding_pos[:, 1::2] = torch.cos(pos * den)
        embedding_pos = embedding_pos.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('embedding_pos', embedding_pos)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.embedding_pos[: token_embedding.size(0), :])

In [None]:
k = 300000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open("input/train.txt", "r") as f:
    dialog = f.read().split("\n")[0:k]
    
tokenizer = Tokenizer()
vocabs = []

for sentence in tqdm(dialog):
    vocabs += [x.surface for x in tokenizer.tokenize(sentence)]
vocabs = list(np.unique(vocabs))
    
vocabs += ["<PAD>", "<EOS>", "<UNK>", "<BEGIN>"]
word2id = dict(zip(vocabs, range(0, len(vocabs))))
id2word = dict(zip(range(0, len(vocabs)), vocabs))
len(word2id.keys())

In [None]:
def sentence2id(sentence):
    result = []
    for x in tokenizer.tokenize(sentence):
        if x.surface in word2id.keys():
            result.append(word2id[x.surface])
        else:
            result.append(word2id["<EOS>"])
    return result
            
def id2sentence(ids):
    return [id2word[x] for x in ids]

In [None]:
max_len = 60
train_data = []

for i in tqdm(range(0, len(dialog), 2)):
    x = sentence2id(dialog[i])
    y = sentence2id(dialog[i])
    
    if len(x) >= max_len-1 or len(y) >= max_len-2:
        pass
    else:
        train_X = torch.tensor(x + [word2id["<EOS>"]] + [word2id["<PAD>"]] * (max_len - len(x) + 1))
        train_Y = torch.tensor([word2id["<BEGIN>"]] + y + [word2id["<EOS>"]] + [word2id["<PAD>"]] * (max_len - len(y) + 2))
        train_data.append((train_X, train_Y))

random.shuffle(train_data)
train_data, test_data = train_test_split(train_data, train_size=0.9)

In [None]:
epoch = 1
batch_size = 128

embedding_size = 128
nhead = 8

dim_feedforward = 128
num_encoder_layers = 6
num_decoder_layers = 6
dropout = 0.1
vocab_size = len(word2id.keys()) + 1

PAD_IDX = word2id["<PAD>"]
START_IDX = word2id["<BEGIN>"]
EOS_IDX = word2id["<EOS>"]

In [None]:
def create_mask(src, tgt, PAD_IDX):
    seq_len_src = src.shape[0]
    seq_len_tgt = tgt.shape[0]

    mask_src = torch.zeros((seq_len_src, seq_len_src), device=device).type(torch.bool)
    mask_tgt = generate_square_subsequent_mask(seq_len_tgt)

    padding_mask_src = (src == PAD_IDX).transpose(0, 1)
    padding_mask_tgt = (tgt == PAD_IDX).transpose(0, 1)
    
    return mask_src, mask_tgt, padding_mask_src, padding_mask_tgt


def generate_square_subsequent_mask(seq_len):
    mask = (torch.triu(torch.ones((seq_len, seq_len), device=device)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def generate_batch(data_batch):
    batch_src, batch_tgt = [], []
    for src, tgt in data_batch:
        batch_src.append(src)
        batch_tgt.append(tgt)
        
    batch_src = pad_sequence(batch_src, padding_value=PAD_IDX)
    batch_tgt = pad_sequence(batch_tgt, padding_value=PAD_IDX)
    
    return batch_src, batch_tgt

In [None]:
train_iter = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)
test_iter  = DataLoader(test_data, batch_size=batch_size, shuffle=True, collate_fn=generate_batch)

In [None]:
def train(model, data, optimizer, criterion, PAD_IDX):
    model.train()
    losses = 0
    for src, tgt in tqdm(data):
        
        src = src.to(device)
        tgt = tgt.to(device)

        input_tgt = tgt[:-1, :]

        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_IDX)
        
        logits = model(
            src=src, tgt=input_tgt,
            mask_src=mask_src, mask_tgt=mask_tgt,
            padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
            memory_key_padding_mask=padding_mask_src
        )

        optimizer.zero_grad()

        output_tgt = tgt[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()
        
    return losses / len(data)

def evaluate(model, data, criterion, PAD_IDX):
    model.eval()
    losses = 0
    for src, tgt in data:
        src = src.to(device)
        tgt = tgt.to(device)

        input_tgt = tgt[:-1, :]

        mask_src, mask_tgt, padding_mask_src, padding_mask_tgt = create_mask(src, input_tgt, PAD_IDX)

        logits = model(
            src=src, tgt=input_tgt,
            mask_src=mask_src, mask_tgt=mask_tgt,
            padding_mask_src=padding_mask_src, padding_mask_tgt=padding_mask_tgt,
            memory_key_padding_mask=padding_mask_src
        )
        
        output_tgt = tgt[1:, :]
        loss = criterion(logits.reshape(-1, logits.shape[-1]), output_tgt.reshape(-1))
        losses += loss.item()
        
    return losses / len(data)

In [None]:
model = Seq2SeqTransformer(
    num_encoder_layers=num_encoder_layers,
    num_decoder_layers=num_decoder_layers,
    embedding_size=embedding_size,
    vocab_size_src=vocab_size, vocab_size_tgt=vocab_size,
    dim_feedforward=dim_feedforward,
    dropout=dropout, nhead=nhead
)

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

model = model.to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(model.parameters())

best_loss = float('Inf')
best_model = None
patience = 10
counter = 0

In [None]:
for loop in range(1, epoch + 1):
    start_time = time.time()
    loss_train = train(
        model=model, data=train_iter, optimizer=optimizer,
        criterion=criterion, PAD_IDX=PAD_IDX
    )
    elapsed_time = time.time() - start_time
    loss_valid = evaluate(
        model=model, data=test_iter, criterion=criterion, PAD_IDX=PAD_IDX
    )
    
    print('[{}/{}] train loss: {:.2f}, valid loss: {:.2f}  [{}{:.0f}s] counter: {} {}'.format(
        loop, epoch,
        loss_train, loss_valid,
        str(int(math.floor(elapsed_time / 60))) + 'm' if math.floor(elapsed_time / 60) > 0 else '',
        elapsed_time % 60,
        counter,
        '**' if best_loss > loss_valid else ''
    ))
    
    if best_loss > loss_valid:
        best_loss = loss_valid
        best_model = model
        counter = 0
        
    if counter > patience:
        break
    
    counter += 1
    
torch.save(best_model.to('cpu'), 'transformer.pth')

In [None]:
class TransformerResponder():
    def __init__(self, model, start_idx, eos_idx, max_len):
        self.model = model
        self.model.eval()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.start_idx = start_idx
        self.end_idx = eos_idx
        self.seq_len_tgt = max_len

    def greedy_decode(self, src, mask_src):
        src = src.to(self.device)
        mask_src = mask_src.to(self.device)

        memory = self.model.encode(src, mask_src)
        memory = self.model.transformer_encoder(self.model.positional_encoding(self.model.token_embedding_src(src)), mask_src)
        ys = torch.ones(1, 1).fill_(self.start_idx).type(torch.long).to(self.device)
    
        for i in range(self.seq_len_tgt - 1):
            memory = memory.to(self.device)
            memory_mask = torch.zeros(ys.shape[0], memory.shape[0]).to(self.device).type(torch.bool)
            mask_tgt = (self.generate_square_subsequent_mask(ys.size(0)).type(torch.bool)).to(self.device)
        
            output = self.model.decode(ys, memory, mask_tgt)
            output = output.transpose(0, 1)
            output = self.model.output(output[:, -1])
            _, next_word = torch.max(output, dim = 1)
            next_word = next_word.item()

            ys = torch.cat([ys, torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
            if next_word == self.end_idx:
                break
        return ys

    def generate_square_subsequent_mask(self, seq_len):
        mask = (torch.triu(torch.ones((seq_len, seq_len), device=self.device)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
    
    def predict(self, x):
        self.model.eval()
        num_tokens = len(x)
        src = torch.LongTensor(x).reshape(num_tokens, 1)
        mask_src = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
        
        predicts = self.greedy_decode(src, mask_src).flatten()
        return [id2word[x.item()] for x in predicts]
        
    def predict_sentence(self, sentence):
        e = sentence2id(sentence) + [self.end_idx]
        return self.predict(e)

In [None]:
predictor = TransformerResponder(model, START_IDX, EOS_IDX, max_len)
def pseudo_model(input_logs):
    return [predictor.predict_sentence(input_logs[0])]

# Run test.
score = evaluate_dialogue_quality_of_model("./input/test.txt", pseudo_model)
print(f"Score: {score}")