In [1]:
import torch
from transformer import Transformer, AdamWithWarpUp, LossCE
from torch.utils.data import Dataset
import torch.utils.data
from transformers import GPT2Tokenizer

In [None]:
import json
from pathlib import Path

In [None]:
convs_path = './conversations.txt'
lines_path = './lines.txt'
encoded_qa_path = './pairs_encoded.json'

In [None]:
seq_len = 30
encoder_layers = 6
decoder_layers = 6

embed_dim = 512
num_heads = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:" + str(device))
epochs = 2

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
pad_token = 0

vocab_size = len(tokenizer)  # FROM GPT2
print(f"vocab_tokenizer size:{vocab_size}")
smoothing_rate = 0.1

In [None]:
if Path(convs_path).exists():
    with open(convs_path, 'r') as c:
        convs = c.readlines()
else:
    print("Not found conversation file... Exiting!")
    exit()


In [None]:
if Path(lines_path).exists():
    with open(lines_path, 'r') as l:
        lines = l.readlines()
else:
    print("Not found lines file... Exiting!")
    exit()

In [None]:
lines_id_data_dict = {}
for line in lines:
    line_info = line.split(" +++$+++ ")
    line_id = line_info[0]
    line_data = line_info[-1]
    lines_id_data_dict[line_id] = line_data

In [None]:
def remove_punctuations(p_str):
    punctuations = '''<>./?@#$%^&*_~!()-[]{};:'"\,'''  #save spaces
    new_str = ""
    for char in p_str:
        if char not in punctuations:
            new_str = new_str + char
    new_str = new_str.lower()  #avoid upper and lower difference
    return new_str

In [None]:
question_answer_pairs = []
for conv in convs:
    line_ids_str = conv.split(" +++$+++ ")[-1]
    line_ids = eval(line_ids_str)  #get line id via eval 'L195' -> L195
    for i in range(len(line_ids) - 1):
        curr_line = lines_id_data_dict[line_ids[i]].strip()
        next_line = lines_id_data_dict[line_ids[i + 1]].strip()
        question = remove_punctuations(curr_line)
        answer = remove_punctuations(next_line)
        question_answer_pairs.append([question, answer])

In [None]:
def encode_question_tokenizer(question_words, seq_len, tokenizer, pad_token):
    #enc_q_without_pad = tokenizer(question_words)['input_ids'][:seq_len]
    enc_q_without_pad = tokenizer.encode(question_words, add_special_tokens=True, truncation=True, max_length=seq_len)
    padding_len = seq_len - len(enc_q_without_pad)
    enc_q = enc_q_without_pad + [pad_token] * padding_len
    return enc_q


def encode_answer_tokenizer(answer_words, seq_len, tokenizer, pad_token):
    #enc_a_without_pad = tokenizer(answer_words)['input_ids'][:seq_len]
    enc_a_without_pad = tokenizer.encode(answer_words, add_special_tokens=True, truncation=True, max_length=seq_len)
    padding_len = seq_len - len(enc_a_without_pad)
    enc_a = [tokenizer.bos_token_id] + enc_a_without_pad + [tokenizer.eos_token_id] + [pad_token] * padding_len
    return enc_a

In [None]:
encoded_qa = []
for qa_pair in question_answer_pairs:
    e_question = encode_question_tokenizer(qa_pair[0], seq_len, tokenizer, pad_token)
    e_answer = encode_answer_tokenizer(qa_pair[1], seq_len, tokenizer, pad_token)
    encoded_qa.append([e_question, e_answer])

with open(encoded_qa_path, 'w') as qa_p:
    json.dump(encoded_qa, qa_p)

In [None]:
transformer = Transformer(embed_dim, num_heads, encoder_layers, decoder_layers, vocab_size, seq_len).to(device)
adam_opt = torch.optim.Adam(transformer.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
warmup_steps = 4000
adam_warmup_opt = AdamWithWarpUp(adam_opt, embed_dim, warmup_steps)
ce_loss = LossCE(smoothing_rate)

In [None]:
def create_mask(input_seq, mask_flag='q'):
    curr_seq_len = input_seq.size(-1)
    input_mask = input_seq != 0
    if mask_flag == 'q':  #no need to mask future, as it goes thorugh encoder
        question_mask = input_mask.unsqueeze(1).unsqueeze(1)
        mask = question_mask  # required dims : [batch_size, 1, 1, seq_len]
    elif mask_flag == 'a':  # need to mask future, as it goes thorugh decoder, via triu-transpose
        answer_mask = input_mask.unsqueeze(1)
        answer_mask = answer_mask & torch.triu(torch.ones(curr_seq_len, curr_seq_len)).transpose(0, 1).type(
            dtype=torch.uint8).unsqueeze(0).type_as(answer_mask.data)
        answer_mask = answer_mask.unsqueeze(1)
        mask = answer_mask  # required dims: [batch_size, 1, seq_len, seq_len]
    else:  #mask_flag == 't' # need only for loss
        mask = input_mask  # required dims: [batch_size, seq_len]
    return mask

In [None]:
class QuestionAnswerDataset(Dataset):

    def __init__(self):
        if Path(encoded_qa_path).exists():
            with open(encoded_qa_path, 'r') as qa_r:
                self.qa_pairs = json.load(qa_r)
        else:
            print("Not found encoded question-answer pairs file... Exiting!")
            exit()

    def __len__(self):
        return len(self.qa_pairs)

    def __getitem__(self, idx):
        question = self.qa_pairs[idx][0]
        answer = self.qa_pairs[idx][1]
        question_t = torch.LongTensor(question)
        answer_t = torch.LongTensor(answer)
        return question_t, answer_t

In [None]:
batch_size = 1000
train_loader = torch.utils.data.DataLoader(QuestionAnswerDataset(),
                                           batch_size=batch_size,
                                           shuffle=True,
                                           pin_memory=True)

In [None]:
for epoch in range(epochs):

    transformer.train()

    batches_total_loss = 0
    batches_total_size = 0

    for i, (question, answer) in enumerate(train_loader):

        adam_warmup_opt.optimizer.zero_grad()

        batch_size = question.shape[0]

        question = question.to(device)
        answer = answer.to(device)

        answer_input = answer[:, :-1]
        answer_target = answer[:, 1:]

        question_mask = create_mask(question, mask_flag='q').to(device)
        answer_input_mask = create_mask(answer_input, mask_flag='a').to(device)
        answer_target_mask = create_mask(answer_target, mask_flag='t').to(device)

        preds = transformer(question, question_mask, answer_input, answer_input_mask)

        loss = ce_loss(preds, answer_target, answer_target_mask)

        loss.backward()
        adam_warmup_opt.step()

        batches_total_loss += loss.item() * batch_size
        batches_total_size += batch_size

        if i % 100 == 0:
            print(
                f"Epoch: [{epoch}] Batch:[{i}/{len(train_loader)}]\tLoss: {batches_total_loss / batches_total_size:.3f}")

    state = {'epoch': epoch, 'model': transformer, 'optimizer': adam_warmup_opt}
    torch.save(state, 'model_epoch_' + str(epoch) + '.pth')
    print("saved model on epoch: " + str(epoch))