In [1]:
import torch
import torch.nn as nn

from transformers import T5Model,T5Tokenizer
from torch import jit

In [2]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')

In [3]:
tokenizer.vocab_size

32100

In [4]:
tokens = tokenizer.tokenize('Hello world how are you?')

print(tokens)

['▁Hello', '▁world', '▁how', '▁are', '▁you', '?']


In [5]:
indexes = tokenizer.convert_tokens_to_ids(tokens)

print(indexes)

[8774, 296, 149, 33, 25, 58]


In [6]:
init_token = tokenizer.pad_token
eos_token = tokenizer.eos_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

<pad> </s> <pad> <unk>


In [7]:
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

0 1 0 2


In [8]:
max_input_length = tokenizer.max_model_input_sizes['t5-small']

print(max_input_length)

512


In [9]:
class T5Network(nn.Module):
    def __init__(self):
        
        super().__init__()
        
        self.t5 = T5Model.from_pretrained('t5-small')
        
        self.out = nn.Linear(self.t5.config.to_dict()['d_model'],
                             self.t5.config.to_dict()['vocab_size'])
                
    def forward(self, src, trg):
        
        embedded = self.t5(input_ids=src,decoder_input_ids=trg) 
        
        output = self.out(embedded[0])
        
        return output

In [10]:
models = []
for i in range(4):
    new_model = T5Network()
    new_model = new_model.half()
    new_model.load_state_dict(torch.load(f't5_summ_model_{i+1}.pt'))
    models.append(new_model)

Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decod

In [11]:
def translate_sentence2(sentence, eval_model, max_len = 50):
    
    eval_model.eval()
    eval_model = eval_model.float()

    src_indexes = [init_token_idx] + sentence + [eos_token_idx]

    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0)

    trg_indexes = [init_token_idx]

    for i in range(max_len):

        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0)
        
        with torch.no_grad():
            
            print(src_tensor)
            print(src_tensor.shape)
            print(trg_tensor)
            print(trg_tensor.shape)
            print("\n\n")
            
            output = eval_model(src_tensor, trg_tensor)
        
        pred_token = output.argmax(2)[:,-1].item()
        
        trg_indexes.append(pred_token)

        if pred_token == eos_token_idx:
            break

    return trg_indexes[1:-1]

In [12]:
def return_summary(txt,test_model):
    txt_tokens = tokenizer.tokenize(txt)
    txt_ids = tokenizer.convert_tokens_to_ids(txt_tokens)
    pred = translate_sentence2(txt_ids,test_model)
    pred_tokens = tokenizer.convert_ids_to_tokens(pred)
    
    return ''.join(pred_tokens)

In [13]:
TEXT = "israeli forces killed two palestinians near the southern west bank town of hebron on wednesday , palestinian hospital officials said ."

In [14]:
return_summary(TEXT,models[3])

tensor([[    0,     3, 30178,    23,  3859,  4792,   192,  7692,   222,    77,
          7137,  1084,     8,  7518,  4653,  2137,  1511,    13,     3,    88,
         13711,    30,    62,    26,  1496,  1135,     3,     6,  7692,  3340,
         15710,  2833,  4298,   243,     3,     5,     1]])
torch.Size([1, 37])
tensor([[0]])
torch.Size([1, 1])



tensor([[    0,     3, 30178,    23,  3859,  4792,   192,  7692,   222,    77,
          7137,  1084,     8,  7518,  4653,  2137,  1511,    13,     3,    88,
         13711,    30,    62,    26,  1496,  1135,     3,     6,  7692,  3340,
         15710,  2833,  4298,   243,     3,     5,     1]])
torch.Size([1, 37])
tensor([[0, 3]])
torch.Size([1, 2])



tensor([[    0,     3, 30178,    23,  3859,  4792,   192,  7692,   222,    77,
          7137,  1084,     8,  7518,  4653,  2137,  1511,    13,     3,    88,
         13711,    30,    62,    26,  1496,  1135,     3,     6,  7692,  3340,
         15710,  2833,  4298,   243,     3,     5,    

'▁israeli▁forces▁kill▁two▁palestinians▁near▁hebron'

In [15]:
x = torch.ones(1, 80).long()
y = torch.ones(1, 20).long()

In [16]:
net_trace = jit.trace(models[3], [x, y])

  if causal_mask.shape[1] < attention_mask.shape[1]:


In [17]:
jit.save(net_trace.half(), 't5_ts_summ_model.zip')