In [None]:
!unzip "/content/train-v1.1.json.zip"

Archive:  /content/train-v1.1.json.zip
  inflating: train-v1.1.json         


In [None]:
!pip install evaluate

In [None]:
!pip install rouge

In [None]:
!pip install blue

In [None]:
import torch
from tqdm import tqdm
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader,Dataset
from torchvision import transforms
from transformers import T5Tokenizer, T5Model, T5ForConditionalGeneration, T5TokenizerFast
import evaluate

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
import json

In [None]:
with open('/content/train-v1.1.json') as f:
  data = json.load(f)


In [None]:
data

In [None]:
data['data'][0]['paragraphs'][0]['qas'][0]['question']

'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?'

In [None]:
articles=[]

for article in data['data']:
  for paragraph in article['paragraphs']:
    for qa in paragraph['qas']:
      inputs={"context": paragraph['context'],'question': qa['question'],'answers':qa['answers'][0]['text']}
      articles.append(inputs)


In [None]:
articles[1]

{'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'What is in front of the Notre Dame Main Building?',
 'answers': 'a copper statue of Christ'}

In [None]:
data=pd.DataFrame(articles)

In [None]:
data.head()

Unnamed: 0,context,question,answers
0,"Architecturally, the school has a Catholic cha...",To whom did the Virgin Mary allegedly appear i...,Saint Bernadette Soubirous
1,"Architecturally, the school has a Catholic cha...",What is in front of the Notre Dame Main Building?,a copper statue of Christ
2,"Architecturally, the school has a Catholic cha...",The Basilica of the Sacred heart at Notre Dame...,the Main Building
3,"Architecturally, the school has a Catholic cha...",What is the Grotto at Notre Dame?,a Marian place of prayer and reflection
4,"Architecturally, the school has a Catholic cha...",What sits on top of the Main Building at Notre...,a golden statue of the Virgin Mary


In [None]:
t_len=0
for i in data['answers']:
  if len(i.split())>t_len:
    t_len=len(i.split())
print(t_len)

43


In [None]:
q_len=0
for i in data['question']:
  if len(i.split())>q_len:
    q_len=len(i.split())
print(q_len)

40


In [None]:
tokenizer=T5TokenizerFast.from_pretrained('t5-base')
model=T5ForConditionalGeneration.from_pretrained('t5-base',return_dict=True)
optimizer=Adam(model.parameters(),lr=0.00001)
q_len=40
t_len=43
batch_size=4
device= 'cuda'if torch.cuda.is_available() else 'cpu'
epochs=5

In [None]:
class QADataset(Dataset):
  def __init__(self,data,tokenizer,q_len,t_len):
    self.data=data
    self.tokenizer=tokenizer
    self.q_len=q_len
    self.t_len=t_len
    self.question=self.data['question'].tolist()
    self.answer=self.data['answers'].tolist()
    self.context=self.data['context'].tolist()


  def __len__(self):
    return len(self.data)

  def __getitem__(self,idx):
    question=self.question[idx]
    answer=self.answer[idx]
    context=self.context[idx]

    #print(question)
    #print(answer)
    #print(context)

    question_token=self.tokenizer(question,max_length=self.q_len,padding='max_length',truncation=True,pad_to_max_length=True,add_special_tokens=True)
    answer_token=self.tokenizer(answer,max_length=self.t_len,padding='max_length',truncation=True,pad_to_max_length=True,add_special_tokens=True)

    labels=answer_token['input_ids']
    labels[labels==0]=-100

    return {
        "inputs_ids":torch.tensor(question_token['input_ids'],dtype=torch.long),
        "attention_mask":torch.tensor(question_token['attention_mask'],dtype=torch.long),
        "labels":torch.tensor(labels,dtype=torch.long),
        "decoder_attention_mask":torch.tensor(answer_token['attention_mask'],dtype=torch.long)
    }




In [None]:
sample_data=QADataset(data,tokenizer,q_len,t_len)

In [None]:
sample_data[0]

