In [19]:
# !pip install wandb
# !pip install transformers
# !pip install sentencepiece
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3" 
import pandas as pd
import numpy as np
import re
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import math
import torch.optim as optim
import wandb


wandb.login()

# os.environ["WANDB_DISABLED"] = "true"



# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu")


In [20]:
# Create torch dataset
# https://towardsdatascience.com/fine-tuning-pretrained-nlp-models-with-huggingfaces-trainer-6326a4456e7b

class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels:
            item["labels"] = torch.tensor(self.labels['input_ids'][idx])
        return item

    def __len__(self):
        return len(self.encodings["input_ids"])

In [21]:
max_source_length = 1024
max_target_length = 128

In [22]:
from transformers import T5TokenizerFast, T5ForConditionalGeneration 
from transformers import EarlyStoppingCallback

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)

In [23]:
training_data = pd.read_csv('./2-NEL_Data/2-csv_format_2/training_data_shuffled.csv')
training_data = training_data.sample(frac=1, random_state=1)
training_data

Unnamed: 0,qid,question,entity,wikidata_reply
19399,Q823600,**Who is someone that was born in chesterfield?**,**chesterfield**,"**[[Q2414206, Chesterfield, town in Massachuse..."
14726,Q3660532,**what gender is carole hayman**,**carole hayman**,"**[[Q3660532, Carole Hayman, British writer]]**"
20385,Q1290008,**Which country is marcel landers from?**,**marcel landers**,"**[[Q1290008, Marcel Landers, German footballe..."
20463,Q238440,**where did the battle of navarino take place**,**battle of navarino**,"**[[Q238440, Battle of Navarino, 1827 naval ba..."
5886,Q65598,**what nationality is jürgen röber**,**jürgen röber**,"**[[Q65598, Jürgen Röber, football player and ..."
...,...,...,...,...
7813,Q3534673,**what is m. k. muthu known for being**,**m. k. muthu**,"**[[Q6712758, M. K. Muthukaruppannasamy, Membe..."
32511,Q381178,**where did bob barker grow up?**,**bob barker**,"**[[Q55762050, Bob Barker Company, American co..."
5192,Q1143438,**is the mysterious mr quin fiction or nonfict...,**the mysterious mr quin**,"**[[Q1143438, The Mysterious Mr Quin, book]]**"
12172,Q3086842,**Which film is fred guiol a director for**,**fred guiol**,"**[[Q3086842, Fred Guiol, Film director; scree..."


In [24]:
input_text = list(training_data['question'] + ',' + training_data['entity'] + ',' + training_data['wikidata_reply'])
input_text[0]

'**Who is someone that was born in chesterfield?**,**chesterfield**,**[[Q2414206, Chesterfield, town in Massachusetts, USA], [Q1924516, Chesterfield, human settlement in Indiana, United States of America], [Q823600, Chesterfield, market town and unparished area in Derbyshire, England], [Q959443, Chesterfield, city in St. Louis County, Missouri, United States], [Q2310682, Chesterfield, town in New Hampshire], [Q2063962, Chesterfield, town in Chesterfield County, South Carolina, United States], [Q48935, Chesterfield F.C., association football club in Chesterfield, England]]**'

In [25]:
target_text = list(training_data['qid'])
target_text[0]

'Q823600'

In [26]:
X_train_tokenized = tokenizer(['nel: ' + sequence for sequence in input_text], 
                              padding=True, 
                              truncation=True, 
                              max_length=max_source_length)

y_train_tokenized = tokenizer(target_text, 
                              padding=True, 
                              truncation=True, 
                              max_length=max_target_length)

print(len(training_data))
# print(len(training_sample))

34300


In [27]:
val_data = pd.read_csv('./2-NEL_Data/2-csv_format_2/val_data_shuffled.csv')
val_data = val_data.sample(frac=1, random_state=1)
val_data

Unnamed: 0,qid,question,entity,wikidata_reply
910,Q149941,**what country is alaa abd el-fattah from**,**alaa abd el-fattah**,"**[[Q149941, Alaa Abd El-Fattah, Egyptian huma..."
1715,Q11425,**What's an example of an animation program?**,**animation**,"**[[Q11425, animation, method of creating movi..."
4700,Q131433,**who is shania twain's husband?**,**shania twain**,"**[[Q1143593, Shania Twain discography, Wikime..."
2776,Q1324387,**where was el medico born**,**el medico**,"**[[Q1324387, El Médico, Cuban musician], [Q27..."
4284,Q3282637,**who is film producer**,**film producer**,"**[[Q111316788, Chandni Soni, Film Producer], ..."
...,...,...,...,...
2895,Q1868197,**Where was wisdom agblexo born**,**wisdom agblexo**,"**[[Q1868197, Wisdom Agblexo, Ghanaian footbal..."
2763,Q93196,**what is the name of a bollywood movice**,**bollywood**,"**[[Q110592757, Music Videos > Indian > Bollyw..."
905,Q691225,**What does a lineman (occupation) specialize ...,**lineman**,"**[[Q269359, ""lineman's pliers"", multifunction..."
3980,Q15998685,**Where was william j. heffernan bown**,**william j. heffernan**,"**[[Q15998685, William J. Heffernan, American ..."


