In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext
from torchtext.data import Field, BucketIterator

import spacy
import numpy as np

import random
import math
import time
from torchtext import data
from transformers import T5Tokenizer, T5Model
import wikipedia

In [None]:
def filter_para(x):
    if len(x) < 20:
        return False
    if '==' in x:
        return False
    
    return True

In [None]:
def wiki_results(query):
    search_results = wikipedia.search(query, results=4)
    wiki_obj = wikipedia.page(search_results[0])
    text = wiki_obj.content
    paras = text.split('\n')
    paras = [para for para in paras if filter_para(para)]
    
    return paras

In [2]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')

In [3]:
init_token = tokenizer.pad_token
eos_token = tokenizer.eos_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

<pad> </s> <pad> <unk>


In [4]:
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

0 1 0 2


In [5]:
max_input_length = tokenizer.max_model_input_sizes['t5-small']

print(max_input_length)

512


In [6]:
def tokenize_and_cut(sentence):
    tokens = tokenizer.tokenize(sentence) 
    tokens = tokens[:max_input_length-2]
    return tokens

In [7]:
SRC = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = tokenize_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  init_token = init_token_idx,
                  eos_token = eos_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

TRG = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = tokenize_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  init_token = init_token_idx,
                  eos_token = eos_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

In [8]:
class T5Network(nn.Module):
    def __init__(self):
        
        super().__init__()
        
        self.t5 = t5 = T5Model.from_pretrained('t5-small')
        
        self.out = nn.Linear(t5.config.to_dict()['d_model'], t5.config.to_dict()['vocab_size'])
                
    def forward(self, src, trg):
        
        embedded = self.t5(input_ids=src, decoder_input_ids=trg)
        
        output = self.out(embedded[0])
        
        return output

In [9]:
model = T5Network().cuda()
model.load_state_dict(torch.load('model_3.pt'))

Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [10]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 76,988,544 trainable parameters


In [11]:
def translate_sentence(sentence, src_field, trg_field, model, max_len = 50):
    model.eval()

    src_indexes = [init_token_idx] + sentence + [eos_token_idx]
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).cuda()

    trg_indexes = [init_token_idx]

    for i in range(max_len):

        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).cuda()
        
        with torch.no_grad():
            output = model(src_tensor, trg_tensor)
        
        pred_token = output.argmax(2)[:,-1].item()
        
        trg_indexes.append(pred_token)

        if pred_token == eos_token_idx:
            break
            
    return trg_indexes

In [12]:
CONTEXT = "harry is playing with his dog. the dog is twenty years old."
QUERY = "how old is the dog ?"

text = "context : " + CONTEXT.lower() + " question : " + QUERY.lower()
tokens = tokenizer.tokenize(text)

In [13]:
print(text)
print('\n')
print(tokens)
print('\n')
print(tokenizer.convert_tokens_to_ids(tokens))

context : harry is playing with his dog. the dog is twenty years old. question : how old is the dog ?


['▁context', '▁', ':', '▁', 'har', 'ry', '▁is', '▁playing', '▁with', '▁his', '▁dog', '.', '▁the', '▁dog', '▁is', '▁twenty', '▁years', '▁old', '.', '▁question', '▁', ':', '▁how', '▁old', '▁is', '▁the', '▁dog', '▁', '?']


[2625, 3, 10, 3, 3272, 651, 19, 1556, 28, 112, 1782, 5, 8, 1782, 19, 6786, 203, 625, 5, 822, 3, 10, 149, 625, 19, 8, 1782, 3, 58]


In [14]:
pred_tokens = translate_sentence(tokenizer.convert_tokens_to_ids(tokens), SRC, TRG, model)

In [15]:
print(pred_tokens)
print(' '.join(tokenizer.convert_ids_to_tokens(pred_tokens)))

[0, 6786, 203, 625, 1]
<pad> ▁twenty ▁years ▁old </s>
