In [2]:
# Import Libraries
import os
from torch.utils.data import DataLoader, Dataset
from transformers import (
    DistilBertForSequenceClassification,
    BertForQuestionAnswering,
    AdamW,
    BertForSequenceClassification,
    AutoTokenizer
)
import torch
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score
import gzip
import json
import random
from collections import defaultdict, Counter
from bs4 import BeautifulSoup
from sklearn.model_selection import train_test_split
import shutil
from torch.cuda.amp import GradScaler, autocast

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [328]:
# Function to clear up
# "balanced_yes_no_data.json", "test_data.json", "train_data.json", "val_data.json", "fine_tuned_bert_yes_no_model"
paths_to_clear = ["fine_tuned_bert_short_answer_model", "balanced_yes_no_data.json", "test_data.json", "train_data.json", "val_data.json", "fine_tuned_bert_long_answer_model"]

for path in paths_to_clear:
    if os.path.exists(path):
        if os.path.isdir(path):
            shutil.rmtree(path)
            print(f"Deleted directory: {path}")
        elif os.path.isfile(path):
            os.remove(path)
            print(f"Deleted file: {path}")
    else:
        print(f"Path not found: {path}")

Path not found: fine_tuned_bert_short_answer_model
Deleted file: balanced_yes_no_data.json
Path not found: test_data.json
Path not found: train_data.json
Path not found: val_data.json
Path not found: fine_tuned_bert_long_answer_model


Here is the code section for processing Yes/No questions: This part handles tasks related to binary classification, such as determining "YES" or "NO" answers based on input data.






In [3]:
# Function to Clean HTML
def clean_html(text):
    return BeautifulSoup(text, "html.parser").get_text()
print("clean_html function defined.")

# Function to Process Files for yes or no qa
def process_file(file_path, output_data):
    oversized_long_answer_count = 0 
    excluded_examples = 0
    max_context_length = 400

    with gzip.open(file_path, 'rt', encoding='utf-8') as file:
        for i, line in enumerate(file):
            example = json.loads(line)
            question = example['question_text']
            document = example['document_text']
            annotations = example.get('annotations', [])

            if annotations and annotations[0]['yes_no_answer'] != "NONE":
                yes_no_answer = annotations[0]['yes_no_answer']
                long_answer = annotations[0].get('long_answer', {})
                short_answer = annotations[0].get('short_answers', [])
                start_token = long_answer.get("start_token")
                end_token = long_answer.get("end_token")

                if start_token is not None and end_token is not None:
                    span_length = end_token - start_token
                    if span_length > 1000:
                        oversized_long_answer_count += 1

                        if short_answer:
                            end_token = short_answer[0]["end_token"]
                            print(f"Using short answer for oversized long answer. Question: {question}")
                        else:
                            excluded_examples += 1
                            print(f"Excluding example. Question: {question}, Span Length: {span_length}")
                            continue

                cleaned_document = clean_html(document)

                context = " ".join(cleaned_document.split()[start_token:end_token]) if start_token and end_token else cleaned_document
                if len(context.split()) > max_context_length:
                    excluded_examples += 1
                    print(f"Excluding example due to context length > {max_context_length}. Question: {question}")
                    continue

                output_data["yes_no"].append({
                    "type": "yes_no",
                    "question": question,
                    "document": context,
                    "yes_no_answer": yes_no_answer
                })

    print(f"File {file_path} processed.")
    print(f"  Oversized long answers handled: {oversized_long_answer_count}")
    print(f"  Excluded examples: {excluded_examples}")

print("process_file function defined.")


clean_html function defined.
process_file function defined.


In [4]:
# Process Dataset and Balance Yes/No Data
input_folder = "simplified_natural_questions"
output_data = defaultdict(list)

all_files = [file for file in os.listdir(input_folder) if file.endswith(".jsonl.gz")]
print(f"Found {len(all_files)} files. Processing one file for inspection...")

num_files_to_process = 50
files_to_process = all_files[:num_files_to_process]

for file_name in files_to_process:
    file_path = os.path.join(input_folder, file_name)
    print(f"Processing file: {file_path}")
    process_file(file_path, output_data)

# Balance yes/no data
yes_no_data = output_data["yes_no"]
yes_data = [item for item in yes_no_data if item["yes_no_answer"] == "YES"]
no_data = [item for item in yes_no_data if item["yes_no_answer"] == "NO"]

# Find the smaller count between "YES" and "NO" and sample accordingly
min_count = min(len(yes_data), len(no_data))
balanced_yes_no_data = random.sample(yes_data, min_count) + random.sample(no_data, min_count)
random.shuffle(balanced_yes_no_data)

# Save only the balanced yes_no data
output_file = "balanced_yes_no_data.json"
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(balanced_yes_no_data, f, indent=4)

print(f"Processing complete. Balanced yes/no data saved to {output_file}")

# Validation after processing
with open(output_file, "r", encoding="utf-8") as f:
    data = json.load(f)

