#  Mount google drive, install dependencies and import required packages

In [None]:
# Run this cell if running on Google Colab
from google.colab import drive
drive.mount('/content/drive/')
%cd '/content/drive/My Drive/Transformer/'

In [None]:
# Install transformers
!pip install transformers==2.11.0

In [None]:
# Import required packages
import torch
import os.path
from os import path
import matplotlib.pyplot as plt
from torch.utils import data
from torch.utils.data import Dataset
from tqdm.notebook import tqdm
import random
from src import utils
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Retrieve training and validation data from google drive

In [None]:
X_train,Y_train = [],[]  
fileX = open("train/train.token.sbt")
fileY = open("train/train.token.nl")

for lineX, lineY in zip(fileX, fileY):
  X_train.append(lineX)
  Y_train.append(lineY)

print("Number of Training Example: ", len(X_train))

In [None]:
X_valid,Y_valid = [],[]
fileX = open("val/valid.token.sbt")
fileY = open("val/valid.token.nl")

for lineX, lineY in zip(fileX, fileY):
  X_valid.append(lineX)
  Y_valid.append(lineY)
print("Number of validation Example: ", len(X_valid))

# Dataset Class to load data while training and validation

In [None]:
class Dataset(data.Dataset):
  def __init__(self, X_item, Y_item):
    self.X_item=X_item
    self.Y_item=Y_item

  def __len__(self):
    return len(self.X_item)

  def __getitem__(self, index):
    X = self.X_item[index]
    
    Y = self.Y_item[index]
    return X, Y

# Take GPU into action and define batch size and num of workers

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(">>> ", device)
params = {'batch_size': 4,
          'shuffle': True,
          'num_workers': 0}

params_valid = {'batch_size': 4,
          'shuffle': False,
          'num_workers': 0}

training_set = Dataset(X_train,Y_train)
training_generator = data.DataLoader(training_set, **params)

validation_set = Dataset(X_valid,Y_valid)
validation_generator = data.DataLoader(validation_set, **params_valid)


## Load Initial Model and Tokenizer

In [None]:
#load pre-trained model and tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small')
model = model.to(device)

#load opptimizer
no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = []
lr=1e-4
for key, value in dict(model.named_parameters()).items():
    if value.requires_grad:
        if any(nd in key for nd in no_decay):
            optimizer_grouped_parameters += [
                {"params": [value], "lr": lr, "weight_decay": 0.01}
            ]
        if not any(nd in key for nd in no_decay):
            optimizer_grouped_parameters += [
                {"params": [value], "lr": lr, "weight_decay": 0.0}
            ]
                
optimizer = utils.BertAdam(
            optimizer_grouped_parameters,
            lr=lr,
            warmup=0.1,
            t_total=100,
            schedule='warmup_constant',
        )

# Training variables

loss_train = []
loss_valid = []
trainloss=0
validloss=0
epoch = 0
num_epoch = 35
model_name = "checkpoint.pth"
model_location = "models/" + model_name


## Load model, tokenizer, optimizer, trainer, epoch no. and losses from a checkpoint directly or from the saved location.

In [None]:
if path.exists(model_location):

  #load saved model from the drive
  checkpoint = torch.load(model_location)
  epoch = checkpoint['epoch']
  model.load_state_dict(checkpoint["state_dict"])
  optimizer.load_state_dict(checkpoint['optimizer'])
  loss_train = checkpoint['trainloss']
  loss_valid = checkpoint['validloss']
  print(">>> loaded saved checkpoint from epoch ", epoch)
  model = model.to(device)

else:
  print(">>> loaded from downloaded model ")

## Plot train and valid loss to see the curve

In [None]:
plt.plot(loss_train, color="green")
plt.plot(loss_valid, color="red")
plt.show()

In [None]:
best_f1=checkpoint['F1']
best_model=checkpoint
if path.exists("models/best.pth"):
  best_model = torch.load("models/best.pth")
  best_f1=best_model['F1']
print(best_f1)

# Training and validation loop

