In [1]:
import torch
import torch.nn as nn

import torchtext
from torchtext.data import Field

from torchtext import data
from transformers import T5Tokenizer, T5Model

In [2]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')

In [3]:
init_token = tokenizer.pad_token
eos_token = tokenizer.eos_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

<pad> </s> <pad> <unk>


In [4]:
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

0 1 0 2


In [5]:
max_input_length = tokenizer.max_model_input_sizes['t5-small']

print(max_input_length)

512


In [6]:
def tokenize_and_cut(sentence):
    tokens = tokenizer.tokenize(sentence) 
    tokens = tokens[:max_input_length-2]
    return tokens

In [7]:
SRC = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = tokenize_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  init_token = init_token_idx,
                  eos_token = eos_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

TRG = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = tokenize_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  init_token = init_token_idx,
                  eos_token = eos_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

In [8]:
class T5Network(nn.Module):
    def __init__(self):
        
        super().__init__()
        
        self.t5 = t5 = T5Model.from_pretrained('t5-small')
        
        self.out = nn.Linear(t5.config.to_dict()['d_model'], t5.config.to_dict()['vocab_size'])
                
    def forward(self, src, trg):
        
        embedded = self.t5(input_ids=src, decoder_input_ids=trg)
        
        output = self.out(embedded[0])
        
        return output

In [9]:
models = []

for i in range(4):
    new_model = T5Network().cuda()
    new_model.load_state_dict(torch.load(f'model_{i+1}.pt'))
    models.append(new_model)

Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decod

In [10]:
def translate_sentence(sentence, src_field, trg_field, model, max_len = 50):
    model.eval()

    src_indexes = [init_token_idx] + sentence + [eos_token_idx]
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).cuda()

    trg_indexes = [init_token_idx]

    for i in range(max_len):

        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).cuda()
        
        with torch.no_grad():
            output = model(src_tensor, trg_tensor)
        
        pred_token = output.argmax(2)[:,-1].item()
        
        trg_indexes.append(pred_token)

        if pred_token == eos_token_idx:
            break
            
    return trg_indexes

In [11]:
#CONTEXT = "The COVID‑19 pandemic, also known as the coronavirus pandemic, is an ongoing global pandemic of coronavirus disease 2019 (COVID‑19), caused by severe acute respiratory syndrome coronavirus 2 (SARS‑CoV‑2). The outbreak was first identified in December 2019 in Wuhan, China. The World Health Organization declared the outbreak a Public Health Emergency of International Concern on 30 January 2020 and a pandemic on 11 March. As of 23 August 2020, more than 23.1 million cases of COVID‑19 have been reported in more than 188 countries and territories, resulting in more than 802,000 deaths; more than 14.8 million people have recovered."
#QUERY = "how many people have died of covid-19 ?"
#QUERY = "how many people have recovered from covid-19 ?"
#QUERY = "where was the outbreak first identified ?"
#QUERY = "when did who declare emergency ?"

#CONTEXT = "Common symptoms include fever, cough, fatigue, shortness of breath, and loss of sense of smell. Complications may include pneumonia and acute respiratory distress syndrome. The time from exposure to onset of symptoms is typically around five days but may range from two to fourteen days. There are several vaccine candidates in development, although none have completed clinical trials to prove their safety and efficacy. There is no known specific antiviral medication, so primary treatment is currently symptomatic."
#QUERY = "what are some symptoms ?"
#QUERY = "when do symptoms start to appear ?"
#QUERY = "are there any medications ?"

#CONTEXT = "Coronaviruses constitute the subfamily Orthocoronavirinae, in the family Coronaviridae, order Nidovirales, and realm Riboviria. They are enveloped viruses with a positive-sense single-stranded RNA genome and a nucleocapsid of helical symmetry. The genome size of coronaviruses ranges from approximately 26 to 32 kilobases, one of the largest among RNA viruses. They have characteristic club-shaped spikes that project from their surface, which in electron micrographs create an image reminiscent of the solar corona, from which their name derives."
#QUERY = "what is genome size of coronavirus ?"
#QUERY = "how do coronavirus look like ?"
#QUERY = "what family does coronavirus belong to ?"

#CONTEXT = "SARS was a relatively rare disease; at the end of the epidemic in June 2003, the incidence was 8,422 cases with a case fatality rate (CFR) of 11%. No cases of SARS-CoV have been reported worldwide since 2004."
#QUERY = "how many cases of sars ?"
#QUERY = "what is the case fatality rate ?"

text = "context : " + CONTEXT.lower() + " question : " + QUERY.lower()
tokens = tokenizer.tokenize(text)

In [12]:
print(text)
print('\n')
print(tokens)
print('\n')
print(tokenizer.convert_tokens_to_ids(tokens))

context : sars was a relatively rare disease; at the end of the epidemic in june 2003, the incidence was 8,422 cases with a case fatality rate (cfr) of 11%. no cases of sars-cov have been reported worldwide since 2004. question : what is the case fatality rate ?


['▁context', '▁', ':', '▁', 's', 'ar', 's', '▁was', '▁', 'a', '▁relatively', '▁rare', '▁disease', ';', '▁at', '▁the', '▁end', '▁of', '▁the', '▁epidemic', '▁in', '▁', 'jun', 'e', '▁2003', ',', '▁the', '▁incidence', '▁was', '▁8,', '42', '2', '▁cases', '▁with', '▁', 'a', '▁case', '▁fatal', 'ity', '▁rate', '▁(', 'c', 'f', 'r', ')', '▁of', '▁1', '1%', '.', '▁no', '▁cases', '▁of', '▁', 's', 'ar', 's', '-', 'cov', '▁have', '▁been', '▁reported', '▁worldwide', '▁since', '▁2004', '.', '▁question', '▁', ':', '▁what', '▁is', '▁the', '▁case', '▁fatal', 'ity', '▁rate', '▁', '?']


[2625, 3, 10, 3, 7, 291, 7, 47, 3, 9, 4352, 3400, 1994, 117, 44, 8, 414, 13, 8, 24878, 16, 3, 6959, 15, 3888, 6, 8, 20588, 47, 9478, 4165, 357, 1488, 28, 3, 9, 4

In [13]:
pred_tokens1 = translate_sentence(tokenizer.convert_tokens_to_ids(tokens), SRC, TRG, models[0])
pred_tokens2 = translate_sentence(tokenizer.convert_tokens_to_ids(tokens), SRC, TRG, models[1])
pred_tokens3 = translate_sentence(tokenizer.convert_tokens_to_ids(tokens), SRC, TRG, models[2])
pred_tokens4 = translate_sentence(tokenizer.convert_tokens_to_ids(tokens), SRC, TRG, models[3])

In [14]:
print(' '.join(tokenizer.convert_ids_to_tokens(pred_tokens1)))
print(' '.join(tokenizer.convert_ids_to_tokens(pred_tokens2)))
print(' '.join(tokenizer.convert_ids_to_tokens(pred_tokens3)))
print(' '.join(tokenizer.convert_ids_to_tokens(pred_tokens4)))

<pad> ▁1 1% </s>
<pad> ▁1 1% </s>
<pad> ▁1 1% </s>
<pad> ▁1 1% </s>