type_counts = Counter([item["yes_no_answer"] for item in data])
print("Counts of each question type in the dataset:")
print(type_counts)


Found 50 files. Processing one file for inspection...
Processing file: simplified_natural_questions/nq-train-06_simplified.jsonl.gz
Excluding example. Question: do canadian citizens need a visa for myanmar, Span Length: 4419
Excluding example. Question: do i need a visa for namibia from australia, Span Length: 4322
Excluding example. Question: wii sports resort do you need motion plus, Span Length: 1436
Excluding example. Question: does frank underwood become president in season 5, Span Length: 4324
Excluding example. Question: in shameless do they get the house back, Span Length: 1837
Excluding example. Question: do french citizens need a visa for south africa, Span Length: 4316
Excluding example. Question: is it illegal to own a gun in california, Span Length: 2071
Excluding example. Question: do i need a visa to go to uk from malaysia, Span Length: 3485
Excluding example due to context length > 400. Question: are protestant and church of england the same
File simplified_natural_ques

Here is the code section for processing Short Answer questions: This part processes span-based questions where a brief, specific answer is extracted from the context.








In [327]:
# Extract Context Window
def extract_context(document, start_token, end_token, window_size=50):
    tokens = document.split(" ")
    doc_length = len(tokens)

    if doc_length == 0:
        raise ValueError("Document is empty.")

    # Define window boundaries with some randomness
    max_pre_shift = max(0, start_token - 1)
    max_post_shift = max(0, doc_length - end_token - 1)
    random_shift = random.randint(-min(max_pre_shift, window_size), min(max_post_shift, window_size))

    pre_start = max(0, start_token - window_size + random_shift)
    post_end = min(doc_length, end_token + window_size + random_shift)

    # Create the randomized context window
    context_window = " ".join(tokens[pre_start:post_end])

    if not context_window.strip():
        raise ValueError("Extracted context window is empty.")

    # Adjust start and end tokens
    adjusted_start = start_token - pre_start
    adjusted_end = end_token - pre_start

    return context_window, adjusted_start, adjusted_end

print("extract_context function defined.")

clean_html function defined.
extract_context function defined.


In [None]:
# Clean HTML and Adjust Spans
def clean_html_and_adjust_spans(context, start_token, end_token):
    soup = BeautifulSoup(context, "html.parser")
    plain_text = soup.get_text(separator=" ")

    # Split tokens before and after cleaning
    original_tokens = context.split(" ")
    cleaned_tokens = plain_text.split(" ")

    if start_token >= len(original_tokens) or end_token > len(original_tokens):
        raise ValueError("Start or end token indices are out of bounds.")

    # Map original tokens to cleaned tokens
    token_map = {}
    clean_index = 0
    for i, token in enumerate(original_tokens):
        if clean_index < len(cleaned_tokens) and token == cleaned_tokens[clean_index]:
            token_map[i] = clean_index
            clean_index += 1

    adjusted_start = token_map.get(start_token, -1)
    adjusted_end = token_map.get(end_token - 1, -1) + 1  # +1 for exclusive range

    if adjusted_start == -1 or adjusted_end == -1:
        print(f"Failed to map tokens for Context: {context}")
        raise ValueError("Failed to map token indices after HTML cleaning.")

    return plain_text, adjusted_start, adjusted_end

print("clean_html_and_adjust_spans function defined.")


In [None]:
# Process Short Answers
def process_short_answers(file_path, output_data, max_short_answers=2000, window_size=50):
    try:
        with gzip.open(file_path, 'rt', encoding='utf-8') as file:
            for line in file:
                if len(output_data) >= max_short_answers:
                    break

                example = json.loads(line)
                question = example.get('question_text', "")
                document = example.get('document_text', "")
                annotations = example.get('annotations', [])

                if not document.strip():
                    print("Skipping example due to empty document.")
                    continue

                for annotation in annotations:
                    short_answers = annotation.get('short_answers', [])
                    for sa in short_answers:
                        start_token = sa.get('start_token')
                        end_token = sa.get('end_token')

                        if start_token is None or end_token is None or start_token >= len(document.split(" ")):
                            print(f"Skipping invalid span: Start: {start_token}, End: {end_token}")
                            continue

                        try:
                            context_window, adjusted_start, adjusted_end = extract_context(
                                document, start_token, end_token, window_size
                            )
                            cleaned_context, cleaned_start, cleaned_end = clean_html_and_adjust_spans(
                                context_window, adjusted_start, adjusted_end
                            )

                            output_data.append({
                                "type": "short_answer",
                                "question": question,
                                "document": cleaned_context,
                                "start_token": cleaned_start,
                                "end_token": cleaned_end
                            })
                        except ValueError as e:
                            print(f"Skipping example due to error: {e}")
    except Exception as e:
        print(f"Error processing file {file_path}: {e}")

print("process_short_answers function defined.")

In [None]:
# Here we run the process_short_answers

