In [1]:
squad_v2 = False
model_checkpoint = "distilbert-base-uncased"
batch_size = 112

In [2]:
from datasets import load_dataset
datasets = load_dataset("squad")

In [3]:
datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [4]:
datasets["train"][0]

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [6]:
import transformers
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [7]:
tokenizer("Hello World")

{'input_ids': [101, 7592, 2088, 102], 'attention_mask': [1, 1, 1, 1]}

In [8]:
max_length = 384 
doc_stride = 128 
# 251,252 是两条条超过max_length的记录
test_251 = datasets['train'][251]
test_252 = datasets['train'][252]
test_example = {
    'id': [test_251['id'], test_252['id']],
    'title': [test_251['title'], test_252['title']],
    'context': [test_251['context'], test_252['context']],
    'question': [test_251['question'], test_252['question']],
    'answers': [test_251['answers'], test_252['answers']],
}

In [9]:
test_example

{'id': ['5733caf74776f4190066124e', '5733caf74776f4190066124f'],
 'title': ['University_of_Notre_Dame', 'University_of_Notre_Dame'],
 'context': ["The men's basketball team has over 1,600 wins, one of only 12 schools who have reached that mark, and have appeared in 28 NCAA tournaments. Former player Austin Carr holds the record for most points scored in a single game of the tournament with 61. Although the team has never won the NCAA Tournament, they were named by the Helms Athletic Foundation as national champions twice. The team has orchestrated a number of upsets of number one ranked teams, the most notable of which was ending UCLA's record 88-game winning streak in 1974. The team has beaten an additional eight number-one teams, and those nine wins rank second, to UCLA's 10, all-time in wins against the top team. The team plays in newly renovated Purcell Pavilion (within the Edmund P. Joyce Center), which reopened for the beginning of the 2009–2010 season. The team is coached by Mik

In [10]:
def prepare_train_features(examples):
    """
    整体方法是在标准的tokenizer基础上，增加了start_positions和end_positions两列，代表答案token的开始和结束位置
    """
    if isinstance(examples["question"], str):
        examples["question"] = examples["question"].lstrip()
    else:
        examples["question"] = [q.lstrip() for q in examples["question"]]
        
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        truncation="only_second",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # 切片后的chunk => 原文
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # token pos => char pos
    offset_mapping = tokenized_examples.pop("offset_mapping")

    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        input_ids = tokenized_examples["input_ids"][i] # chunk token

        cls_index = input_ids.index(tokenizer.cls_token_id) # cls pos, always 0

        sequence_ids = tokenized_examples.sequence_ids(i) # token mask，0-question，1-answer，None-special token
        
        sample_index = sample_mapping[i] # 对应的原文
        
        answers = examples["answers"][sample_index] # 对应原文的答案

        if len(answers["answer_start"]) == 0: # 'answers': {'text': [''], 'answer_start': []}}，原文中没有答案
            # 开始结束位置记录为 0
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # 答案在原文中的开始和结束位置
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            token_start_index = 0 # token对应的context开始的位置
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1 

            token_end_index = len(input_ids) - 1 # token对应的context结束的位置
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1 
            
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                # 当前chunk中没有答案（必须完整包含）
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # 当前chunk中有答案
                # 后移token的位置，找到答案对应的token开始位置
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                # 前移token的位置，找到答案对应的token结束位置
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [11]:
test_tokenized_example = prepare_train_features(test_example)

In [12]:
tokenized_datasets = datasets.map(prepare_train_features,
                                  batched=True,
                                  remove_columns=datasets["train"].column_names)

In [13]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 88524
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 10784
    })
})

In [14]:
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer

model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [15]:
batch_size=350
model_dir = f"models/{model_checkpoint}-finetuned-squad"

args = TrainingArguments(
    output_dir=model_dir,
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.01,
)

In [16]:
from transformers import default_data_collator

data_collator = default_data_collator

In [17]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [18]:
trainer.train()



Epoch,Training Loss,Validation Loss
1,No log,1.502634
2,2.031900,1.330522
3,2.031900,1.266645
4,1.200300,1.234033
5,1.200300,1.233527


TrainOutput(global_step=1265, training_loss=1.5085064190649704, metrics={'train_runtime': 3313.0382, 'train_samples_per_second': 133.599, 'train_steps_per_second': 0.382, 'total_flos': 4.337225635212288e+16, 'train_loss': 1.5085064190649704, 'epoch': 5.0})