In [1]:
import torch
import torch.nn as nn

import torchtext
from torchtext.data import Field

from torchtext import data
from transformers import T5Tokenizer, T5Model

In [2]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')

In [3]:
init_token = tokenizer.pad_token
eos_token = tokenizer.eos_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

<pad> </s> <pad> <unk>


In [4]:
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

0 1 0 2


In [5]:
max_input_length = tokenizer.max_model_input_sizes['t5-small']

print(max_input_length)

512


In [6]:
def tokenize_and_cut(sentence):
    tokens = tokenizer.tokenize(sentence) 
    tokens = tokens[:max_input_length-2]
    return tokens

In [7]:
SRC = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = tokenize_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  init_token = init_token_idx,
                  eos_token = eos_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

TRG = data.Field(batch_first = True,
                  use_vocab = False,
                  tokenize = tokenize_and_cut,
                  preprocessing = tokenizer.convert_tokens_to_ids,
                  init_token = init_token_idx,
                  eos_token = eos_token_idx,
                  pad_token = pad_token_idx,
                  unk_token = unk_token_idx)

In [8]:
class T5Network(nn.Module):
    def __init__(self):
        
        super().__init__()
        
        self.t5 = t5 = T5Model.from_pretrained('t5-small')
        
        self.out = nn.Linear(t5.config.to_dict()['d_model'], t5.config.to_dict()['vocab_size'])
                
    def forward(self, src, trg):
        
        embedded = self.t5(input_ids=src, decoder_input_ids=trg)
        
        output = self.out(embedded[0])
        
        return output

In [9]:
models = []

for i in range(4):
    new_model = T5Network().cuda()
    new_model.load_state_dict(torch.load(f'saved_models/marco_model_{i+1}.pt'))
    models.append(new_model)

Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decod

In [10]:
def translate_sentence(sentence, src_field, trg_field, model, max_len = 50):
    model.eval()

    src_indexes = [init_token_idx] + sentence + [eos_token_idx]
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).cuda()

    trg_indexes = [init_token_idx]

    for i in range(max_len):

        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).cuda()
        
        with torch.no_grad():
            output = model(src_tensor, trg_tensor)
        
        pred_token = output.argmax(2)[:,-1].item()
        
        trg_indexes.append(pred_token)

        if pred_token == eos_token_idx:
            break
            
    return trg_indexes

In [11]:
def str_result(tokens):
    tokens = tokens[1:-1]
    joined = ''.join(tokenizer.convert_ids_to_tokens(tokens))
    sep_token = joined[0]
    split = joined.split(sep_token)
    final = ' '.join(split[1:])
    
    return final

In [12]:
CONTEXT = 'The COVID‑19 pandemic, also known as the coronavirus pandemic, is an ongoing global pandemic of coronavirus disease 2019 (COVID‑19), caused by severe acute respiratory syndrome coronavirus 2 (SARS‑CoV‑2). The outbreak was first identified in December 2019 in Wuhan, China. The World Health Organization declared the outbreak a Public Health Emergency of International Concern on 30 January 2020 and a pandemic on 11 March. As of 24 August 2020, more than 23.4 million cases of COVID‑19 have been reported in more than 188 countries and territories, resulting in more than 808,000 deaths; more than 15.1 million people have recovered.'
QUERIES = ['where was the outbreak first identified ?',
           'when was the outbreak first identified ?',
           'how many people have died from covid-19 ?',
           'how many people have recovered from covid-19 ?',
           'when did the world health organization declare an emergency ?',
           'when did the world health organization declare a pandemic ?']

In [13]:
for query in QUERIES:
    text = "context : " + CONTEXT.lower() + " query : " + query.lower()
    tokens = tokenizer.tokenize(text)
    
    print("TEXT")
    print(text)
    
    pred_tokens1 = translate_sentence(tokenizer.convert_tokens_to_ids(tokens), SRC, TRG, models[0])
    pred_tokens2 = translate_sentence(tokenizer.convert_tokens_to_ids(tokens), SRC, TRG, models[1])
    pred_tokens3 = translate_sentence(tokenizer.convert_tokens_to_ids(tokens), SRC, TRG, models[2])
    pred_tokens4 = translate_sentence(tokenizer.convert_tokens_to_ids(tokens), SRC, TRG, models[3])
    
    print("\nPREDICTIONS")
    print(str_result(pred_tokens1))
    print(str_result(pred_tokens2))
    print(str_result(pred_tokens3))
    print(str_result(pred_tokens4))
    print('\n\n')

TEXT
context : the covid‑19 pandemic, also known as the coronavirus pandemic, is an ongoing global pandemic of coronavirus disease 2019 (covid‑19), caused by severe acute respiratory syndrome coronavirus 2 (sars‑cov‑2). the outbreak was first identified in december 2019 in wuhan, china. the world health organization declared the outbreak a public health emergency of international concern on 30 january 2020 and a pandemic on 11 march. as of 24 august 2020, more than 23.4 million cases of covid‑19 have been reported in more than 188 countries and territories, resulting in more than 808,000 deaths; more than 15.1 million people have recovered. query : where was the outbreak first identified ?

PREDICTIONS
wuhan, china
in wuhan, china.
wuhan, china
wuhan, china



TEXT
context : the covid‑19 pandemic, also known as the coronavirus pandemic, is an ongoing global pandemic of coronavirus disease 2019 (covid‑19), caused by severe acute respiratory syndrome coronavirus 2 (sars‑cov‑2). the outbr