input_folder = "Newsimplified_natural_questions"
short_answer_data = []

all_files = [file for file in os.listdir(input_folder) if file.endswith(".jsonl.gz")]

num_files_to_process = 50
files_to_process = all_files[:num_files_to_process]

for file_name in files_to_process:
    if len(short_answer_data) >= 2000:
        break
    file_path = os.path.join(input_folder, file_name)
    process_short_answers(file_path, short_answer_data, max_short_answers=2000)

# Save processed data
output_file = "short_answers_data.json"
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(short_answer_data, f, indent=4)

print(f"Processing complete. Short answers with context saved to {output_file}.")


Here is the code section for processing Long Answer questions: This section deals with span-based tasks requiring the extraction of extended context as the answer.








In [None]:
# Count HTML Tags
def count_html_tags(context_window):
    soup = BeautifulSoup(context_window, "html.parser")
    return len(soup.find_all())

print("count_html_tags function defined.")

# Extract Context with Randomization
def extract_context_with_randomization(document, start_token, end_token, window_size=100):
    tokens = document.split(" ")
    doc_length = len(tokens)

    # Define the boundaries with random shifts
    max_pre_shift = max(0, start_token - 1)
    max_post_shift = max(0, doc_length - end_token - 1)
    random_shift = random.randint(-min(max_pre_shift, window_size), min(max_post_shift, window_size))

    pre_start = max(0, start_token - window_size + random_shift)
    post_end = min(doc_length, end_token + window_size + random_shift)

    # Create the context window
    context_window = " ".join(tokens[pre_start:post_end])

    # Adjust start and end tokens relative to the new context window
    adjusted_start = start_token - pre_start
    adjusted_end = end_token - pre_start

    return context_window, adjusted_start, adjusted_end

print("extract_context_with_randomization function defined.")


In [None]:
# Process Long Answers
def process_long_answers(file_path, output_data, max_long_answers=2000, window_size=100, max_html_tags=5):
    try:
        with gzip.open(file_path, 'rt', encoding='utf-8') as file:
            for line in file:
                if len(output_data) >= max_long_answers:
                    break

                # Parse JSON line
                example = json.loads(line)
                question = example.get('question_text', "")
                document = example.get('document_text', "")
                annotations = example.get('annotations', [])

                # Process each annotation
                for annotation in annotations:
                    long_answer = annotation.get('long_answer', {})
                    start_token = long_answer.get('start_token')
                    end_token = long_answer.get('end_token')

                    # Skip if tokens are invalid or exceed the threshold
                    if (
                        start_token is None or end_token is None or
                        start_token == -1 or end_token == -1 or
                        (end_token - start_token) >= 100
                    ):
                        continue

                    # Extract the randomized context window
                    try:
                        context_window, adjusted_start, adjusted_end = extract_context_with_randomization(
                            document, start_token, end_token, window_size
                        )
                    except ValueError as e:
                        print(f"Skipping example due to context extraction error: {e}")
                        continue

                    # Skip examples with too many HTML tags
                    if count_html_tags(context_window) > max_html_tags:
                        continue

                    # Append processed example
                    output_data.append({
                        "type": "long_answer",
                        "question": question,
                        "document": context_window,
                        "start_token": adjusted_start,
                        "end_token": adjusted_end,
                    })
    except Exception as e:
        print(f"Error processing file {file_path}: {e}")

print("process_long_answers function defined.")


In [None]:
# Here we run the process_long_answers

input_folder = "Newsimplified_natural_questions"
long_answer_data = []

all_files = [file for file in os.listdir(input_folder) if file.endswith(".jsonl.gz")]

num_files_to_process = 50 
files_to_process = all_files[:num_files_to_process]

for file_name in files_to_process:
    if len(long_answer_data) >= 2000:
        break
    file_path = os.path.join(input_folder, file_name)
    process_long_answers(file_path, long_answer_data, max_long_answers=2000)

output_file = "long_answers_data.json"
with open(output_file, "w", encoding="utf-8") as f:
    json.dump(long_answer_data, f, indent=4)

print(f"Processing complete. Extracted long answers saved to {output_file}.")

After running the chosen QA task, we can proceed with the following training, validation, and testing steps.





In [5]:
# Define File Paths and Task Type
tokenized_data_path = "balanced_yes_no_data.json" # Options: "balanced_yes_no_data.json", "short_answers_data.json", "long_answers_data.json"
train_path = "train_data.json"
val_path = "val_data.json"
test_path = "test_data.json"

# Task type
task_type = "yes_no"  # Options: "yes_no", "short_answer", "long_answer"
print(f"Task type: {task_type}")


Task type: yes_no