{'inputs_ids': tensor([  304,  4068,   410,     8, 16823,  3790,     3, 18280,  2385,    16,
           507,  3449,    16,   301,  1211,  1395,  1410,    58,     1,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'labels': tensor([-100, 8942,    9,   26, 1954,  264, 8371, 8283,    1,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0]),
 'decoder_attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}

In [None]:
train_data,val_data=train_test_split(data,test_size=0.2,random_state=42)

train_dataset=QADataset(train_data,tokenizer,q_len,t_len)
val_dataset=QADataset(val_data,tokenizer,q_len,t_len)

train_data_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
val_data_loader=DataLoader(val_dataset,batch_size=batch_size,shuffle=True)

In [None]:
next(iter(train_data_loader))

{'inputs_ids': tensor([[  571,    19, 11901,    31,     7,  1424,   512,     3,  8232,    58,
              1,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
         [ 2645,   410,    37,  3068, 15884,  1827, 22324,    13,    16,  6622,
             58,     1,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
         [  363,    19,     8,  2015,   496,   358,    16,  1117,  5089,    58,
              1,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
         [ 1029,  4068,

In [None]:
model.to(device)

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

In [None]:
train_loss=0
val_loss=0
train_batch_count=0
val_batch_count=0

for epoch in range(epochs):
  model.train()
  for batch in tqdm(train_data_loader):
    inputs_ids=batch['inputs_ids'].to(device)
    attention_mask=batch['attention_mask'].to(device)
    labels=batch['labels'].to(device)
    decoder_attention_mask=batch['decoder_attention_mask'].to(device)


    output=model(input_ids=inputs_ids,
                 attention_mask=attention_mask,
                 labels=labels,
                 decoder_attention_mask=decoder_attention_mask)
    #print(output)
    optimizer.zero_grad()
    loss=output.loss
    loss.backward()
    optimizer.step()
    train_loss+=loss.item()
    train_batch_count+=1
    #print(loss)

  print(f"Train Loss: {train_loss/train_batch_count}")

In [None]:
model.eval()
for batch in tqdm(val_data_loader):
  inputs_ids=batch['inputs_ids'].to(device)
  attention_mask=batch['attention_mask'].to(device)
  labels=batch['labels'].to(device)
  decoder_attention_mask=batch['decoder_attention_mask'].to(device)
  with torch.no_grad():
    output=model(input_ids=inputs_ids,
                 attention_mask=attention_mask,
                 labels=labels,
                 decoder_attention_mask=decoder_attention_mask)
    optimizer.zero_grad()
    loss=output.loss
    val_loss+=loss.item()
    val_batch_count+=1
  print(f"{epoch+1}/{2} -> Train loss: {train_loss / train_batch_count}\tValidation loss: {val_loss/val_batch_count}")

  0%|          | 4/4380 [00:00<04:10, 17.49it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3200398683547974
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3199824690818787
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.2340755065282185
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3304593563079834


  0%|          | 8/4380 [00:00<03:57, 18.43it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.6487142086029052
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5951141516367595
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.4948115178516932
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.500336304306984


  0%|          | 12/4380 [00:00<04:01, 18.10it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.4832775195439656
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.4629136443138122
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5102866888046265
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.4635957727829616


  0%|          | 14/4380 [00:00<04:00, 18.12it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.4997358643091643
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.475655083145414
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.4321874101956686


  0%|          | 18/4380 [00:01<04:39, 15.59it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3940244391560555
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3951288601931404
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.363844417863422
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3408314177864475


  1%|          | 22/4380 [00:01<04:48, 15.13it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3212263464927674
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3039121826489766
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.2875331342220306
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.287826146768487


  1%|          | 26/4380 [00:01<04:19, 16.78it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.291819768647353
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3039742970466615
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.292497884768706
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.2842118055732161


  1%|          | 30/4380 [00:01<04:07, 17.59it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3243614115885325
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3421943043840343
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.365327384074529
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3701287911784263


  1%|          | 34/4380 [00:02<04:26, 16.33it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3590140137821436
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.352149556983601
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3531348477391636
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3471554262297494


  1%|          | 38/4380 [00:02<04:08, 17.47it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3497549874915018
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3648417527611192
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3568125759300433
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3622470436952052


  1%|          | 42/4380 [00:02<04:26, 16.31it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3669402495026588
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3545387619879188
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3803215892541976


  1%|          | 46/4380 [00:02<04:07, 17.54it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.4139503892077956
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.4015375673770905
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3948443677690294
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3965937287911125


  1%|          | 50/4380 [00:02<04:09, 17.33it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.3926055025547108
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5168815006812413
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5459541009396922
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5698761558532714


  1%|          | 54/4380 [00:03<04:03, 17.75it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.572456584257238
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5894795931302583
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.598987804268891
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5932042201360066


  1%|▏         | 58/4380 [00:03<03:59, 18.08it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.584083552794023
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5755424733672823
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5715726969534891
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5690518186010163


  1%|▏         | 62/4380 [00:03<04:01, 17.91it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5848225597607888
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5865213672320049
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5842388184344183
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5869016628111563


  2%|▏         | 66/4380 [00:03<04:04, 17.67it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.579213019401308
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5696773137897253
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5670360711904672
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.559476760300723


  2%|▏         | 70/4380 [00:04<03:57, 18.15it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5904480432396504
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5984251341399025
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5865221645521081
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5923808864184787


  2%|▏         | 74/4380 [00:04<03:55, 18.26it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5879725284979378
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5932735684845183
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5873669761500946
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5869205255766172


  2%|▏         | 78/4380 [00:04<03:48, 18.80it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.593367649714152
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.592194271715064
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5829293518871457
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5740732512412927


  2%|▏         | 82/4380 [00:04<03:56, 18.16it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5797761339175551
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5776802606880664
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5691631178797028
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5611332720372735


  2%|▏         | 84/4380 [00:04<05:35, 12.82it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5525141458913505
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5580084444511504
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5529712487669551

  2%|▏         | 86/4380 [00:05<05:40, 12.61it/s]


1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5576401165751523


  2%|▏         | 88/4380 [00:05<05:39, 12.66it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5529949452685214
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.550241426310756
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5524479961127378


  2%|▏         | 90/4380 [00:05<06:27, 11.07it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5534977588388654
1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5485746814654424


  2%|▏         | 92/4380 [00:05<04:36, 15.50it/s]

1/2 -> Train loss: 6.0496930536234155	Validation loss: 1.5575156270161918





KeyboardInterrupt: 

In [None]:
model.save_pretrained('qa_model')
tokenizer.save_pretrained('qa_model')


('qa_model/tokenizer_config.json',
 'qa_model/special_tokens_map.json',
 'qa_model/spiece.model',
 'qa_model/added_tokens.json',
 'qa_model/tokenizer.json')

In [None]:
def predict_ans(context,question,ref_ans=None):
  inputs=tokenizer(question,context,max_length=q_len,padding='max_length',truncation=True,add_special_tokens=True).to(device)
  inputs_ids=torch.tensor(inputs['input_ids'],dtype=torch.long).to(device).unsqueeze(0)
  attention_mask=torch.tensor(inputs['attention_mask'],dtype=torch.long).to(device).unsqueeze(0)
  output=model.generate(input_ids=inputs_ids,attention_mask=attention_mask)
  predicted_ans=tokenizer.decode(output[0],skip_special_tokens=True)
  if ref_ans:
    rouge=evaluate.load('google_bleu')
    #score=bleu.compute(predictions=[predicted_ans],references=[ref_ans])
    print("context: \n",context,"\n")
    print("question: \n",question,"\n")
    return{
      "reference answer: ": ref_ans,
      "predicted answer: ": predicted_ans
      #"bleu score: ": score['bleu']
    }
  else:
    return{
      "context: ": context,
      "question: ": question,
      "predicted answer: ": predicted_ans
    }

In [None]:
context=data.iloc[0]['context']
question=data.iloc[0]['question']
ref_ans=data.iloc[0]['answers']

In [None]:
predict_ans(context,question,ref_ans)