In [1]:
import torch
import torch.nn as nn

from transformers import T5Tokenizer
from torch import jit

In [2]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')

In [3]:
tokenizer.vocab_size

32100

In [4]:
tokens = tokenizer.tokenize('Hello world how are you?')

print(tokens)

['▁Hello', '▁world', '▁how', '▁are', '▁you', '?']


In [5]:
indexes = tokenizer.convert_tokens_to_ids(tokens)

print(indexes)

[8774, 296, 149, 33, 25, 58]


In [6]:
init_token = tokenizer.pad_token
eos_token = tokenizer.eos_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

<pad> </s> <pad> <unk>


In [7]:
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

0 1 0 2


In [8]:
max_input_length = tokenizer.max_model_input_sizes['t5-small']

print(max_input_length)

512


In [9]:
new_model = jit.load('t5_ts_qa_model.zip')

In [10]:
def translate_sentence2(sentence, eval_model, max_len = 50):
    
    eval_model.eval()
    eval_model = eval_model.float()

    src_indexes = [init_token_idx] + sentence + [eos_token_idx]

    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0)

    trg_indexes = [init_token_idx]

    for i in range(max_len):

        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0)
        
        with torch.no_grad():
            
            output = eval_model(src_tensor, trg_tensor)
        
        pred_token = output.argmax(2)[:,-1].item()
        
        trg_indexes.append(pred_token)

        if pred_token == eos_token_idx:
            break

    return trg_indexes[1:-1]

In [11]:
def return_answer(context,query):
    txt = 'context : ' + context.lower() + ' question : ' + query.lower()
    txt_tokens = tokenizer.tokenize(txt)
    txt_ids = tokenizer.convert_tokens_to_ids(txt_tokens)
    pred = translate_sentence2(txt_ids, new_model)
    pred_tokens = tokenizer.convert_ids_to_tokens(pred)
    
    return ''.join(pred_tokens)

In [24]:
CONTEXT = "The COVID-19 pandemic, also known as the coronavirus pandemic, is an ongoing pandemic of coronavirus disease 2019 (COVID-19) caused by severe acute respiratory syndrome coronavirus 2 (SARS-CoV-2), first identified in December 2019 in Wuhan, China. The outbreak was declared a Public Health Emergency of International Concern in January 2020 and a pandemic in March 2020. As of 9 November 2020, more than 50.4 million cases have been confirmed, with more than 1.25 million deaths attributed to COVID-19, and more than 32.8 million recovered."
QUERY = "when was covid-19 declared a pandemic ?"

In [25]:
return_answer(CONTEXT,QUERY)

'▁march▁2020'