In [1]:
import json
import torch
from torch.utils.data import DataLoader
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering, AdamW
from tqdm import tqdm
from collections import Counter
import string
import re

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [2]:
def load_squad_data(file_path):
    """
    Load SQuAD data from JSON format and extract contexts, questions, and answers.
    """
    with open(file_path, 'rb') as file:
        data = json.load(file)
    
    contexts, questions, answers = [], [], []
    for article in data['data']:
        for paragraph in article['paragraphs']:
            context = paragraph['context']
            for qa_pair in paragraph['qas']:
                question = qa_pair['question']
                answer_key = 'plausible_answers' if 'plausible_answers' in qa_pair else 'answers'
                for answer in qa_pair[answer_key]:
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)
                    
    return contexts, questions, answers

# Load train and validation data
train_contexts, train_questions, train_answers = load_squad_data('squad/train-v2.0.json')
val_contexts, val_questions, val_answers = load_squad_data('squad/test-v2.0.json')


In [3]:
def adjust_answer_end_indices(answers, contexts):
    """
    Adjust answer end indices to match the exact positions in the context.
    """
    for answer, context in zip(answers, contexts):
        answer_text = answer['text']
        start_index = answer['answer_start']
        end_index = start_index + len(answer_text)

        # Check if the substring matches the answer text
        if context[start_index:end_index] == answer_text:
            answer['answer_end'] = end_index
        else:
            # Try shifting the start index if there's a mismatch
            for offset in [1, 2]:
                if context[start_index - offset:end_index - offset] == answer_text:
                    answer['answer_start'] = start_index - offset
                    answer['answer_end'] = end_index - offset
                    break
            else:
                # Default the end index to the end of the answer if no match is found
                answer['answer_end'] = start_index + len(answer_text)

# Re-run the function on train and validation answers
adjust_answer_end_indices(train_answers, train_contexts)
adjust_answer_end_indices(val_answers, val_contexts)


In [4]:
# Initialize tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

# Tokenize the datasets
train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
val_encodings = tokenizer(val_contexts, val_questions, truncation=True, padding=True)



In [5]:
def map_token_positions(encodings, answers):
    start_positions, end_positions = [], []

    for idx, answer in enumerate(answers):
        start_positions.append(encodings.char_to_token(idx, answer['answer_start']))
        end_positions.append(encodings.char_to_token(idx, answer['answer_end']))

        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        shift = 1
        while end_positions[-1] is None:
            end_positions[-1] = encodings.char_to_token(idx, answer['answer_end'] - shift)
            shift += 1

    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

# Map token positions for both datasets
map_token_positions(train_encodings, train_answers)
map_token_positions(val_encodings, val_answers)


In [6]:
class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, index):
        return {key: torch.tensor(value[index]) for key, value in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)



In [7]:
# Create train and validation datasets
train_dataset = SquadDataset(train_encodings)
val_dataset = SquadDataset(val_encodings)

In [8]:
# Model setup
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")

# Setup device and optimizer
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
model.train()
print(device)
optim = AdamW(model.parameters(), lr=2e-4)

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
cuda




In [9]:
# Training loop
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
for epoch in range(3):
    model.train()
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask,
                        start_positions=start_positions,
                        end_positions=end_positions)
        loss = outputs[0]
        loss.backward()
        optim.step()
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())

# Save model
model_path = 'models/'
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)


Epoch 0: 100%|██████████| 2320/2320 [09:07<00:00,  4.24it/s, loss=1.87]
Epoch 1: 100%|██████████| 2320/2320 [09:03<00:00,  4.27it/s, loss=1.75] 
Epoch 2: 100%|██████████| 2320/2320 [09:03<00:00,  4.27it/s, loss=1.36] 


('models/tokenizer_config.json',
 'models/special_tokens_map.json',
 'models/vocab.txt',
 'models/added_tokens.json',
 'models/tokenizer.json')

In [10]:
model = DistilBertForQuestionAnswering.from_pretrained("models/")
model.to(device)

  return torch.load(checkpoint_file, map_location=map_location)


DistilBertForQuestionAnswering(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
      

In [None]:

# Evaluation on validation set
model.eval()
val_loader = DataLoader(val_dataset, batch_size=16)
acc, answers, references = [], [], []
loop = tqdm(val_loader)

for batch in loop:
    with torch.no_grad():
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_true = batch['start_positions'].to(device)
        end_true = batch['end_positions'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask)
        start_pred = torch.argmax(outputs['start_logits'], dim=1)
        end_pred = torch.argmax(outputs['end_logits'], dim=1)
        acc.append(((start_pred == start_true).sum() / len(start_pred)).item())
        acc.append(((end_pred == end_true).sum() / len(end_pred)).item())

        for i in range(start_pred.shape[0]):
            all_tokens = tokenizer.convert_ids_to_tokens(batch['input_ids'][i])
            answer = ' '.join(all_tokens[start_pred[i]: end_pred[i] + 1])
            ref = ' '.join(all_tokens[start_true[i]: end_true[i] + 1])
            answer_ids = tokenizer.convert_tokens_to_ids(answer.split())
            answer = tokenizer.decode(answer_ids)
            answers.append(answer)
            references.append(ref)


 44%|████▍     | 490/1116 [00:46<00:59, 10.52it/s]

In [20]:
# Evaluation metrics
def normalize_answer(s):
    def remove_articles(text):
        return re.sub(r'\\b(a|an|the)\\b', ' ', text)
    def white_space_fix(text):
        return ' '.join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)
    def lower(text):
        return text.lower()
    return white_space_fix(remove_articles(remove_punc(lower(s))))

In [21]:
def exact_match_score(prediction, ground_truth):
    return normalize_answer(prediction) == normalize_answer(ground_truth)

In [22]:
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    return max([metric_fn(prediction, gt) for gt in ground_truths])


In [23]:
def f1_score(prediction, ground_truth):
    pred_tokens = normalize_answer(prediction).split()
    gt_tokens = normalize_answer(ground_truth).split()
    common = Counter(pred_tokens) & Counter(gt_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_tokens)
    recall = 1.0 * num_same / len(gt_tokens)
    return (2 * precision * recall) / (precision + recall)

In [24]:
def evaluate(gold_answers, predictions):
    f1 = exact_match = total = 0
    for gt, prediction in zip(gold_answers, predictions):
        total += 1
        exact_match += metric_max_over_ground_truths(exact_match_score, prediction, [gt])
        f1 += metric_max_over_ground_truths(f1_score, prediction, [gt])
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total
    return {'f1': f1}

In [25]:
# Run evaluation
evaluation_results = evaluate(references, answers)
print("Evaluation Results:", evaluation_results)

Evaluation Results: {'f1': 14.200461257580255}
