In [12]:
import json
import re
import string
from collections import Counter
import torch
from tqdm import tqdm
from evaluate import load
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import AutoTokenizer

class SpokenSquad(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, i):
        return {
            'input_ids': torch.tensor(self.encodings['input_ids'][i]),
            'attention_mask': torch.tensor(self.encodings['attention_mask'][i]),
            'start_positions': torch.tensor(self.encodings['start_positions'][i]),
            'end_positions': torch.tensor(self.encodings['end_positions'][i])
        }

    def __len__(self):
        return len(self.encodings['input_ids'])

class QAModel(nn.Module):
    def __init__(self, bert_base_model, device):
        super(QAModel, self).__init__()
        self.bert = bert_base_model
        self.device = device
        self.to(device)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.start_logits, outputs.end_logits

    def find_focal_loss(self, start_logits, end_logits, start_positions, end_positions, gamma=1):
        start_soft_probs = 1 - nn.Softmax(dim=1)(start_logits)
        end_soft_probs = 1 - nn.Softmax(dim=1)(end_logits)
        start_log_probs = nn.LogSoftmax(dim=1)(start_logits)
        end_log_probs = nn.LogSoftmax(dim=1)(end_logits)
        nll_loss = nn.NLLLoss()
        
        start_loss = nll_loss(torch.pow(start_soft_probs, gamma) * start_log_probs, start_positions)
        end_loss = nll_loss(torch.pow(end_soft_probs, gamma) * end_log_probs, end_positions)
    
        return (start_loss + end_loss) / 2

    def evaluate_model(self, dataloader, tokenizer):
        self.eval()
        f1_scores = []
        pred_true_pairs = []
        wer_metric = load("wer")
        with torch.no_grad():
            for batch in tqdm(dataloader, desc='Evaluating Model!'):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                start_labels = batch['start_positions'].to(self.device)
                end_labels = batch['end_positions'].to(self.device)

                start, end = self(input_ids=input_ids, attention_mask=attention_mask)
                start_preds = torch.argmax(start, dim=1)
                end_preds = torch.argmax(end, dim=1)

                i = 0
                while i < input_ids.size(0):
                    predicted = tokenizer.decode(input_ids[i][start_preds[i]:end_preds[i]+1])
                    actual = tokenizer.decode(input_ids[i][start_labels[i]:end_labels[i]+1])
                    
                    pred_true_pairs.append([predicted, actual])
                    f1_scores.append(find_f1_score(predicted, actual))
                    i += 1
                    
        predicted_ans = [pair[0] if pair[0] else "$" for pair in pred_true_pairs]
        actual_ans = [pair[1] if pair[1] else "$" for pair in pred_true_pairs]
        wer_score = wer_metric.compute(predictions=predicted_ans, references=actual_ans)
        avg_f1_score = sum(f1_scores) / len(f1_scores) if f1_scores else 0
        return avg_f1_score, wer_score
        
def collect_and_find_positions(file_path, tokenizer, max_length):
    passages, queries, responses = [], [], []
    start_positions = []
    end_positions = []
    with open(file_path, 'r') as file:
        data = json.load(file)
    for topic in data['data']:
        for paragraph in topic['paragraphs']:
            passage = paragraph['context'].lower()
            for qa in paragraph['qas']:
                query = qa['question'].lower()
                for answer in qa['answers']:
                    answer_text = answer['text'].lower()
                    passages.append(passage)
                    queries.append(query)
                    responses.append({
                        'text': answer_text,
                        'answer_start': answer['answer_start'],
                        'answer_end': answer['answer_start'] + len(answer_text)
                    })

    passages_trunc=[]
    for i in range(len(passages)):
        if(len(passages[i])>512):
            answer_start=responses[i]['answer_start']
            answer_end=responses[i]['answer_start']+len(responses[i]['text'])
            mid=(answer_start+answer_end)//2
            paragraph_start=max(0,min(mid - max_length//2,len(passages[i])-max_length))
            paragraph_end = paragraph_start + max_length 
            passages_trunc.append(passages[i][paragraph_start:paragraph_end])
            responses[i]['answer_start']=((512/2)-len(responses[i])//2)
        else:
            passages_trunc.append(passages[i])

    encodings = tokenizer(queries, passages_trunc, max_length = max_length,truncation=True,padding=True, return_offsets_mapping=False, stride = 128)
    for idx, response in enumerate(responses):
        start_pos, end_pos = 0, 0
        answer_tokens = tokenizer(response['text'], max_length=max_length, truncation=True, padding=True)
        answer_ids = answer_tokens['input_ids'][1:-1]
        
        for start_idx in range(len(encodings['input_ids'][idx]) - len(answer_ids) + 1):
            if encodings['input_ids'][idx][start_idx + 1 : start_idx + 1 + len(answer_ids)] == answer_ids:
                start_pos = start_idx + 1
                end_pos = start_idx + 1 + len(answer_ids)
                break
        
        start_positions.append(start_pos)
        end_positions.append(end_pos)

    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})
    return encodings

def filter_data(text):
    text = re.sub(r'\b(a|an|the)\b', ' ', text.lower())
    text = re.sub(f"[{re.escape(string.punctuation)}]", "", text)
    return ' '.join(text.split())

def find_f1_score(predicted_answer, actual_answer):
    predicted = filter_data(predicted_answer).split()
    actual = filter_data(actual_answer).split()
    common_token_counts = Counter(predicted) & Counter(actual)
    num_common_tokens = sum(common_token_counts.values())
    if num_common_tokens > 0:
        precision = num_common_tokens / len(predicted)
        recall = num_common_tokens / len(actual)
        return 2 * precision * recall / (precision + recall)
    return 0.0