In [None]:
while epoch < num_epoch:

  trainloss = 0
  validloss = 0  
  print("Running EPOCH : ", epoch+1)

  """Training"""
  model.train()
  for local_batch, local_labels in tqdm(training_generator): 


    """Forward Function Implementation"""
    input_ids = tokenizer.batch_encode_plus(local_batch, return_tensors="pt",pad_to_max_length=True)
    label = tokenizer.batch_encode_plus(local_labels, return_tensors="pt",pad_to_max_length=True)
    
    outputs = model(input_ids=(input_ids['input_ids']).to(device), lm_labels=(label['input_ids']).to(device),attention_mask=(input_ids['attention_mask']).to(device))
    loss = outputs[0]
    """Forward Function Ends here"""
    
    """Loss and optimizer"""
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    trainloss += round(loss.data.item(), 4)

  trainloss = trainloss/len(training_generator)
  loss_train.append(trainloss)

  """Validation"""
  model.eval()
  with torch.set_grad_enabled(False):
    refs_list = []
    hyp_list = []
    for local_batch, local_labels in tqdm(validation_generator):
      input_ids = tokenizer.batch_encode_plus(local_batch, return_tensors="pt",pad_to_max_length=True)
      label = tokenizer.batch_encode_plus(local_labels, return_tensors="pt",pad_to_max_length=True)
      
      outputs = model(input_ids=(input_ids['input_ids']).to(device), lm_labels=(label['input_ids']).to(device))

      predY = model.generate(input_ids=(input_ids['input_ids']).to(device))
          
      for i,j in zip(predY, local_labels):
        hyp_list.append(tokenizer.decode(i))
        refs_list.append(j)
        
      validloss += round(outputs[0].data.item(), 4)   

  validloss = validloss / len(validation_generator)
  loss_valid.append(validloss)

  precision, recall, f1, accuracy = utils.calculate_results(refs_list, hyp_list)  

  plt.xlabel("epochs")
  plt.ylabel("loss")
  plt.plot(loss_train, color='green', label="Train Loss")
  plt.plot(loss_valid, color='red', label="Valid Loss")
  
  plt.legend()
  plt.show()

  print("\nEpoch ", epoch+1, " completed! \nTrain loss is: ", trainloss, "\nValid loss is: ", validloss)  
  
  print("For Epoch:",epoch+1)
  print("Precision: ", precision)
  print("Recall: ", recall)
  print("f1: ", f1)
  print("Accuracy: ", accuracy)
  """Save states"""
  states = {
          'epoch': epoch + 1,
          'state_dict': model.state_dict(),
          'optimizer': optimizer.state_dict(),
          'trainloss': loss_train,
          'validloss': loss_valid,
          'F1': f1
      }
  if f1>best_f1:
    best_f1=f1
    print("Saving the best model")
    torch.save(states, "models/best.pth")
  
  print("Saving the regular model")
  torch.save(states, "models/" + model_name)

  rand = random.randint(0,19999)

  print("Random example from valid set:\n")
  print("TARGET: " + refs_list[rand] + "\n")
  print("Prediction: " + hyp_list[rand] + "\n")
  print("\n", "_"*100 , "\n\n")

  epoch += 1

## Load Test Data

In [None]:
plt.xlabel("epochs")
plt.ylabel("loss")
plt.plot(loss_valid, color='red', label="Valid Loss")
plt.plot(loss_train, color='blue', label="Train Loss")
plt.legend()
plt.show()

In [None]:
X_test,Y_test = [],[]
file=open("test/test.token.sbt")
for line in file:
  X_test.append(line)
file=open("test/test.token.nl")
for line in file:
  Y_test.append(line)

print(len(X_test))
  
# Data loader for test set
test_set = Dataset(X_test,Y_test)
test_generator = data.DataLoader(test_set, **params_valid)

## Make target and predicted list


In [None]:
refs_list = []
hyp_list = []

model.eval()
with torch.set_grad_enabled(False):
  for local_batch, local_labels in tqdm(test_generator):
    input_ids = tokenizer.batch_encode_plus(local_batch, return_tensors="pt",pad_to_max_length=True)
    
    predY = model.generate(input_ids=(input_ids['input_ids']).to(device))
    
    for i,j in zip(predY, local_labels):
      hyp_list.append(tokenizer.decode(i))
      refs_list.append(j)

## Calculate Precision, Recall, f1 and Accuracy on the test set

In [None]:
precision, recall, f1, accuracy = utils.calculate_results(refs_list, hyp_list)
rand = random.randint(0,19999)

for i in range(40):
  print("TARGET: " + refs_list[rand+i])
  print("Prediction: " + hyp_list[rand+i] + "\n\n")

print("Precision: ", precision)
print("Recall: ", recall)
print("f1: ", f1)

## Calculat Bleu Score

In [None]:
f = open("src/reference.txt", "w")
for i in range(len(refs_list)):
  if(i > 1):
    f.write(refs_list[i])
f.close()
f = open("src/hypothesis.txt", "w")
for i in range(len(hyp_list)):
  if(i > 1):
    f.write(hyp_list[i] + "\n")
f.close()

In [None]:
%cd src
%%perl multi-bleu.perl reference.txt hypothesis.txt

In [None]:
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.bleu_score import corpus_bleu
from nltk.translate.bleu_score import SmoothingFunction
smoothie = SmoothingFunction().method7
bleu_score = 0.0
refs=[]
hyp=[]
for i in range(len(refs_list)):
  refs.append(refs_list[i].strip().split())
  hyp.append(hyp_list[i].strip().split())
print(hyp)
#bleu_score2 = corpus_bleu(refs, hyp,smoothing_function=smoothie)
bleu_score2 = corpus_bleu(refs, hyp)
print("The bleu score is: "+str(bleu_score2*100))