In [1]:
import torch
import torch.nn as nn

import torchtext
from torchtext.data import Field

from torchtext import data
from transformers import T5Tokenizer, T5Model

In [2]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')

In [3]:
init_token = tokenizer.pad_token
eos_token = tokenizer.eos_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

<pad> </s> <pad> <unk>


In [4]:
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

0 1 0 2


In [5]:
max_input_length = tokenizer.max_model_input_sizes['t5-small']

print(max_input_length)

512


In [6]:
def tokenize_and_cut(sentence):
    tokens = tokenizer.tokenize(sentence) 
    tokens = tokens[:max_input_length-2]
    return tokens

In [7]:
SRC = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = tokenize_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  init_token = init_token_idx,
                  eos_token = eos_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

TRG = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = tokenize_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  init_token = init_token_idx,
                  eos_token = eos_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

In [8]:
class T5Network(nn.Module):
    def __init__(self):
        
        super().__init__()
        
        self.t5 = t5 = T5Model.from_pretrained('t5-small')
        
        self.out = nn.Linear(t5.config.to_dict()['d_model'], t5.config.to_dict()['vocab_size'])
                
    def forward(self, src, trg):
        
        embedded = self.t5(input_ids=src, decoder_input_ids=trg)
        
        output = self.out(embedded[0])
        
        return output

In [9]:
models = []

for i in range(4):
    new_model = T5Network().cuda()
    new_model.load_state_dict(torch.load(f'model_{i+1}.pt'))
    models.append(new_model)

Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decod

In [10]:
def model_ensemble_output(models, src_tensor, trg_tensor):
    return models[0](src_tensor, trg_tensor) + models[1](src_tensor, trg_tensor) + models[2](src_tensor, trg_tensor) + models[3](src_tensor, trg_tensor)

In [11]:
def translate_sentence(sentence, src_field, trg_field, models, max_len = 50):
    for m in models:
        m.eval()

    src_indexes = [init_token_idx] + sentence + [eos_token_idx]
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).cuda()

    trg_indexes = [init_token_idx]

    for i in range(max_len):

        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).cuda()
        
        with torch.no_grad():
            output = model_ensemble_output(models, src_tensor, trg_tensor)
        pred_token = output.argmax(2)[:,-1].item()
        
        trg_indexes.append(pred_token)

        if pred_token == eos_token_idx:
            break
            
    return trg_indexes

In [12]:
CONTEXT = 'The COVID‑19 pandemic, also known as the coronavirus pandemic, is an ongoing global pandemic of coronavirus disease 2019 (COVID‑19), caused by severe acute respiratory syndrome coronavirus 2 (SARS‑CoV‑2). The outbreak was first identified in December 2019 in Wuhan, China. The World Health Organization declared the outbreak a Public Health Emergency of International Concern on 30 January 2020 and a pandemic on 11 March. As of 24 August 2020, more than 23.4 million cases of COVID‑19 have been reported in more than 188 countries and territories, resulting in more than 808,000 deaths; more than 15.1 million people have recovered.'
QUERIES = ['where was the outbreak first identified ?',
           'when was the outbreak first identified ?',
           'how many people have died from covid-19 ?',
           'how many people have recovered from covid-19 ?',
           'when did the world health organization declare an emergency ?',
           'when did the world health organization declare a pandemic ?']

# CONTEXT = 'Common symptoms include fever, cough, fatigue, shortness of breath, and loss of sense of smell. Complications may include pneumonia and acute respiratory distress syndrome. The time from exposure to onset of symptoms is typically around five days but may range from two to fourteen days. There are several vaccine candidates in development, although none have completed clinical trials to prove their safety and efficacy. There is no known specific antiviral medication, so primary treatment is currently symptomatic.'
# QUERIES = ['what are some symptoms ?',
#            'how many days to show symptoms ?',
#            'how many vaccines have completed clinical trials ?',
#            'are there any medication ?']

# CONTEXT = 'Coronaviruses are a group of related RNA viruses that cause diseases in mammals and birds. In humans and birds, they cause respiratory tract infections that can range from mild to lethal. Mild illnesses in humans include some cases of the common cold (which is also caused by other viruses, predominantly rhinoviruses), while more lethal varieties can cause SARS, MERS, and COVID-19. In cows and pigs they cause diarrhea, while in mice they cause hepatitis and encephalomyelitis. There are as yet no vaccines or antiviral drugs to prevent or treat human coronavirus infections.'
# QUERIES = ['what group does coronavirus belong to ?',
#            'what are lethal varities of coronavirus ?',
#            'how does coronavirus affect cows ?',
#            'are there any vaccines ?',
#            'how does coronavirus affect humans ?']

# CONTEXT = 'Coronaviruses constitute the subfamily Orthocoronavirinae, in the family Coronaviridae, order Nidovirales, and realm Riboviria. They are enveloped viruses with a positive-sense single-stranded RNA genome and a nucleocapsid of helical symmetry. The genome size of coronaviruses ranges from approximately 26 to 32 kilobases, one of the largest among RNA viruses. They have characteristic club-shaped spikes that project from their surface, which in electron micrographs create an image reminiscent of the solar corona, from which their name derives.'
# QUERIES = ['what is the genome size of coronavirus ?',
#            'what family does coronavirus belong to ?',
#            'what subfamily does coronavirus belong to ?']

In [13]:
def str_result(tokens):
    result = ''
    sep_char = tokens[1][0]
    for t in tokens[1:-1]:
        result = result+t
        
    result = ' '.join(result.split(sep_char)[1:])
    return result

In [14]:
for query in QUERIES:
    text = "context : " + CONTEXT.lower() + " query : " + query.lower()
    tokens = tokenizer.tokenize(text)
    
    print(f"INPUT TEXT\n{text}\n")
    print(f"INPUT TOKENS\n{tokens}\n")
    print(tokenizer.convert_tokens_to_ids(tokens))

    pred_tokens = translate_sentence(tokenizer.convert_tokens_to_ids(tokens), SRC, TRG, models)

    print("\nPREDICTIONS")
    #print(' '.join(tokenizer.convert_ids_to_tokens(pred_tokens)))
    final_result = str_result(tokenizer.convert_ids_to_tokens(pred_tokens))
    print(final_result)
    print('\n\n')

INPUT TEXT
context : the covid‑19 pandemic, also known as the coronavirus pandemic, is an ongoing global pandemic of coronavirus disease 2019 (covid‑19), caused by severe acute respiratory syndrome coronavirus 2 (sars‑cov‑2). the outbreak was first identified in december 2019 in wuhan, china. the world health organization declared the outbreak a public health emergency of international concern on 30 january 2020 and a pandemic on 11 march. as of 24 august 2020, more than 23.4 million cases of covid‑19 have been reported in more than 188 countries and territories, resulting in more than 808,000 deaths; more than 15.1 million people have recovered. query : where was the outbreak first identified ?

INPUT TOKENS
['▁context', '▁', ':', '▁the', '▁co', 'vid', '‐', '19', '▁pan', 'de', 'mic', ',', '▁also', '▁known', '▁as', '▁the', '▁cor', 'on', 'a', 'virus', '▁pan', 'de', 'mic', ',', '▁is', '▁an', '▁ongoing', '▁global', '▁pan', 'de', 'mic', '▁of', '▁cor', 'on', 'a', 'virus', '▁disease', '▁2019