In [6]:
# Split Dataset Function
def split_dataset(tokenized_data_path, train_path, val_path, test_path, task_type=None, test_size=0.3, random_state=42):
    with open(tokenized_data_path, "r", encoding="utf-8") as f:
        tokenized_data = json.load(f)

    # Filter by task_type
    if task_type:
        tokenized_data = [item for item in tokenized_data if item["type"] == task_type]

    # Split into train, validation, and test sets
    train_data, temp_data = train_test_split(tokenized_data, test_size=test_size, random_state=random_state)
    val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=random_state)

    # Save splits
    with open(train_path, "w", encoding="utf-8") as f:
        json.dump(train_data, f, indent=4)
    with open(val_path, "w", encoding="utf-8") as f:
        json.dump(val_data, f, indent=4)
    with open(test_path, "w", encoding="utf-8") as f:
        json.dump(test_data, f, indent=4)

    print("Dataset split into train, validation, and test sets.")

print("Splitting the dataset...")
split_dataset(tokenized_data_path, train_path, val_path, test_path, task_type=task_type)


Splitting the dataset...
Dataset split into train, validation, and test sets.


In [7]:
# Define collate_fn Function for DataLoader

def collate_fn(batch):
    input_ids = [item["input_ids"] for item in batch]
    attention_mask = [item["attention_mask"] for item in batch]
    types = [item["type"] for item in batch]

    # Initialize empty lists for optional fields
    labels = []
    start_positions = []
    end_positions = []

    for item in batch:
        if item["type"] == "yes_no":
            labels.append(item["label"])
        elif item["type"] in ["long_answer", "short_answer"]:
            start_positions.append(item["start_positions"])
            end_positions.append(item["end_positions"])

    collated_data = {
        "input_ids": pad_sequence(input_ids, batch_first=True, padding_value=0),
        "attention_mask": pad_sequence(attention_mask, batch_first=True, padding_value=0),
        "types": types,
    }

    if labels:
        collated_data["labels"] = torch.tensor(labels, dtype=torch.long)
    if start_positions and end_positions:
        collated_data["start_positions"] = torch.tensor(start_positions, dtype=torch.long)
        collated_data["end_positions"] = torch.tensor(end_positions, dtype=torch.long)

    return collated_data

print("collate_fn function defined.")


collate_fn function defined.


In [10]:
# Define QADataset Class
class QADataset(Dataset):
    def __init__(self, file_path, task_type="long_answer", tokenizer_name="bert-base-uncased"):
        with open(file_path, "r", encoding="utf-8") as f:
            self.data = [item for item in json.load(f) if item["type"] == task_type]
        self.task_type = task_type
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Tokenize with offset mapping
        tokenized = self.tokenizer(
            item["question"],
            item["document"],
            truncation=True,
            max_length=512,
            padding="max_length",
            return_tensors="pt",
            return_offsets_mapping=(self.task_type in ["short_answer", "long_answer"])
        )

        input_ids = tokenized["input_ids"].squeeze(0).tolist()
        attention_mask = tokenized["attention_mask"].squeeze(0).tolist()

        # Prepare output based on task type
        output = {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "type": item["type"],
        }

        if self.task_type == "yes_no":
            output["label"] = torch.tensor(1 if item.get("yes_no_answer") == "YES" else 0, dtype=torch.long)
        elif self.task_type in ["short_answer", "long_answer"]:
            offset_mapping = tokenized["offset_mapping"].squeeze(0).tolist()
            sep_index = input_ids.index(self.tokenizer.sep_token_id) + 1

            # Adjust start and end positions using offset mapping (only for span-based tasks)
            tokenized_start, tokenized_end = None, None
            char_start = len(" ".join(item["document"].split()[:item["start_token"]])) + 1
            char_end = len(" ".join(item["document"].split()[:item["end_token"]]))

            for idx, (start_offset, end_offset) in enumerate(offset_mapping[sep_index:], start=sep_index):
                if start_offset <= char_start < end_offset:
                    tokenized_start = idx
                if start_offset < char_end <= end_offset:
                    tokenized_end = idx

            if tokenized_start is None or tokenized_end is None:
                # Attempt to map span using a more relaxed approach
                for idx, (start_offset, end_offset) in enumerate(offset_mapping):
                    if start_offset <= char_start <= end_offset:
                        tokenized_start = idx
                    if start_offset <= char_end <= end_offset:
                        tokenized_end = idx

                if tokenized_start is None or tokenized_end is None:
                    print(f"Failed to map tokens for example {idx}.")
                    print(f"Question: {item['question']}")
                    print(f"Document: {item['document']}")
                    print(f"Start Token: {item['start_token']}, End Token: {item['end_token']}")
                    print(f"Character Start: {char_start}, Character End: {char_end}")
                    print(f"Offset Mapping: {offset_mapping}")
                    raise ValueError(f"Failed to map tokens for example {idx}.")

            output["start_positions"] = torch.tensor(tokenized_start, dtype=torch.long)
            output["end_positions"] = torch.tensor(tokenized_end, dtype=torch.long)

        return output

print("QADataset class defined.")


QADataset class defined.


In [23]:
# Here is the important paramters:
batch_size = 8
learning_rate = 4e-5
num_epochs = 5

