In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext
from torchtext.data import Field, BucketIterator

import numpy as np
from torchtext import data

import random
import math
import time

from transformers import T5Model,T5Tokenizer

In [2]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')

In [3]:
tokenizer.vocab_size

32100

In [4]:
tokens = tokenizer.tokenize('Hello world how are you?')

print(tokens)

['▁Hello', '▁world', '▁how', '▁are', '▁you', '?']


In [5]:
indexes = tokenizer.convert_tokens_to_ids(tokens)

print(indexes)

[8774, 296, 149, 33, 25, 58]


In [6]:
init_token = tokenizer.pad_token
eos_token = tokenizer.eos_token
pad_token = tokenizer.pad_token
unk_token = tokenizer.unk_token

print(init_token, eos_token, pad_token, unk_token)

<pad> </s> <pad> <unk>


In [7]:
init_token_idx = tokenizer.convert_tokens_to_ids(init_token)
eos_token_idx = tokenizer.convert_tokens_to_ids(eos_token)
pad_token_idx = tokenizer.convert_tokens_to_ids(pad_token)
unk_token_idx = tokenizer.convert_tokens_to_ids(unk_token)

print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)

0 1 0 2


In [8]:
max_input_length = tokenizer.max_model_input_sizes['t5-small']

print(max_input_length)

512


In [9]:
def tokenize_and_cut(sentence):
    tokens = tokenizer.tokenize(sentence) 
    tokens = tokens[:max_input_length-2]
    return tokens

In [10]:
SRC = Field(batch_first = True,
          use_vocab = False,
          tokenize = tokenize_and_cut,
          preprocessing = tokenizer.convert_tokens_to_ids,
          init_token = init_token_idx,
          eos_token = eos_token_idx,
          pad_token = pad_token_idx,
          unk_token = unk_token_idx)

TRG = Field(batch_first = True,
          use_vocab = False,
          tokenize = tokenize_and_cut,
          preprocessing = tokenizer.convert_tokens_to_ids,
          init_token = init_token_idx,
          eos_token = eos_token_idx,
          pad_token = pad_token_idx,
          unk_token = unk_token_idx)



In [11]:
fields = [('src', SRC), ('trg', TRG)]

In [12]:
train_data = data.TabularDataset.splits(
                path = '',
                train = 'news_train.csv',
                format = 'csv',
                fields = fields,
                skip_header = True)

train_data , valid_data = train_data[0].split(split_ratio=0.98,
                                             random_state = random.seed(4321))



In [13]:
print(len(train_data.examples))
print(len(valid_data.examples))

86765
1771


In [14]:
print(vars(train_data.examples[4000]))

{'src': [3, 76, 17, 2046, 3, 5319, 1395, 107, 21108, 5752, 6323, 176, 15, 7, 107, 3, 7, 17178, 9, 30, 3, 189, 3589, 1135, 243, 24, 1181, 18, 11956, 32, 12025, 16, 8, 538, 65, 118, 3759, 12, 766, 8, 1455, 13, 887, 11, 59, 12, 31738, 16679, 151, 5, 96, 29, 32, 6965, 56, 36, 3169, 26, 406, 136, 1053, 5, 62, 33, 464, 21, 8, 394, 297, 13, 151, 5, 3, 99, 10843, 56, 6136, 120, 31738, 13112, 6, 62, 56, 59, 8179, 135, 976, 3, 88, 974, 5], 'trg': [1181, 18, 11956, 32, 12025, 59, 3759, 12, 31738, 16679, 7, 10, 95, 3, 26, 63, 2446]}


In [15]:
src_tokens = tokenizer.convert_ids_to_tokens(vars(train_data.examples[4000])['src'])
trg_tokens = tokenizer.convert_ids_to_tokens(vars(train_data.examples[4000])['trg'])

print(src_tokens)
print(trg_tokens)

['▁', 'u', 't', 'tar', '▁', 'pra', 'des', 'h', '▁deputy', '▁chief', '▁minister', '▁din', 'e', 's', 'h', '▁', 's', 'harm', 'a', '▁on', '▁', 'th', 'urs', 'day', '▁said', '▁that', '▁anti', '-', 'rome', 'o', '▁squad', '▁in', '▁the', '▁state', '▁has', '▁been', '▁launched', '▁to', '▁ensure', '▁the', '▁safety', '▁of', '▁women', '▁and', '▁not', '▁to', '▁disturb', '▁innocent', '▁people', '.', '▁"', 'n', 'o', 'body', '▁will', '▁be', '▁trouble', 'd', '▁without', '▁any', '▁reason', '.', '▁we', '▁are', '▁working', '▁for', '▁the', '▁better', 'ment', '▁of', '▁people', '.', '▁', 'if', '▁somebody', '▁will', '▁false', 'ly', '▁disturb', '▁anybody', ',', '▁we', '▁will', '▁not', '▁spare', '▁them', ',"', '▁', 'he', '▁added', '.']
['▁anti', '-', 'rome', 'o', '▁squad', '▁not', '▁launched', '▁to', '▁disturb', '▁innocent', 's', ':', '▁up', '▁', 'd', 'y', '▁cm']


In [16]:
device = torch.device('cuda')

BATCH_SIZE = 32

train_iterator, valid_iterator = BucketIterator.splits(
                                 (train_data, valid_data), 
                                 batch_size = BATCH_SIZE,
                                 device = device,
                                 sort_key=lambda x: len(x.src))



