In [1]:
import torch
from transformer import Transformer
from torch.utils.data import Dataset
import torch.utils.data
from transformers import GPT2Tokenizer

In [2]:
import json
from pathlib import Path

In [4]:
seq_len = 30
encoder_layers = 6
decoder_layers = 6

embed_dim = 512
num_heads = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:" +str(device))

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
pad_token = 0

vocab_size = len(tokenizer) # FROM GPT2
print(f"vocab_tokenizer size:{vocab_size}")


device:cuda
vocab_tokenizer size:50257


In [5]:
def encode_question_tokenizer(question_words,seq_len, tokenizer, pad_token):
    enc_q_without_pad = tokenizer.encode(question_words, add_special_tokens=True, truncation=True, max_length=seq_len)
    padding_len = seq_len - len(enc_q_without_pad)
    enc_q = enc_q_without_pad + [pad_token] * padding_len 
    return enc_q

def encode_answer_tokenizer(answer_words,seq_len, tokenizer, pad_token):
    enc_a_without_pad = tokenizer.encode(answer_words, add_special_tokens=True, truncation=True , max_length=seq_len)
    padding_len = seq_len - len(enc_a_without_pad)
    enc_a = [tokenizer.bos_token_id] + enc_a_without_pad + [tokenizer.eos_token_id] + [pad_token] * padding_len
    return enc_a

In [6]:
def create_mask(input_seq, mask_flag='q'):
    curr_seq_len = input_seq.size(-1)
    input_mask =  input_seq != 0
    if mask_flag == 'q': #no need to mask future, as it goes thorugh encoder
        question_mask = input_mask.unsqueeze(1).unsqueeze(1)         
        mask = question_mask # required dims : [batch_size, 1, 1, seq_len]
    elif mask_flag == 'a': # need to mask future, as it goes thorugh decoder, via triu-transpose
        answer_mask = input_mask.unsqueeze(1)
        answer_mask = answer_mask & torch.triu(torch.ones(curr_seq_len, curr_seq_len)).transpose(0, 1).type(dtype=torch.uint8).unsqueeze(0).type_as(answer_mask.data) 
        answer_mask = answer_mask.unsqueeze(1)
        mask = answer_mask # required dims: [batch_size, 1, seq_len, seq_len]
    else: #mask_flag == 't' # need only for loss
        mask = input_mask # required dims: [batch_size, seq_len] 
    return mask

In [7]:
epoch_model_to_load = 1
load_model = torch.load('model_epoch_' + str(epoch_model_to_load) + '.pth',map_location=device)
model = load_model['model']
print("load model of epoch: "+str(epoch_model_to_load))

load model of epoch: 1


In [8]:
def testing(model, q, seq_len ,req_len, tokenizer, pad_token ,device):
    
    def answer_creation(target_seq,tokenizer):
        if target_seq.dim() > 1: #squeeze batch dim
            target_seq = target_seq.squeeze(0)

        target_seq = target_seq.tolist()[1:] #avoid bos token
        if target_seq[-1] == tokenizer.eos_token_id:
            target_seq = target_seq[:-1] #avoid eos token
            
        answer = ''
        for i,w_idx in enumerate(target_seq):
            decoded_w = tokenizer.decode(w_idx).strip()
            curr_w = str(decoded_w).lower()
            if i == 0:
                curr_w = curr_w.capitalize()
            curr_update = ''
            if i!= len(target_seq)-1:
                curr_update = curr_w+' '
            else:
                curr_update = curr_w
            answer += curr_update
        return answer

    model.eval()
    
    encoded_question = encode_question_tokenizer(q,seq_len, tokenizer, pad_token)
    encoded_question = torch.LongTensor(encoded_question).unsqueeze(0).to(device)
    question_mask = create_mask(encoded_question, mask_flag='q').to(device)
    
    target_seq = torch.LongTensor([[tokenizer.bos_token_id]]).to(device) #starts with eos_token
    
    for j in range(1, req_len):
        
        target_mask = create_mask(target_seq, mask_flag='a').to(device)
        predictions = model(encoded_question, question_mask, target_seq, target_mask)
        _, curr_words = torch.max(predictions, dim = 1)
        next_word = curr_words[-1]
        next_w_idx = next_word.item()
        target_seq = torch.cat([target_seq, torch.LongTensor([[next_w_idx]]).to(device)], dim = 1)
        if next_w_idx == tokenizer.eos_token_id:
            break
        
    return answer_creation(target_seq,tokenizer)

In [9]:
question = ''
while question != 'q':
    question = input("Welcome to Q-A system!\nTap q to finish.\nEnter your question: ") 
    if question == 'q':
        break
    req_len = torch.randint(5,seq_len,(1,))
    answer = testing(model, question, seq_len, req_len.item(), tokenizer, pad_token ,device)
    print(answer)

Welcome to Q-A system!
Tap q to finish.
Enter your question:  what is your name?


I dont know


Welcome to Q-A system!
Tap q to finish.
Enter your question:  how good is your day


I dont know


Welcome to Q-A system!
Tap q to finish.
Enter your question:  are you a guy


I dont know


Welcome to Q-A system!
Tap q to finish.
Enter your question:  q
