In [1]:
import torch
import torch.nn as nn

from transformers import T5Model,T5Tokenizer
from torch import jit

In [2]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')

In [3]:
tokenizer.vocab_size

32100

In [4]:
tokens = tokenizer.tokenize('Hello world how are you?')

print(tokens)

['▁Hello', '▁world', '▁how', '▁are', '▁you', '?']


In [5]:
indexes = tokenizer.convert_tokens_to_ids(tokens)

print(indexes)

[8774, 296, 149, 33, 25, 58]


In [6]:
init_token = tokenizer.pad_token
eos_token = tokenizer.eos_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

<pad> </s> <pad> <unk>


In [7]:
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

0 1 0 2


In [8]:
max_input_length = tokenizer.max_model_input_sizes['t5-small']

print(max_input_length)

512


In [9]:
class T5Network(nn.Module):
    def __init__(self):
        
        super().__init__()
        
        self.t5 = T5Model.from_pretrained('t5-small')
        
        self.out = nn.Linear(self.t5.config.to_dict()['d_model'],
                             self.t5.config.to_dict()['vocab_size'])
                
    def forward(self, src, trg):
        
        embedded = self.t5(input_ids=src,decoder_input_ids=trg) 
        
        output = self.out(embedded[0])
        
        return output

In [10]:
model = T5Network()

Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 76,988,544 trainable parameters


In [12]:
# CONVERT ALL MODEL WEIGHTS AND BIASES TO HALF PRECISION
# MODEL SIZE WILL REDUCE
model = model.half()

In [13]:
model.load_state_dict(torch.load('t5_qa_model.pt'))

<All keys matched successfully>

In [14]:
def translate_sentence2(sentence, eval_model, max_len = 50):
    
    eval_model.eval()
    eval_model = eval_model.float()

    src_indexes = [init_token_idx] + sentence + [eos_token_idx]

    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0)

    trg_indexes = [init_token_idx]

    for i in range(max_len):

        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0)
        
        with torch.no_grad():
            
            print(src_tensor)
            print(src_tensor.shape)
            print(trg_tensor)
            print(trg_tensor.shape)
            print("\n\n")
            
            output = eval_model(src_tensor, trg_tensor)
        
        pred_token = output.argmax(2)[:,-1].item()
        
        trg_indexes.append(pred_token)

        if pred_token == eos_token_idx:
            break

    return trg_indexes[1:-1]

In [15]:
def return_answer(context,query):
    txt = 'context : ' + context.lower() + ' question : ' + query.lower()
    txt_tokens = tokenizer.tokenize(txt)
    txt_ids = tokenizer.convert_tokens_to_ids(txt_tokens)
    pred = translate_sentence2(txt_ids, model)
    pred_tokens = tokenizer.convert_ids_to_tokens(pred)
    
    return ''.join(pred_tokens)

In [16]:
CONTEXT = "Common symptoms include fever, cough, fatigue, breathing difficulties, and loss of smell and taste. Complications may include pneumonia and acute respiratory distress syndrome. The incubation period is typically around five days but may range from one to 14 days. There are several vaccine candidates in development, although none have completed clinical trials. There is no known specific antiviral medication, so primary treatment is currently symptomatic."
QUERY = "what are some symptoms ?"

In [17]:
return_answer(CONTEXT,QUERY)

tensor([[    0,  2625,     3,    10,  1017,  3976,   560, 17055,     6, 19222,
             6, 13034,     6, 10882, 10308,     6,    11,  1453,    13,  5949,
            11,  2373,     5, 14497,   164,   560, 30195,    11, 12498, 19944,
         19285, 12398,     5,     8,    16, 16377,  1575,  1059,    19,  3115,
           300,   874,   477,    68,   164,   620,    45,    80,    12,   968,
           477,     5,   132,    33,   633, 12956,  4341,    16,   606,     6,
          2199,  5839,    43,  2012,  3739, 10570,     5,   132,    19,   150,
           801,   806,  1181,  5771,   138,  7757,     6,    78,  2329,  1058,
            19,  1083,     3, 18018,  6049,     5,   822,     3,    10,   125,
            33,   128,  3976,     3,    58,     1]])
torch.Size([1, 96])
tensor([[0]])
torch.Size([1, 1])



tensor([[    0,  2625,     3,    10,  1017,  3976,   560, 17055,     6, 19222,
             6, 13034,     6, 10882, 10308,     6,    11,  1453,    13,  5949,
            11,  2373,

'▁fever,▁cough,▁fatigue,▁breathing▁difficulties,▁and▁loss▁of▁smell▁and▁taste'

In [18]:
x = torch.ones(1, 100).long()
y = torch.ones(1, 20).long()

In [19]:
net_trace = jit.trace(model, [x, y])

  if causal_mask.shape[1] < attention_mask.shape[1]:


In [20]:
jit.save(net_trace.half(), 't5_ts_qa_model.zip')