In [17]:
class T5Network(nn.Module):
    def __init__(self):
        
        super().__init__()
        
        self.t5 = T5Model.from_pretrained('t5-small')
        
        self.out = nn.Linear(self.t5.config.to_dict()['d_model'],
                             self.t5.config.to_dict()['vocab_size'])
                
    def forward(self, src, trg):
        
        embedded = self.t5(input_ids=src,decoder_input_ids=trg) 
        
        output = self.out(embedded[0])
        
        return output

In [18]:
model = T5Network().cuda()

Some weights of T5Model were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 76,988,544 trainable parameters


In [20]:
LEARNING_RATE = 0.0004

optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)

In [21]:
criterion = nn.CrossEntropyLoss(ignore_index = pad_token_idx)

In [22]:
N_EPOCHS = 4
CLIP = 1

best_valid_loss = float('inf')

In [23]:
for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    model = model.float()
    # TRAINING
    ##############################################################################
    model.train()
    epoch_loss = 0
    
    for i, batch in enumerate(train_iterator):
        
        src = batch.src
        trg = batch.trg
        
        optimizer.zero_grad()
        
        output = model(src, trg[:,:-1])

        output_dim = output.shape[-1]
            
        output = output.contiguous().view(-1, output_dim)
        trg = trg[:,1:].contiguous().view(-1)

        loss = criterion(output, trg)
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP)
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    train_loss = epoch_loss / len(train_iterator)
    ##############################################################################
    
    # VALIDATION
    ##############################################################################
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():
    
        for i, batch in enumerate(valid_iterator):

            src = batch.src
            trg = batch.trg

            output = model(src, trg[:,:-1])
            
            output_dim = output.shape[-1]
            
            output = output.contiguous().view(-1, output_dim)
            trg = trg[:,1:].contiguous().view(-1)
            
            loss = criterion(output, trg)

            epoch_loss += loss.item()
        
    valid_loss = epoch_loss / len(valid_iterator)
    ##############################################################################
    model = model.half()
    
    end_time = time.time()
    
    print(f"EPOCH : {epoch+1}\tTRAIN LOSS : {train_loss:.2f}\tVALID LOSS : {valid_loss:.2f}\tTIME : {end_time-start_time:.2f}\n")
    torch.save(model.state_dict(), f't5_summ_model_{epoch+1}.pt')



EPOCH : 1	TRAIN LOSS : 2.58	VALID LOSS : 1.38	TIME : 1269.99

EPOCH : 2	TRAIN LOSS : 1.31	VALID LOSS : 1.25	TIME : 1311.11

EPOCH : 3	TRAIN LOSS : 1.11	VALID LOSS : 1.20	TIME : 1782.60

EPOCH : 4	TRAIN LOSS : 0.99	VALID LOSS : 1.21	TIME : 1279.54



In [24]:
# # CONVERT ALL MODEL WEIGHTS AND BIASES TO HALF PRECISION
# # MODEL SIZE WILL REDUCE
# model = model.half()

In [25]:
# torch.save(model.state_dict(), 't5_summ_model.pt')

In [26]:
def translate_sentence2(sentence, eval_model, device, max_len = 50):
    
    eval_model.eval()

    src_indexes = [init_token_idx] + sentence + [eos_token_idx]

    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)

    trg_indexes = [init_token_idx]

    for i in range(max_len):

        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = eval_model(src_tensor, trg_tensor)
        
        pred_token = output.argmax(2)[:,-1].item()
        
        trg_indexes.append(pred_token)

        if pred_token == eos_token_idx:
            break

    return trg_indexes[1:]

In [27]:
idxs = random.sample(range(0, len(valid_data.examples)), 20)
for i in idxs:
    src = vars(valid_data.examples[i])['src']
    #trg = vars(valid_data.examples[i])['trg']
    translation = translate_sentence2(src, model, device)

    print(f"SRC : {''.join(tokenizer.convert_ids_to_tokens(src))}")
    #print(f"TRG : {''.join(tokenizer.convert_ids_to_tokens(trg))}")
    print(f"PRED : {''.join(tokenizer.convert_ids_to_tokens(translation))}\n")

SRC : ▁more▁than▁100▁state▁and▁district▁roads▁in▁uttar▁pradesh▁will▁be▁converted▁into▁national▁highways,▁deputy▁chief▁minister▁keshav▁prasad▁maurya▁said▁on▁sunday.▁maurya▁further▁claimed▁that▁the▁bjp▁has▁already▁succeeded▁in▁making▁a▁major▁part▁of▁the▁state's▁roads▁free▁of▁potholes▁and▁that▁the▁city▁of▁allahabad▁would▁witness▁rapid▁development▁in▁near▁future.n
PRED : ▁100▁up▁state,▁district▁roads▁to▁be▁converted▁into▁highways</s>

SRC : ▁harvard▁and▁cornell▁university▁researchers▁have▁documented▁the▁"dance▁routine"▁of▁vogelkop▁superb▁bird-of-paradise,▁which▁helped▁confirm▁it▁as▁a▁new▁species.▁the▁video▁revealed▁the▁male▁dances▁with▁crescent-shaped▁wings▁to▁lure▁its▁female▁counterpart.▁the▁species▁was▁earlier▁confused▁with▁a▁similar▁looking▁bird▁which▁is▁also▁endemic▁to▁the▁island▁nation▁of▁new▁guinea,▁off▁australia's▁coast.
PRED : ▁bird-of-paradise▁dances▁with▁crescent▁wings▁to▁lure▁its▁female</s>

SRC : ▁two▁security▁personnel▁have▁committed▁suicide▁in▁jaisalmer▁within▁a▁span▁of▁two▁d