In [24]:
# Data Loaders
print("Creating data loaders...")
train_dataset = QADataset(train_path, task_type=task_type)
val_dataset = QADataset(val_path, task_type=task_type)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=collate_fn)

# Model and Optimizer
print("Initializing model and optimizer...")

if task_type == "yes_no":
    model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
elif task_type in ["short_answer", "long_answer"]:
    model = BertForQuestionAnswering.from_pretrained("bert-base-uncased")
    print("Loaded BertForQuestionAnswering for span-based tasks.")
else:
    raise ValueError(f"Unknown task type: {task_type}")

model.to(device)
optimizer = AdamW(model.parameters(), lr=learning_rate)


Creating data loaders...
Initializing model and optimizer...


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [25]:
# Training Loop
print("Starting training...")

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    yes_correct = 0
    no_correct = 0
    yes_total = 0
    no_total = 0

    for batch in tqdm(train_loader, desc=f"Training Epoch {epoch + 1}/{num_epochs}"):
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        if task_type == "yes_no":
            labels = batch["labels"].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            # Predictions for accuracy
            predictions = torch.argmax(outputs.logits, dim=1)
            yes_correct += ((predictions == 1) & (labels == 1)).sum().item()
            no_correct += ((predictions == 0) & (labels == 0)).sum().item()
            yes_total += (labels == 1).sum().item()
            no_total += (labels == 0).sum().item()

        elif task_type in ["short_answer", "long_answer"]:
            start_positions = batch["start_positions"].to(device)
            end_positions = batch["end_positions"].to(device)
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                start_positions=start_positions,
                end_positions=end_positions,
            )
            loss = outputs.loss

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch + 1} - Loss: {total_loss / len(train_loader)}")

    if task_type == "yes_no":
        yes_accuracy = yes_correct / yes_total if yes_total > 0 else 0
        no_accuracy = no_correct / no_total if no_total > 0 else 0
        print(f"Training YES Accuracy: {yes_accuracy:.4f}")
        print(f"Training NO Accuracy: {no_accuracy:.4f}")


Starting training...


Training Epoch 1/5: 100%|██████████| 217/217 [00:26<00:00,  8.27it/s]


Epoch 1 - Loss: 0.6984546407027179
Training YES Accuracy: 0.4954
Training NO Accuracy: 0.4786


Training Epoch 2/5: 100%|██████████| 217/217 [00:26<00:00,  8.24it/s]


Epoch 2 - Loss: 0.6833275133563627
Training YES Accuracy: 0.5358
Training NO Accuracy: 0.5898


Training Epoch 3/5: 100%|██████████| 217/217 [00:26<00:00,  8.25it/s]


Epoch 3 - Loss: 0.5812813378698815
Training YES Accuracy: 0.6859
Training NO Accuracy: 0.7161


Training Epoch 4/5: 100%|██████████| 217/217 [00:26<00:00,  8.22it/s]


Epoch 4 - Loss: 0.27626121874194814
Training YES Accuracy: 0.8961
Training NO Accuracy: 0.8992


Training Epoch 5/5: 100%|██████████| 217/217 [00:26<00:00,  8.19it/s]

Epoch 5 - Loss: 0.09752215037963563
Training YES Accuracy: 0.9654
Training NO Accuracy: 0.9664





In [26]:
# Validation Loop
print("Starting validation...")

model.eval()
span_metrics = {"exact_match": 0, "f1": 0}
total_span_examples = 0
all_labels, all_predictions = [], []

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Evaluating"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        if task_type == "yes_no":
            labels = batch["labels"].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            predictions = torch.argmax(outputs.logits, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predictions.cpu().numpy())

        elif task_type in ["short_answer", "long_answer"]:
            start_positions = batch["start_positions"].to(device)
            end_positions = batch["end_positions"].to(device)
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                start_positions=start_positions,
                end_positions=end_positions
            )
            start_preds = torch.argmax(outputs.start_logits, dim=1)
            end_preds = torch.argmax(outputs.end_logits, dim=1)

            for i in range(len(start_positions)):
                pred_span = set(range(start_preds[i].item(), end_preds[i].item() + 1))
                true_span = set(range(start_positions[i].item(), end_positions[i].item() + 1))
                if pred_span == true_span:
                    span_metrics["exact_match"] += 1
                intersection = len(pred_span & true_span)
                precision = intersection / len(pred_span) if pred_span else 0
                recall = intersection / len(true_span) if true_span else 0
                f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
                span_metrics["f1"] += f1

            total_span_examples += len(start_positions)

if task_type == "yes_no":
    yes_accuracy = accuracy_score(all_labels, all_predictions)
    report = classification_report(all_labels, all_predictions, target_names=["NO", "YES"])
    print(f"Validation Accuracy: {yes_accuracy:.4f}")
    print(report)

