In [1]:
pip install transformers datasets evaluate

Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install transformers[torch]

Note: you may need to restart the kernel to use updated packages.


# Training 


In [1]:
import os
import re
import sys
import string
import argparse
import datetime
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

torch.manual_seed(0)

<torch._C.Generator at 0x2177a447a50>

Finetuning distillbert for context classification

In [2]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Creating training data

In [3]:
from datasets import load_dataset

squad = load_dataset("squad")

In [7]:
# generates the bad answers by randomly picking a span in the context
import random
def generate_bad_labels(example, min_length=1, max_length=5, answer_count=2):
    words = example['context'].split()
    length = len(words)
    correct_ans = example['answers']['text']
    answers = []
    while (len(answers) < answer_count):
        ans_len = random.randint(1, 5)
        start = random.randint(0, length - ans_len - 1)
        ans = ' '.join(words[start: start + ans_len])
        if (ans != correct_ans):
            answers.append(ans)
    return answers

In [5]:
# generates the bad answers by extending from current answer
def generate_bad_labels_from_answers(context, answer, min_length=1, max_length=5, answer_count=2):
    return None

In [8]:
def add_answers(dataset, generator=generate_bad_labels):
    bad_answer1 = []
    bad_answer2 = []
    correct_answer = []
    labels = []
    for example in dataset:
        bad_answers = generator(example, answer_count=2)
        bad_answer1.append(bad_answers[0])
        bad_answer2.append(bad_answers[1])
        correct_answer.append(example['answers']['text'][0])
        labels.append(random.randint(0, 2))
    dataset = dataset.add_column('bad_answer1', bad_answer1)
    dataset = dataset.add_column('bad_answer2', bad_answer2)
    dataset = dataset.add_column('correct_answer', correct_answer)
    dataset = dataset.add_column('label', labels)
    return dataset


In [8]:
squad['train'] = add_answers(squad['train'])
squad['validation'] = add_answers(squad['validation'])

ValueError: The table can't have duplicated columns but columns ['bad_answer1'] are duplicated.

In [9]:
squad['train'][0]

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]},
 'bad_answer1': 'and facing it,',
 'bad_answer2': 

In [10]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [11]:
ans_names = ['correct_answer', 'bad_answer1', 'bad_answer2']
def preprocess_function(examples):
    context = [[c] * 3 for c in examples["context"]]
    question = examples["question"]
    labels = examples["label"]
    qna = [
        [f"{q} {examples[ans][i]}" for ans in ans_names] for i, q in enumerate(question)
    ]
    for i, q in enumerate(qna):
        label = labels[i]
        q[0], q[label] = q[label], q[0]
    context = sum(context, [])
    qna = sum(qna, [])

    tokenized_examples = tokenizer(context, qna, truncation=True)
    return {k: [v[i : i + 3] for i in range(0, len(v), 3)] for k, v in tokenized_examples.items()}

In [12]:
tokenized_squad = squad.map(preprocess_function, batched=True)

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [11]:
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch


@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        return batch

In [16]:
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer

model = AutoModelForMultipleChoice.from_pretrained("distilbert-base-uncased")

Some weights of DistilBertForMultipleChoice were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'pre_classifier.bias', 'classifier.weight', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Not computing metrics here due to issue with memory

In [17]:
training_args = TrainingArguments(
    output_dir="squad_qna_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=True,
    prediction_loss_only=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_squad["train"],
    eval_dataset=tokenized_squad["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
)

trainer.train()

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss
1,0.2317,0.158884
2,0.1097,0.201094
3,0.038,0.203859


TrainOutput(global_step=65700, training_loss=0.13413862392601175, metrics={'train_runtime': 14648.8023, 'train_samples_per_second': 17.94, 'train_steps_per_second': 4.485, 'total_flos': 5.031553594176346e+16, 'train_loss': 0.13413862392601175, 'epoch': 3.0})

In [18]:
trainer.push_to_hub()

'https://huggingface.co/Clyvey/squad_qna_model/tree/main/'

In [19]:
trainer.save_model('qna_model.pt')

# Evaluation

In [2]:
from transformers import AutoModelForMultipleChoice, AutoTokenizer

model = AutoModelForMultipleChoice.from_pretrained("./qna_model.pt")

tokenizer = AutoTokenizer.from_pretrained("./qna_model.pt")


## Data generation

don't need to run this if training data was previously generated

In [6]:
from datasets import load_dataset

squad = load_dataset("squad")

In [9]:
squad['validation'] = add_answers(squad['validation'])
squad['train'] = add_answers(squad['train'])

In [15]:
ans_names = ['correct_answer', 'bad_answer1', 'bad_answer2']
def preprocess_function(examples):
    context = [[c] * 3 for c in examples["context"]]
    question = examples["question"]
    labels = examples["label"]
    qna = [
        [f"{q} {examples[ans][i]}" for ans in ans_names] for i, q in enumerate(question)
    ]
    for i, q in enumerate(qna):
        label = labels[i]
        q[0], q[label] = q[label], q[0]
    context = sum(context, [])
    qna = sum(qna, [])

    tokenized_examples = tokenizer(context, qna, truncation=True)
    return {k: [v[i : i + 3] for i in range(0, len(v), 3)] for k, v in tokenized_examples.items()}

tokenized_squad = squad.map(preprocess_function, batched=True)

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [12]:
import evaluate
import numpy as np

accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [19]:
training_args = TrainingArguments(
    output_dir="squad_qna_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=True,
)

eval_trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_squad["train"],
    eval_dataset=tokenized_squad["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
    compute_metrics=compute_metrics
)

In [20]:
eval_trainer.evaluate()

{'eval_loss': 0.17389702796936035,
 'eval_accuracy': 0.9616840113528855,
 'eval_runtime': 116.7652,
 'eval_samples_per_second': 90.524,
 'eval_steps_per_second': 22.635}