In [1]:
# !pip install wandb
# !pip install transformers
# !pip install sentencepiece
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,2,3" 
import pandas as pd
import numpy as np
import re
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
import math
import torch.optim as optim
import wandb


wandb.login()
%env WANDB_PROJECT=NEL
# wandb.init(project="NEL")

# os.environ["WANDB_DISABLED"] = "true"


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu")


[34m[1mwandb[0m: Currently logged in as: [33mhodz199[0m. Use [1m`wandb login --relogin`[0m to force relogin


env: WANDB_PROJECT=NEL


In [2]:
# Create torch dataset
# https://towardsdatascience.com/fine-tuning-pretrained-nlp-models-with-huggingfaces-trainer-6326a4456e7b

class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels=None):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        if self.labels:
            item["labels"] = torch.tensor(self.labels['input_ids'][idx])
        return item

    def __len__(self):
        return len(self.encodings["input_ids"])

In [3]:
max_source_length = 1024
max_target_length = 128

In [4]:
from transformers import T5TokenizerFast, T5ForConditionalGeneration 
from transformers import EarlyStoppingCallback

tokenizer = T5TokenizerFast.from_pretrained("t5-base")
model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)

In [5]:
training_data = pd.read_csv('./2-NEL_Data/2-csv_format_2/training_data_shuffled.csv')
training_data = training_data.sample(frac=1, random_state=1)
training_data

Unnamed: 0,qid,question,entity,wikidata_reply,qid_in_reply,input_len
4312,Q7492362,**Who was involved in the sheemore ambush?**,**sheemore ambush**,"**[[Q7492362, Sheemore ambush, ambush during t...",True,143
12486,Q82955,**name a professional politician.**,**politician**,"**[[Q51556674, Politician, song by Cream], [Q1...",True,531
92904,Q738653,** white can play either a3 or a4 (see algebra...,**algebraic notation**,"**[[Q60418499, Algebraic Notation of Kinship, ...",True,622
20220,Q4776569,**Where in europe was antonio demo born**,**antonio demo**,"**[[Q4776569, Antonio Demo, Italian-American p...",True,114
105441,Q1297,"** Filming. Prior to filming, Mendes sought to...",**Chicago**,"**[[Q2233885, Willard, city in Huron County, O...",True,629
...,...,...,...,...,...,...
98047,Q816704,** the Daily Mirror's new magazine. Their Depu...,**pâté**,"**[[Q1044124, paten, small plate used to hold ...",True,790
5192,Q2247706,**who directed the crowd roars**,**the crowd roars**,"**[[Q2247706, The Crowd Roars, 1938 film by Ri...",True,224
77708,Q100047,**AmIRC is an MUI-based IRC client for the Ami...,**Amiga**,"**[[Q20049564, Amiga, song by Eliana], [Q32389...",True,470
98539,Q2717,"** place). Births. March 25 - Matthew Barney, ...",**July 29**,"**[[Q17982661, 29 July 2013, date], [Q17982659...",True,515


In [6]:
input_text = list(training_data['question'] + ',' + training_data['entity'] + ',' + training_data['wikidata_reply'])
input_text[0]

'**Who was involved in the sheemore ambush?**,**sheemore ambush**,**[[Q7492362, Sheemore ambush, ambush during the Irish War of Independence]]**'

In [7]:
target_text = list(training_data['qid'])
target_text[0]

'Q7492362'

In [8]:
X_train_tokenized = tokenizer(['nel: ' + sequence for sequence in input_text], 
                              padding=True, 
                              truncation=True, 
                              max_length=max_source_length)

y_train_tokenized = tokenizer(target_text, 
                              padding=True, 
                              truncation=True, 
                              max_length=max_target_length)

print(len(training_data))
# print(len(training_sample))

130395


In [9]:
val_data = pd.read_csv('./2-NEL_Data/2-csv_format_2/val_data_shuffled.csv')
val_data = val_data.sample(frac=1, random_state=1)
val_data

Unnamed: 0,qid,question,entity,wikidata_reply,qid_in_reply,input_len
13761,Q497,"** arises from the anus, the distal orifice of...",**anus**,"**[[Q31785909, Anus, mountain in Namibia], [Q4...",True,532
12966,Q188,**Hiwi is a German abbreviation. It has two me...,**German**,"**[[Q188, German, West Germanic language spoke...",True,459
135,Q11865151,**what country is ismo kallio from**,**ismo kallio**,"**[[Q11865151, Ismo Kallio, Finnish actor (193...",True,110
13518,Q165654,** The fire was first spotted at 8:04 p.m. by ...,**constable**,"**[[Q159297, John Constable, English painter (...",True,592
10510,Q81136,** the British Admiralty lost interest in the ...,**Northwest Passage**,"**[[Q17114554, Northwest Passage, former bi-we...",True,659
...,...,...,...,...,...,...
905,Q7697846,**what type of film is telstar: the joe meek s...,**telstar: the joe meek story**,"**[[Q7697846, Telstar: The Joe Meek Story, 200...",True,156
5192,Q252,**Bunguran is a small archipelago of Indonesia...,**Indonesia**,"**[[Q96708780, Indonesia, scientific journal],...",True,507
12172,Q5758978,**American High is a documentary television sh...,**Highland Park High School**,"**[[Q5758983, Highland Park High School, publi...",True,661
235,Q2605094,**What type of music is the album hours?**,**hours**,"**[[Q25235, hour, unit of time], [Q157044, The...",True,466


In [10]:
input_text_val = list(val_data['question'] + ',' + val_data['entity'] + ',' + val_data['wikidata_reply'])
input_text_val[0]

'** arises from the anus, the distal orifice of the gastrointestinal tract. It is a distinct entity from the more common colorectal cancer. The etiology, risk factors, clinical progression, staging, **,**anus**,**[[Q31785909, Anus, mountain in Namibia], [Q497, anus, digestive track waste expulsion opening], [Q23855, Anus, Oceanic language spoken in Indonesia], [Q20685927, Anus, album by Alaska Thunderfuck 5000], [Q25016777, Pulau Anus, island in Papua, Indonesia], [Q26235245, Anus, village in Sarmi Regency, Papua, Indonesia]]**'

In [11]:
target_text_val = list(val_data['qid'])
target_text_val[0]

'Q497'

In [12]:
X_val_tokenized = tokenizer(['nel: ' + sequence for sequence in input_text_val], 
                              padding=True, 
                              truncation=True, 
                              max_length=max_source_length)

y_val_tokenized = tokenizer(target_text_val, 
                              padding=True, 
                              truncation=True, 
                              max_length=max_target_length)

print(len(val_data))
# print(len(training_sample))

14406


In [13]:
train_dataset = Dataset(X_train_tokenized, y_train_tokenized)

In [14]:
val_dataset = Dataset(X_val_tokenized, y_val_tokenized)

In [15]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    "NEL_model_shuffled_add_spaces_in_input",
    evaluation_strategy ='steps',
    eval_steps = 1000, # Evaluation and Save happens every 50 steps
    logging_steps = 1000,
    save_steps = 1000,
    save_total_limit = 5, # Only last 5 models are saved. Older ones are deleted.
    per_device_train_batch_size = 4,
    per_device_eval_batch_size = 4,
    learning_rate = 1e-3,
    adam_epsilon = 1e-8,
    num_train_epochs = 5,
    report_to="wandb",
#     metric_for_best_model = 'f1',
    load_best_model_at_end=True
)

In [16]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model=model, 
    args=training_args, 
    train_dataset= train_dataset,
    eval_dataset = val_dataset,
    callbacks = [EarlyStoppingCallback(early_stopping_patience=5)]
)

In [17]:
trainer.train()

***** Running training *****
  Num examples = 130395
  Num Epochs = 5
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 12
  Gradient Accumulation steps = 1
  Total optimization steps = 54335
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"




Step,Training Loss,Validation Loss
1000,0.4359,0.251185
2000,0.247,0.166507
3000,0.1985,0.173752
4000,0.1767,0.132703
5000,0.1654,0.121042
6000,0.151,0.131296
7000,0.1397,0.118089
8000,0.13,0.107417
9000,0.1244,0.097368
10000,0.1185,0.103013


***** Running Evaluation *****
  Num examples = 14406
  Batch size = 12
Saving model checkpoint to NEL_model_shuffled/checkpoint-1000
Configuration saved in NEL_model_shuffled/checkpoint-1000/config.json
Model weights saved in NEL_model_shuffled/checkpoint-1000/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 14406
  Batch size = 12
Saving model checkpoint to NEL_model_shuffled/checkpoint-2000
Configuration saved in NEL_model_shuffled/checkpoint-2000/config.json
Model weights saved in NEL_model_shuffled/checkpoint-2000/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 14406
  Batch size = 12
Saving model checkpoint to NEL_model_shuffled/checkpoint-3000
Configuration saved in NEL_model_shuffled/checkpoint-3000/config.json
Model weights saved in NEL_model_shuffled/checkpoint-3000/pytorch_model.bin
***** Running Evaluation *****
  Num examples = 14406
  Batch size = 12
Saving model checkpoint to NEL_model_shuffled/checkpoint-4000
Configuration saved in NEL

  Batch size = 12
Saving model checkpoint to NEL_model_shuffled/checkpoint-14000
Configuration saved in NEL_model_shuffled/checkpoint-14000/config.json
Model weights saved in NEL_model_shuffled/checkpoint-14000/pytorch_model.bin
Deleting older checkpoint [NEL_model_shuffled/checkpoint-9000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 14406
  Batch size = 12
Saving model checkpoint to NEL_model_shuffled/checkpoint-15000
Configuration saved in NEL_model_shuffled/checkpoint-15000/config.json
Model weights saved in NEL_model_shuffled/checkpoint-15000/pytorch_model.bin
Deleting older checkpoint [NEL_model_shuffled/checkpoint-10000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 14406
  Batch size = 12
Saving model checkpoint to NEL_model_shuffled/checkpoint-16000
Configuration saved in NEL_model_shuffled/checkpoint-16000/config.json
Model weights saved in NEL_model_shuffled/checkpoint-16000/pytorch_model.bin
Deleting older check

Deleting older checkpoint [NEL_model_shuffled/checkpoint-21000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 14406
  Batch size = 12
Saving model checkpoint to NEL_model_shuffled/checkpoint-27000
Configuration saved in NEL_model_shuffled/checkpoint-27000/config.json
Model weights saved in NEL_model_shuffled/checkpoint-27000/pytorch_model.bin
Deleting older checkpoint [NEL_model_shuffled/checkpoint-22000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 14406
  Batch size = 12
Saving model checkpoint to NEL_model_shuffled/checkpoint-28000
Configuration saved in NEL_model_shuffled/checkpoint-28000/config.json
Model weights saved in NEL_model_shuffled/checkpoint-28000/pytorch_model.bin
Deleting older checkpoint [NEL_model_shuffled/checkpoint-23000] due to args.save_total_limit
***** Running Evaluation *****
  Num examples = 14406
  Batch size = 12
Saving model checkpoint to NEL_model_shuffled/checkpoint-29000
Configuration saved 

TrainOutput(global_step=37000, training_loss=0.09810885846937024, metrics={'train_runtime': 41263.5624, 'train_samples_per_second': 15.8, 'train_steps_per_second': 1.317, 'total_flos': 2.6085598750688256e+17, 'train_loss': 0.09810885846937024, 'epoch': 3.4})

In [18]:
print('finished')

finished


In [19]:
# Num examples = 130395
# Num Epochs = 5
# Instantaneous batch size per device = 4
# Total train batch size (w. parallel, distributed & accumulation) = 12
# Gradient Accumulation steps = 1
# Total optimization steps = 54335