if total_span_examples > 0:
    span_metrics["exact_match"] /= total_span_examples
    span_metrics["f1"] /= total_span_examples
    print(f"Validation Exact Match (EM): {span_metrics['exact_match']:.4f}")
    print(f"Validation F1 Score: {span_metrics['f1']:.4f}")


Starting validation...


Evaluating: 100%|██████████| 47/47 [00:01<00:00, 23.78it/s]

Validation Accuracy: 0.5676
              precision    recall  f1-score   support

          NO       0.57      0.37      0.44       175
         YES       0.57      0.75      0.65       195

    accuracy                           0.57       370
   macro avg       0.57      0.56      0.55       370
weighted avg       0.57      0.57      0.55       370






In [30]:
# Load Fine-Tuned Model and Tokenizer
model_path = f"./fine_tuned_bert_{task_type}_model"
tokenizer_path = model_path

if task_type == "yes_no":
    model = DistilBertForSequenceClassification.from_pretrained(model_path)
elif task_type in ["short_answer", "long_answer"]:
    model = BertForQuestionAnswering.from_pretrained(model_path)
else:
    raise ValueError(f"Unsupported task type: {task_type}")

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model.to(device)
model.eval()

print(f"Loaded model and tokenizer for task type: {task_type}")


Loaded model and tokenizer for task type: yes_no


In [27]:
# Save Model and Tokenizer
output_dir = f"./fine_tuned_bert_{task_type}_model"
os.makedirs(output_dir, exist_ok=True)
model.save_pretrained(output_dir)
train_dataset.tokenizer.save_pretrained(output_dir)
print(f"Model saved to {output_dir}.")


Model saved to ./fine_tuned_bert_yes_no_model.


Here is tests






In [28]:
# Load Test Data
test_path = "test_data.json"
with open(test_path, "r", encoding="utf-8") as f:
    test_data = json.load(f)

print(f"Loaded test data from {test_path}. Total examples: {len(test_data)}")

# Define Helper Functions
def compute_exact_match(pred_start, pred_end, true_start, true_end):
    return (pred_start == true_start) and (pred_end == true_end)

def compute_f1(pred_start, pred_end, true_start, true_end):
    pred_set = set(range(pred_start, pred_end + 1))
    true_set = set(range(true_start, true_end + 1))
    intersection = pred_set & true_set
    if not pred_set or not true_set:
        return 1 if pred_set == true_set else 0
    precision = len(intersection) / len(pred_set)
    recall = len(intersection) / len(true_set)
    return 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0

print("Helper functions defined.")


Loaded test data from test_data.json. Total examples: 371
Helper functions defined.


In [31]:
# Automatic Testing
all_predictions = []
all_ground_truths = []
span_metrics = {"exact_match": 0, "f1": 0}
total_span_examples = 0

for item in tqdm(test_data, desc="Testing"):
    question = item["question"]
    context = item["document"] if "document" in item else item["context"]

    if task_type == "yes_no":
        ground_truth = item["yes_no_answer"]
        inputs = tokenizer(
            question, context, return_tensors="pt", truncation=True, max_length=512, padding="longest"
        ).to(device)

        # Remove `token_type_ids` for DistilBERT compatibility
        if "token_type_ids" in inputs:
            del inputs["token_type_ids"]

        # Perform inference
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=-1)
            predicted_class = torch.argmax(probabilities, dim=-1).item()

        # Map predicted class to "YES" or "NO"
        predicted_answer = "YES" if predicted_class == 1 else "NO"
        all_predictions.append(predicted_answer)
        all_ground_truths.append(ground_truth)

    elif task_type in ["short_answer", "long_answer"]:
        true_start = item["start_token"]
        true_end = item["end_token"]
        tokenized = tokenizer(
            question, context, truncation=True, max_length=512, padding="max_length", return_tensors="pt", return_offsets_mapping=True
        ).to(device)

        input_ids = tokenized["input_ids"]
        attention_mask = tokenized["attention_mask"]
        offset_mapping = tokenized["offset_mapping"].squeeze(0).tolist()
        sep_index = input_ids.squeeze(0).tolist().index(tokenizer.sep_token_id) + 1
        char_start = len(" ".join(context.split()[:true_start])) + 1
        char_end = len(" ".join(context.split()[:true_end]))
        pred_start, pred_end = None, None

        for idx, (start_offset, end_offset) in enumerate(offset_mapping[sep_index:], start=sep_index):
            if start_offset <= char_start < end_offset:
                pred_start = idx
            if start_offset < char_end <= end_offset:
                pred_end = idx

        if pred_start is None or pred_end is None:
            continue

        # Perform inference
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            start_logits = outputs.start_logits
            end_logits = outputs.end_logits
            model_start = torch.argmax(start_logits, dim=1).item()
            model_end = torch.argmax(end_logits, dim=1).item()

        # Compute metrics
        exact_match = compute_exact_match(model_start, model_end, pred_start, pred_end)
        f1 = compute_f1(model_start, model_end, pred_start, pred_end)
        span_metrics["exact_match"] += exact_match
        span_metrics["f1"] += f1
        total_span_examples += 1

