In [1]:
import torch 
import json
from tqdm import tqdm
import numpy as np
import logging
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering, AdamW
from torch.utils.data import Dataset, DataLoader

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set device 
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

def load_dataset_from_file(filepath):
    """Loads the dataset from a given JSON file containing QA data."""
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            dataset = json.load(f)
    except FileNotFoundError:
        logger.error(f"File not found: {filepath}")
        raise

    context_list, question_list, answer_list = [], [], []
    total_questions = 0
    total_possible = 0
    total_impossible = 0

    for section in dataset['data']:
        for paragraph in section['paragraphs']:
            context = paragraph['context']
            for qna in paragraph['qas']:
                question = qna['question']
                total_questions += 1
                if 'is_impossible' in qna and qna['is_impossible']:
                    total_impossible += 1
                else:
                    total_possible += 1
                for answer in qna.get('answers', []):
                    context_list.append(context.lower())
                    question_list.append(question.lower())
                    answer_list.append(answer)

    return total_questions, total_possible, total_impossible, context_list, question_list, answer_list


def calculate_end_positions(answer_data, context_data):
    """Computes the end position of the answers based on the answer start position and answer text."""
    for answer, context in zip(answer_data, context_data):
        answer_text = answer.get('text', '').lower()
        answer_start = answer.get('answer_start', -1)
        answer['answer_end'] = answer_start + len(answer_text)


# Load datasets
try:
    train_total, train_possible, train_impossible, train_contexts, train_questions, train_answers = load_dataset_from_file('spoken_train-v1.1.json')
    valid_total, valid_possible, valid_impossible, valid_contexts, valid_questions, valid_answers = load_dataset_from_file('spoken_test-v1.1.json')
except Exception as e:
    logger.error(f"Error loading data: {e}")
    exit()

# Apply answer end position calculation
calculate_end_positions(train_answers, train_contexts)
calculate_end_positions(valid_answers, valid_contexts)

MAX_SEQ_LEN = 512
MODEL_PATH = "distilbert-base-uncased"

# Initialize tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_PATH)

# Tokenize training and validation sets
train_encodings = tokenizer(train_questions, train_contexts, max_length=MAX_SEQ_LEN, padding=True, truncation=True)
valid_encodings = tokenizer(valid_questions, valid_contexts, max_length=MAX_SEQ_LEN, padding=True, truncation=True)

# Custom Dataset class
class QAData(Dataset):
    def __init__(self, encodings, answers):
        """
        Custom dataset for question answering task."""
        self.encodings = encodings
        self.answers = answers

    def __getitem__(self, idx):
        """Fetches a specific item from the dataset."""
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['start_positions'] = torch.tensor(self.answers[idx].get('answer_start', -1))
        item['end_positions'] = torch.tensor(self.answers[idx].get('answer_end', -1))
        return item

    def __len__(self):
        return len(self.encodings.input_ids)

# Create dataset objects
train_dataset = QAData(train_encodings, train_answers)
valid_dataset = QAData(valid_encodings, valid_answers)

# Create DataLoaders for batching
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1)

# Load pre-trained model and move to device
model = DistilBertForQuestionAnswering.from_pretrained(MODEL_PATH).to(device)

# Define optimizer
optimizer = AdamW(model.parameters(), lr=2e-5)

def train_one_epoch(model, data_loader, optimizer):
    """ Trains the model for one epoch."""
    
    model.train()
    total_loss = 0.0
    for batch in tqdm(data_loader, desc='Training'):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask,
                         start_positions=start_positions, end_positions=end_positions)
        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    return total_loss / len(data_loader)

def evaluate_on_validation_set(model, data_loader):
    """Evaluates the model on the validation set and computes F1 score."""
    
    model.eval()
    f1_scores = []

    for batch in tqdm(data_loader, desc='Evaluating'):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_true = batch['start_positions'].to(device)
        end_true = batch['end_positions'].to(device)

        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)

        start_pred = torch.argmax(outputs.start_logits, dim=1)
        end_pred = torch.argmax(outputs.end_logits, dim=1)

        for i in range(len(start_true)):
            predicted_answer = tokenizer.decode(input_ids[i][start_pred[i]:end_pred[i] + 1], skip_special_tokens=True)
            true_answer = tokenizer.decode(input_ids[i][start_true[i]:end_true[i] + 1], skip_special_tokens=True)

            predicted_tokens = set(predicted_answer.strip().split())
            true_tokens = set(true_answer.strip().split())

            if true_tokens and predicted_tokens:
                intersection = predicted_tokens.intersection(true_tokens)
                precision = len(intersection) / len(predicted_tokens) if len(predicted_tokens) > 0 else 0
                recall = len(intersection) / len(true_tokens) if len(true_tokens) > 0 else 0
                f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
                f1_scores.append(f1)

    return np.mean(f1_scores) if f1_scores else 0.0

# Training loop
EPOCHS = 5
best_f1 = 0.0
patience = 3
early_stop_count = 0

for epoch in range(EPOCHS):
    train_loss = train_one_epoch(model, train_loader, optimizer)
    val_f1 = evaluate_on_validation_set(model, valid_loader)
    logger.info(f"Epoch {epoch + 1}/{EPOCHS}\n Train Loss: {train_loss:.4f}, F1 Score: {val_f1:.4f}")

    if val_f1 > best_f1:
        best_f1 = val_f1
        early_stop_count = 0
    else:
        early_stop_count += 1
        if early_stop_count >= patience:
            logger.info("Early stopping activated!")
            break


  warn(
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Training: 100%|██████████| 4639/4639 [02:48<00:00, 27.58it/s]
Evaluating: 100%|██████████| 15875/15875 [00:52<00:00, 303.22it/