In [28]:
input_text_val = list(val_data['question'] + ',' + val_data['entity'] + ',' + val_data['wikidata_reply'])
input_text_val[0]

'**what country is alaa abd el-fattah from**,**alaa abd el-fattah**,**[[Q149941, Alaa Abd El-Fattah, Egyptian human rights activist]]**'

In [29]:
target_text_val = list(val_data['qid'])
target_text_val[0]

'Q149941'

In [30]:
X_val_tokenized = tokenizer(['nel: ' + sequence for sequence in input_text_val], 
                              padding=True, 
                              truncation=True, 
                              max_length=max_source_length)

y_val_tokenized = tokenizer(target_text_val, 
                              padding=True, 
                              truncation=True, 
                              max_length=max_target_length)

print(len(val_data))
# print(len(training_sample))

4846


In [31]:
train_dataset = Dataset(X_train_tokenized, y_train_tokenized)

In [32]:
val_dataset = Dataset(X_val_tokenized, y_val_tokenized)

In [33]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    "NEL_model_normal_shuffled",
    evaluation_strategy ='steps',
    eval_steps = 500, # Evaluation and Save happens every 50 steps
    logging_steps = 500,
    save_total_limit = 5, # Only last 5 models are saved. Older ones are deleted.
    per_device_train_batch_size = 4,
    per_device_eval_batch_size = 4,
    learning_rate = 1e-3,
    adam_epsilon = 1e-8,
    num_train_epochs = 5,
    report_to="wandb",
#     metric_for_best_model = 'f1',
    load_best_model_at_end=True
)

In [34]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model, 
    args=training_args, 
    train_dataset= train_dataset,
    eval_dataset = val_dataset,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=3)]
)

In [35]:
trainer.train()

***** Running training *****
  Num examples = 34300
  Num Epochs = 5
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 12
  Gradient Accumulation steps = 1
  Total optimization steps = 14295
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"




Step,Training Loss,Validation Loss
500,0.6942,0.463458
1000,0.4655,0.355211
1500,0.3652,0.284501
2000,0.3369,0.278281
2500,0.3055,0.267376
3000,0.2785,0.234602
3500,0.2453,0.250234
4000,0.2454,0.242878
4500,0.2169,0.216075
5000,0.2359,0.225882


***** Running Evaluation *****
  Num examples = 4846
  Batch size = 12
Saving model checkpoint to NEL_model_normal_shuffled/checkpoint-500
Configuration saved in NEL_model_normal_shuffled/checkpoint-500/config.json
Model weights saved in NEL_model_normal_shuffled/checkpoint-500/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 4846
  Batch size = 12
Saving model checkpoint to NEL_model_normal_shuffled/checkpoint-1000
Configuration saved in NEL_model_normal_shuffled/checkpoint-1000/config.json
Model weights saved in NEL_model_normal_shuffled/checkpoint-1000/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 4846
  Batch size = 12
Saving model checkpoint to NEL_model_normal_shuffled/checkpoint-1500
Configuration saved in NEL_model_normal_shuffled/checkpoint-1500/config.json
Model weights saved in NEL_model_normal_shuffled/checkpoint-1500/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 4846
  Batch size = 12
Saving model checkpoint to NEL_m

***** Running Evaluation *****
  Num examples = 4846
  Batch size = 12
Saving model checkpoint to NEL_model_normal_shuffled/checkpoint-7000
Configuration saved in NEL_model_normal_shuffled/checkpoint-7000/config.json
Model weights saved in NEL_model_normal_shuffled/checkpoint-7000/pytorch_model.bin
Deleting older checkpoint [NEL_model_normal_shuffled/checkpoint-4500] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4846
  Batch size = 12
Saving model checkpoint to NEL_model_normal_shuffled/checkpoint-7500
Configuration saved in NEL_model_normal_shuffled/checkpoint-7500/config.json
Model weights saved in NEL_model_normal_shuffled/checkpoint-7500/pytorch_model.bin
Deleting older checkpoint [NEL_model_normal_shuffled/checkpoint-5000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 4846
  Batch size = 12
Saving model checkpoint to NEL_model_normal_shuffled/checkpoint-8000
Configuration saved in NEL_model_normal_shuffled/checkpoint-80

TrainOutput(global_step=8500, training_loss=0.27528705731560205, metrics={'train_runtime': 6064.0706, 'train_samples_per_second': 28.281, 'train_steps_per_second': 2.357, 'total_flos': 4.742704873193472e+16, 'train_loss': 0.27528705731560205, 'epoch': 2.97})