In [None]:
from transformers import BertTokenizerFast, BertModel
from torch.utils.data import Dataset, DataLoader
import torch
from torch import nn
import pandas as pd
from torch.optim import Adam
from tqdm import tqdm
import numpy as np
import string

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
train_path = '/content/drive/MyDrive/SQuAD/train-squad.csv'
val_path = '/content/drive/MyDrive/SQuAD/validation-squad.csv'

In [None]:
tokeniser = BertTokenizerFast.from_pretrained("bert-base-uncased")


In [None]:
class eQA(Dataset):
  def __init__(self,csv_path) -> None:
    super().__init__()
    self.df = pd.read_csv(csv_path)
    self.df["text"] = self.df["text"].fillna("")
  def __len__(self):
    return len(self.df)
  def __getitem__(self,idx):
    row = self.df.iloc[idx]
    return {
        "context": row["context"],
        "question": row["question"],
        "id": row["id"],
        "answer_start": row["answer_start"],
        "answer": row["text"]
    }


In [None]:
def collate_fn(batch):
  encoding = tokeniser([sample["context"] for sample in batch],
                        [sample["question"] for sample in batch],
                        padding = "max_length",
                        truncation = True,
                        return_offsets_mapping = True,
                        max_length = 512,
                        return_tensors = "pt")
  #get the index of the start index and end index
  start_positions = []
  end_positions = []
  for i,s in enumerate(batch):
    answer_start = s["answer_start"]
    answer_end = answer_start + len(s["answer"])
    offset = encoding["offset_mapping"][i]
    start_pos, end_pos = 0, 0
    for j, (start,end) in enumerate(offset):
      if start <= answer_start < end:
        start_pos = j
      if start < answer_end <= end:
        end_pos = j
    start_positions.append(start_pos)
    end_positions.append(end_pos)

  tensor_start = torch.tensor(start_positions)
  tensor_end = torch.tensor(end_positions)

  encoding.pop("offset_mapping")

  encoding["start_positions"] = tensor_start
  encoding["end_positions"] = tensor_end

  return encoding


In [None]:
def collate_fn_val(batch):
  encoding = tokeniser([sample["context"] for sample in batch],
                        [sample["question"] for sample in batch],
                        padding = "max_length",
                        truncation = True,
                        return_offsets_mapping = True,
                        max_length = 512,
                        return_tensors = "pt")
  #get the index of the start index and end index
  start_positions = []
  end_positions = []
  for i,s in enumerate(batch):
    answer_start = s["answer_start"]
    answer_end = answer_start + len(s["answer"])
    offset = encoding["offset_mapping"][i]
    start_pos, end_pos = 0, 0
    for j, (token_start,token_end) in enumerate(offset):
      if token_start <= answer_start < token_end:
        start_pos = j
      if token_start < answer_end <= token_end:
        end_pos = j
    start_positions.append(start_pos)
    end_positions.append(end_pos)

  tensor_start = torch.tensor(start_positions)
  tensor_end = torch.tensor(end_positions)

  encoding["start_positions"] = tensor_start
  encoding["end_positions"] = tensor_end
  encoding["context"] = [s["context"] for s in batch]
  encoding["answer"] = [s["answer"] for s in batch]

  return encoding


In [None]:
trainset = eQA(train_path)
valset = eQA(val_path)
raw_batch_data = [trainset[i] for i in range(3)]

print(collate_fn(raw_batch_data))

In [None]:
trainloader = DataLoader(trainset, batch_size=8, collate_fn=collate_fn)
valLoader = DataLoader(valset, batch_size=8, collate_fn=collate_fn_val)


In [None]:
class BiLSTMBERT(nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.embeddings = BertModel.from_pretrained("bert-base-uncased")
    self.embeddings.requires_grad_(False)

    hidden_size = self.embeddings.config.hidden_size
    self.lstm = nn.LSTM(hidden_size, hidden_size, bidirectional=True, batch_first=True)
    self.output1 = nn.Linear(hidden_size * 2, 1)
    self.output2 = nn.Linear(hidden_size * 2, 1)

  def forward(self, input_ids, token_type, attention_mask):
    state = self.embeddings(input_ids, attention_mask = attention_mask, token_type_ids = token_type).last_hidden_state
    proc,_ = self.lstm(state)
    start_logits = self.output1(proc)
    end_logits = self.output2(proc)
    return start_logits, end_logits




In [None]:
device = "cuda"
model = BiLSTMBERT().to(device)
loss_fn = nn.CrossEntropyLoss()
optimiser = Adam(model.parameters(), lr = 0.001)


In [None]:
epochs = 5
model.train()
for epoch in range(epochs):
  m = 0
  running_loss = 0
  progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}")
  for batch in progress_bar:
    input_ids = batch["input_ids"].to(device)
    token_type_ids = batch["token_type_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    start = batch["start_positions"].to(device)
    end = batch["end_positions"].to(device)


    start_logits, end_logits = model(input_ids,token_type_ids,attention_mask)
    start_logits = start_logits.squeeze(-1)
    end_logits = end_logits.squeeze(-1)

    loss_start = loss_fn(start_logits,start)
    loss_end = loss_fn(end_logits, end)
    loss_avg = (loss_start + loss_end) / 2

    progress_bar.set_postfix(loss=loss_avg.item())

    optimiser.zero_grad()
    loss_avg.backward()
    optimiser.step()
  validate(model,valLoader,"cuda",epoch)





In [None]:
def validate(model, valLoader, device, epoch):
    model.eval()
    total_em = 0
    total_f1 = 0
    n = 0

    def normalise(text):
        return text.lower().translate(str.maketrans("", "", string.punctuation)).strip()

    def compute_f1(pred, truth):
        pred_tokens = normalise(pred).split()
        truth_tokens = normalise(truth).split()
        common = set(pred_tokens) & set(truth_tokens)
        if len(common) == 0:
            return 0.0
        precision = len(common) / len(pred_tokens)
        recall = len(common) / len(truth_tokens)
        return 2 * (precision * recall) / (precision + recall)

    with torch.no_grad():
        progress_bar = tqdm(valLoader, desc=f"Epoch {epoch+1}")
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            start = batch["start_positions"].to(device)
            end = batch["end_positions"].to(device)
            offset = batch["offset_mapping"]

            start_logits, end_logits = model(input_ids, token_type_ids, attention_mask)
            start_logits = start_logits.squeeze(-1)
            end_logits = end_logits.squeeze(-1)

            predicted_start = torch.argmax(start_logits, dim=-1)
            predicted_end = torch.argmax(end_logits, dim=-1)

            for i in range(len(predicted_start)):
                offsets = batch["offset_mapping"][i]
                context = batch["context"][i]
                answer = batch["answer"][i]

                i_start = predicted_start[i].item()
                i_end = predicted_end[i].item()

                start_char, _ = offsets[i_start]
                _, end_char = offsets[i_end]

                predicted_text = context[start_char:end_char]
                gt_text = answer

                print(f"GT: {gt_text}")
                print(f"Pred: {predicted_text}")

                em = int(normalise(predicted_text) == normalise(gt_text))
                f1 = compute_f1(predicted_text, gt_text)

                total_em += em
                total_f1 += f1

                n += 1

    print(f"Validation EM: {total_em / n}")
    print(f"Validation f1: {total_f1 / n}")

In [None]:
PATH = '/content/drive/MyDrive/SQuAD/model.pth'
torch.save(model.state_dict(), PATH)