print("Testing complete.")


Testing: 100%|██████████| 371/371 [00:01<00:00, 208.45it/s]

Testing complete.





In [32]:
# Evaluation Metrics and Answer logs
if task_type == "yes_no":
    # Calculate metrics
    accuracy = accuracy_score(all_ground_truths, all_predictions)
    report = classification_report(all_ground_truths, all_predictions, target_names=["NO", "YES"])
    
    # Print overall metrics
    print(f"Overall Accuracy: {accuracy:.2f}")
    print("Classification Report:")
    print(report)

    # Print questions and answers
    print("\nQuestions and Answers:")
    for idx, item in enumerate(test_data):
        question = item["question"]
        ground_truth = item["yes_no_answer"]
        predicted_answer = all_predictions[idx]
        print(f"Q{idx + 1}: {question}")
        print(f"  Predicted Answer: {predicted_answer}")
        print(f"  Ground Truth: {ground_truth}")
        print("-" * 50)

elif task_type in ["short_answer", "long_answer"]:
    # Calculate metrics
    span_metrics["exact_match"] /= total_span_examples
    span_metrics["f1"] /= total_span_examples
    
    # Print span-based metrics
    print(f"Span-Based Validation Metrics:")
    print(f"  - Exact Match: {span_metrics['exact_match']:.4f}")
    print(f"  - F1 Score: {span_metrics['f1']:.4f}")

    # Print questions and answers
    print("\nQuestions and Answers:")
    for idx, item in enumerate(test_data):
        question = item["question"]
        context = item["document"] if "document" in item else item["context"]
        true_start = item["start_token"]
        true_end = item["end_token"]
        predicted_start = pred_start
        predicted_end = pred_end

        # Decode predicted and true spans
        predicted_span = " ".join(context.split()[predicted_start:predicted_end + 1]) if predicted_start and predicted_end else "N/A"
        true_span = " ".join(context.split()[true_start:true_end]) if true_start and true_end else "N/A"
        
        print(f"Q{idx + 1}: {question}")
        print(f"  Predicted Answer: {predicted_span}")
        print(f"  Ground Truth: {true_span}")
        print(f"  Predicted Span Indices: {model_start}-{model_end}")
        print(f"  True Span Indices: {pred_start}-{pred_end}")
        print(f"  Exact Match: {exact_match}")
        print(f"  F1 Score: {f1}")
        print("-" * 50)



Overall Accuracy: 0.54
Classification Report:
              precision    recall  f1-score   support

          NO       0.63      0.34      0.44       197
         YES       0.51      0.78      0.62       174

    accuracy                           0.54       371
   macro avg       0.57      0.56      0.53       371
weighted avg       0.58      0.54      0.52       371


Questions and Answers:
Q1: does the dog die in evil dead remake
  Predicted Answer: YES
  Ground Truth: YES
--------------------------------------------------
Q2: does st louis still have a football team
  Predicted Answer: YES
  Ground Truth: NO
--------------------------------------------------
Q3: is puerto rico part of the continental usa
  Predicted Answer: YES
  Ground Truth: NO
--------------------------------------------------
Q4: do i need a tourist visa for taiwan
  Predicted Answer: YES
  Ground Truth: NO
--------------------------------------------------
Q5: is a city and guilds equivalent to a degree
  Pre

Manual Testing for each qa





In [34]:
def test_yes_no_question(question, context, model, tokenizer, device):
    # Tokenize the input
    inputs = tokenizer(
        question,
        context,
        return_tensors="pt",
        truncation=True,
        max_length=512,
        padding="longest"
    ).to(device)

    # Remove `token_type_ids` for DistilBERT (if necessary)
    if "token_type_ids" in inputs:
        del inputs["token_type_ids"]

    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=-1)
        predicted_class = torch.argmax(probabilities, dim=-1).item()

    # Map prediction to "YES" or "NO"
    predicted_answer = "YES" if predicted_class == 1 else "NO"

    # Print the results
    print(f"Question: {question}")
    print(f"Context: {context}")
    print(f"Predicted Answer: {predicted_answer}")
    print(f"Confidence: {probabilities[0][predicted_class]:.4f}")



In [35]:
# Define a question and context
question = "Is the capital of France Paris?"
context = "Paris is the capital of France, known for its rich culture and history."

# Call the function to test
test_yes_no_question(question, context, model, tokenizer, device)


Question: Is the capital of France Paris?
Context: Paris is the capital of France, known for its rich culture and history.
Predicted Answer: YES
Confidence: 0.9507


