# Finetunning Question Answering on T5
## This notebook outlines the concepts behind the Finetuning of Question Answering task on SQUAD dataset

In [1]:
import pandas as pd
from datasets import load_dataset
from transformers import DefaultDataCollator
from transformers import T5ForConditionalGeneration, AdamW
from transformers import T5TokenizerFast as T5Tokenizer

In [2]:
import numpy as np
import torch
from pathlib import Path
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split
from termcolor import colored
import textwrap

In [3]:
MODEL_NAME = "mrm8488/t5-base-finetuned-quartz"

In [4]:
squad = load_dataset("squad")

Downloading builder script:   0%|          | 0.00/1.97k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Downloading and preparing dataset squad/plain_text (download: 33.51 MiB, generated: 85.63 MiB, post-processed: Unknown size, total: 119.14 MiB) to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/8.12M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.05M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

Dataset squad downloaded and prepared to /root/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
squad

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [6]:
train_question = squad['train']['question']  #question
train_context = squad['train']['context']  #context

In [7]:
listing = []
listing2 = []
first_list = squad['train']['answers']  #answers
for fl in first_list:
    listing.append(fl.get('text'))
    listing2.append(fl.get('answer_start')) 
    
train_answer = []
train_answer_start = []

for one_list in listing:
    for _ in one_list:
        train_answer.append(_)

for two_list in listing2:
    for _ in two_list:
        train_answer_start.append(_)

In [8]:
data = {'question':train_question, 'context':train_context, 'answer_text':train_answer, 'answer_start':train_answer_start}

In [9]:
len(train_question ), len(train_context), len(train_answer), len(train_answer_start)

(87599, 87599, 87599, 87599)

In [10]:
df = pd.DataFrame(data)
df

Unnamed: 0,question,context,answer_text,answer_start
0,To whom did the Virgin Mary allegedly appear i...,"Architecturally, the school has a Catholic cha...",Saint Bernadette Soubirous,515
1,What is in front of the Notre Dame Main Building?,"Architecturally, the school has a Catholic cha...",a copper statue of Christ,188
2,The Basilica of the Sacred heart at Notre Dame...,"Architecturally, the school has a Catholic cha...",the Main Building,279
3,What is the Grotto at Notre Dame?,"Architecturally, the school has a Catholic cha...",a Marian place of prayer and reflection,381
4,What sits on top of the Main Building at Notre...,"Architecturally, the school has a Catholic cha...",a golden statue of the Virgin Mary,92
...,...,...,...,...
87594,In what US state did Kathmandu first establish...,"Kathmandu Metropolitan City (KMC), in order to...",Oregon,229
87595,What was Yangon previously known as?,"Kathmandu Metropolitan City (KMC), in order to...",Rangoon,414
87596,With what Belorussian city does Kathmandu have...,"Kathmandu Metropolitan City (KMC), in order to...",Minsk,476
87597,In what year did Kathmandu create its initial ...,"Kathmandu Metropolitan City (KMC), in order to...",1975,199


In [11]:
df = df.drop_duplicates(subset=["context"]).reset_index(drop=True)

In [12]:
len(df.question.unique())

18881

In [13]:
len(df.context.unique())

18891

In [14]:
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

Downloading:   0%|          | 0.00/1.81k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.74k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.26k [00:00<?, ?B/s]

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.
The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


In [15]:
class SQUADQADataset(Dataset):
    def __init__(self, data:pd.DataFrame, tokenizer:T5Tokenizer, source_max_token_len: int = 396, target_max_token_len: int = 32,):
        self.data =  data
        self.tokenizer =  tokenizer
        self.source_max_token_len =  source_max_token_len
        self.target_max_token_len =  target_max_token_len
        
    def __len__(self):
        return len(self.data)
   
    def __getitem__(self, index: int):
        data_row = self.data.iloc[index]
        source_encoding =tokenizer(
            data_row["question"],
            data_row["context"],
            max_length=self.source_max_token_len,
            padding="max_length",
            truncation="only_second",
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt"
        )

        target_encoding =tokenizer(
            data_row["answer_text"],
            max_length=self.target_max_token_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt"
        )

        labels = target_encoding["input_ids"]
        labels[labels == 0] = -100
        return dict(
            question = data_row["question"],
            context=data_row["context"],
            answer_text=data_row['answer_text'],
            input_ids=source_encoding["input_ids"].flatten(),
            attention_mask=source_encoding["attention_mask"].flatten(),
            labels=labels.flatten()
        )

In [16]:
sample_dataset = SQUADQADataset(df, tokenizer)

In [17]:
train_df, val_df = train_test_split(df, test_size=0.05)

In [18]:
class SUAQDQAModule(pl.LightningDataModule):
    def __init__(self, train_df: pd.DataFrame,test_df: pd.DataFrame, 
                 tokenizer:T5Tokenizer, batch_size: int = 8, 
                 source_max_token_len: int = 396, 
                 target_max_token_len: int = 32,):
        super().__init__()
        self.train_df = train_df
        self.test_df = test_df
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.source_max_token_len = source_max_token_len
        self.target_max_token_len = target_max_token_len
    
    def setup(self, stage=None):
        self.train_dataset = SQUADQADataset(self.train_df, self.tokenizer,
                                          self.source_max_token_len,self.target_max_token_len)
        self.test_dataset = SQUADQADataset(self.test_df, self.tokenizer,
                                         self.source_max_token_len, self.target_max_token_len)

    def train_dataloader(self):
        return DataLoader(self.train_dataset,batch_size=self.batch_size, shuffle=True,num_workers=4)
    
    def val_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=4)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=1, num_workers=4)

In [19]:
BATCH_SIZE = 4
N_EPOCHS = 3

In [20]:
data_module = SUAQDQAModule(train_df, val_df, tokenizer, batch_size=BATCH_SIZE)
data_module.setup()

In [21]:
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict = True)
# model.config

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


Downloading:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

# Modelling

In [22]:
class SUADQAModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
   
    def forward(self, input_ids, attention_mask, labels=None):
        output = self.model(
            input_ids, 
             attention_mask=attention_mask,
             labels=labels)
        return output.loss, output.logits
   
    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask=batch['attention_mask']
        labels = batch['labels']
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return {"loss": loss, "predictions":outputs, "labels": labels}
   
    def validation_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask=batch['attention_mask']
        labels = batch['labels']
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("val_loss", loss, prog_bar=True, logger=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask=batch['attention_mask']
        labels = batch['labels']
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("test_loss", loss, prog_bar=True, logger=True)
        return loss
   
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=0.0001)
        return optimizer
    


In [23]:
model = SUADQAModel()

The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.


In [24]:
checkpoint_callback = ModelCheckpoint(
     dirpath="checkpoints",
     filename="best-checkpoint",
     save_top_k=1,
     verbose=True,
     monitor="val_loss",
     mode="min"
 )

In [25]:
trainer = pl.Trainer(
     #logger = logger,
    checkpoint_callback=checkpoint_callback,
    max_epochs=N_EPOCHS,
    gpus=1,
    progress_bar_refresh_rate = 30
 )

  f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
  f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"


In [26]:
trainer.fit(model, data_module)



Sanity Checking: 0it [00:00, ?it/s]

  cpuset_checked))


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [27]:
trainer.test(model, data_module)   

Testing: 0it [00:00, ?it/s]

[{'test_loss': 0.3079359829425812}]