In [203]:
def test_short_answer_question(question, context, model, tokenizer, device):
    # Tokenize the input
    inputs = tokenizer(
        question,
        context,
        return_tensors="pt",
        truncation=True,
        max_length=512,
        padding="max_length",
        return_offsets_mapping=True 
    ).to(device)

    # Forward pass through the model
    with torch.no_grad():
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"]
        )

    # Extract start and end logits
    start_logits = outputs.start_logits
    end_logits = outputs.end_logits

    # Get the most probable start and end tokens
    start_idx = torch.argmax(start_logits, dim=1).item()
    end_idx = torch.argmax(end_logits, dim=1).item()

    # Convert token indices back to the text
    input_ids = inputs["input_ids"].squeeze(0)
    offset_mapping = inputs["offset_mapping"].squeeze(0).tolist()

    # Ensure indices are within bounds and valid
    if start_idx >= len(offset_mapping) or end_idx >= len(offset_mapping):
        return {"error": "Predicted indices are out of bounds."}

    # Map token indices to character spans in the context
    start_char, end_char = offset_mapping[start_idx][0], offset_mapping[end_idx][1]

    # Extract the answer from the context
    predicted_answer = context[start_char:end_char]

    return {
        "question": question,
        "context": context,
        "predicted_answer": predicted_answer,
        "start_idx": start_idx,
        "end_idx": end_idx,
        "start_char": start_char,
        "end_char": end_char
    }


In [243]:
# Give a short-answer question and context
question = "What is the capital of France?"
context = ("France is a country in Western Europe. Its capital city is Paris, known for its art, gastronomy, and culture. The Eiffel Tower and the Louvre Museum are among its iconic landmarks.")

result = test_short_answer_question(question, context, model, tokenizer, device)

if "error" in result:
    print(result["error"])
else:
    print(f"Question: {result['question']}")
    print(f"Predicted Answer: {result['predicted_answer']}")
    print(f"Start Token: {result['start_idx']}, End Token: {result['end_idx']}")
    print(f"Answer Span (Character Indices): {result['start_char']} - {result['end_char']}")


Question: What are the key features of the Amazon rainforest?
Predicted Answer: a
Start Token: 10, End Token: 10
Answer Span (Character Indices): 50 - 51


In [318]:
def test_long_answer_question(question, context, model, tokenizer, device):
    inputs = tokenizer(
        question,
        context,
        return_tensors="pt",
        truncation=True,
        max_length=512,
        padding="max_length",
        return_offsets_mapping=True
    ).to(device)

    with torch.no_grad():
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"]
        )

    start_logits = outputs.start_logits
    end_logits = outputs.end_logits

    start_idx = torch.argmax(start_logits, dim=1).item()
    end_idx = torch.argmax(end_logits, dim=1).item()

    # Debugging outputs
    print(f"Start Index: {start_idx}, End Index: {end_idx}")
    offset_mapping = inputs["offset_mapping"].squeeze(0).tolist()

    # Check if indices are valid
    if start_idx >= len(offset_mapping) or end_idx >= len(offset_mapping):
        return {"error": "Predicted indices are out of bounds."}

    start_char, end_char = offset_mapping[start_idx][0], offset_mapping[end_idx][1]

    # Handle invalid spans
    if start_char > end_char:
        print("Invalid span detected. Adjusting end index.")
        end_char = start_char + 10 

    predicted_answer = context[start_char:end_char]
    return {
        "question": question,
        "context": context,
        "predicted_answer": predicted_answer,
        "start_idx": start_idx,
        "end_idx": end_idx,
        "start_char": start_char,
        "end_char": end_char
    }


In [322]:
# Define a long-answer question and context
question: "How does photosynthesis work in plants?"
context: "Photosynthesis is a vital process for life on Earth, enabling plants to convert sunlight into energy. It is closely linked to the carbon cycle. <p>Photosynthesis is the process by which green plants, algae, and some bacteria convert light energy, usually from the sun, into chemical energy in the form of glucose. This process takes place in the chloroplasts of plant cells, where chlorophyll captures sunlight. During photosynthesis, carbon dioxide from the air and water from the soil react in the presence of sunlight to produce glucose and oxygen.</p> Plants also produce oxygen as a byproduct, which supports life on Earth."

result = test_long_answer_question(question, context, model, tokenizer, device)

# Print the result
if "error" in result:
    print(result["error"])
else:
    print(f"Question: {result['question']}")
    print(f"Predicted Long Answer: {result['predicted_answer']}")
    print(f"Start Token: {result['start_idx']}, End Token: {result['end_idx']}")
    print(f"Answer Span (Character Indices): {result['start_char']} - {result['end_char']}")


Start Index: 46, End Index: 124
Question: What are the key features of the Amazon rainforest?
Predicted Long Answer: <p>The Amazon rainforest, also known as Amazonia, is a vast tropical rainforest in South America that covers over 5.5 million square kilometers. It is known for its incredible biodiversity, hosting millions of species of plants, animals, and insects. The Amazon River, one of the largest rivers in the world, runs through the forest, contributing to the rich ecosystem.</p>
Start Token: 46, End Token: 124
Answer Span (Character Indices): 179 - 552
