In [62]:
!pip install datasets transformers



# 1.Preparation

In [63]:
from datasets import load_dataset

In [64]:
# squad_v2等于True或者False分别代表使用SQUAD v1 或者 SQUAD v2。
# 如果您使用的是其他数据集，那么True代表的是：模型可以回答“不可回答”问题，也就是部分问题不给出答案，而False则代表所有问题必须回答。
squad_v2 = False
# 下载数据（确保有网络）
datasets = load_dataset("squad_v2" if squad_v2 else "squad")

Found cached dataset squad (C:/Users/96212/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


  0%|          | 0/2 [00:00<?, ?it/s]

In [65]:
datasets["train"][0]

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}

In [66]:
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [67]:
show_random_elements(datasets["train"], num_examples=1)

Unnamed: 0,id,title,context,question,answers
0,56d00567234ae51400d9c277,New_York_City,"New York has been described as the ""Capital of Baseball"". There have been 35 Major League Baseball World Series and 73 pennants won by New York teams. It is one of only five metro areas (Los Angeles, Chicago, Baltimore–Washington, and the San Francisco Bay Area being the others) to have two baseball teams. Additionally, there have been 14 World Series in which two New York City teams played each other, known as a Subway Series and occurring most recently in 2000. No other metropolitan area has had this happen more than once (Chicago in 1906, St. Louis in 1944, and the San Francisco Bay Area in 1989). The city's two current Major League Baseball teams are the New York Mets, who play at Citi Field in Queens, and the New York Yankees, who play at Yankee Stadium in the Bronx. who compete in six games of interleague play every regular season that has also come to be called the Subway Series. The Yankees have won a record 27 championships, while the Mets have won the World Series twice. The city also was once home to the Brooklyn Dodgers (now the Los Angeles Dodgers), who won the World Series once, and the New York Giants (now the San Francisco Giants), who won the World Series five times. Both teams moved to California in 1958. There are also two Minor League Baseball teams in the city, the Brooklyn Cyclones and Staten Island Yankees.",How many minor league baseball teams are there in NYC?,"{'text': ['two'], 'answer_start': [288]}"


# 2. data process

In [68]:
from transformers import AutoTokenizer
model_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

In [69]:
example = datasets["train"][0]
tokenized_example=tokenizer(example["question"], example["context"])
tokenized_example["input_ids"]

[101,
 2000,
 3183,
 2106,
 1996,
 6261,
 2984,
 9382,
 3711,
 1999,
 8517,
 1999,
 10223,
 26371,
 2605,
 1029,
 102,
 6549,
 2135,
 1010,
 1996,
 2082,
 2038,
 1037,
 3234,
 2839,
 1012,
 10234,
 1996,
 2364,
 2311,
 1005,
 1055,
 2751,
 8514,
 2003,
 1037,
 3585,
 6231,
 1997,
 1996,
 6261,
 2984,
 1012,
 3202,
 1999,
 2392,
 1997,
 1996,
 2364,
 2311,
 1998,
 5307,
 2009,
 1010,
 2003,
 1037,
 6967,
 6231,
 1997,
 4828,
 2007,
 2608,
 2039,
 14995,
 6924,
 2007,
 1996,
 5722,
 1000,
 2310,
 3490,
 2618,
 4748,
 2033,
 18168,
 5267,
 1000,
 1012,
 2279,
 2000,
 1996,
 2364,
 2311,
 2003,
 1996,
 13546,
 1997,
 1996,
 6730,
 2540,
 1012,
 3202,
 2369,
 1996,
 13546,
 2003,
 1996,
 24665,
 23052,
 1010,
 1037,
 14042,
 2173,
 1997,
 7083,
 1998,
 9185,
 1012,
 2009,
 2003,
 1037,
 15059,
 1997,
 1996,
 24665,
 23052,
 2012,
 10223,
 26371,
 1010,
 2605,
 2073,
 1996,
 6261,
 2984,
 22353,
 2135,
 2596,
 2000,
 3002,
 16595,
 9648,
 4674,
 2061,
 12083,
 9711,
 2271,
 1999,
 8517,
 101

In [70]:
sequence_ids = tokenized_example.sequence_ids()
print(sequence_ids)

[None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, None]


In [71]:
for i, example in enumerate(datasets["train"]):
    if len(tokenizer(example["question"], example["context"])["input_ids"]) > 384:
        break
example = datasets["train"][i]

In [72]:
len(tokenizer(example["question"], example["context"])["input_ids"])

396

In [73]:
max_length = 384 # 输入feature的最大长度，question和context拼接之后
doc_stride = 128 # 2个切片之间的重合token数量。

In [74]:
len(tokenizer(example["question"], example["context"], max_length=max_length, truncation="only_second")["input_ids"])

384

In [75]:
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    stride=doc_stride
)

In [76]:
[len(x) for x in tokenized_example["input_ids"]]

[384, 157]

In [77]:
for i, x in enumerate(tokenized_example["input_ids"][:2]):
    print("切片: {}".format(i))
    print(tokenizer.decode(x))

切片: 0
[CLS] how many wins does the notre dame men's basketball team have? [SEP] the men's basketball team has over 1, 600 wins, one of only 12 schools who have reached that mark, and have appeared in 28 ncaa tournaments. former player austin carr holds the record for most points scored in a single game of the tournament with 61. although the team has never won the ncaa tournament, they were named by the helms athletic foundation as national champions twice. the team has orchestrated a number of upsets of number one ranked teams, the most notable of which was ending ucla's record 88 - game winning streak in 1974. the team has beaten an additional eight number - one teams, and those nine wins rank second, to ucla's 10, all - time in wins against the top team. the team plays in newly renovated purcell pavilion ( within the edmund p. joyce center ), which reopened for the beginning of the 2009 – 2010 season. the team is coached by mike brey, who, as of the 2014 – 15 season, his fifteenth a

In [78]:
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    stride=doc_stride
)

In [79]:
# 打印切片前后位置下标的对应关系
print(tokenized_example["offset_mapping"][1][:30])

[(0, 0), (0, 3), (4, 8), (9, 13), (14, 18), (19, 22), (23, 28), (29, 33), (34, 37), (37, 38), (38, 39), (40, 50), (51, 55), (56, 60), (60, 61), (0, 0), (1093, 1105), (1105, 1106), (1107, 1110), (1111, 1115), (1115, 1116), (1116, 1118), (1119, 1123), (1124, 1133), (1134, 1137), (1138, 1145), (1146, 1152), (1153, 1159), (1160, 1166), (1167, 1172)]


In [80]:
first_token_id = tokenized_example["input_ids"][0][1] #2129
offsets = tokenized_example["offset_mapping"][0][1] #(0, 3)
print(tokenizer.convert_ids_to_tokens([first_token_id])[0], example["question"][offsets[0]:offsets[1]])

how How


In [81]:
answers = example["answers"] #答案{'answer_start': [30], 'text': ['over 1,600']}
start_char = answers["answer_start"][0] #答案起始位置30
end_char = start_char + len(answers["text"][0])#答案终止位置40

In [82]:
# 找到当前文本的Start token index.
token_start_index = 0 #得到context在句子1和句子2拼接后的句子中的起始位置
while sequence_ids[token_start_index] != 1: #sequence_ids区分question和context
    token_start_index += 1

In [83]:
# 找到当前文本的End token index.
token_end_index = len(tokenized_example["input_ids"][0]) - 1#得到context在句子1和句子2拼接后的句子中的终止位置，可能还要去掉一些padding
while token_end_index>=len(sequence_ids) or sequence_ids[token_end_index] != 1:
    token_end_index -= 1

In [84]:
# 检测答案是否在文本区间的外部，这种情况下意味着该样本的数据标注在CLS token位置。
offsets = tokenized_example["offset_mapping"][0]
if (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): #答案在文本内
    # 将token_start_index和token_end_index移动到answer所在位置的两侧.
    # 注意：答案在最末尾的边界条件.
    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
        token_start_index += 1 #之前的token_start_index在context的第一个token位置
    start_position = token_start_index - 1
    while offsets[token_end_index][1] >= end_char:
        token_end_index -= 1#之前的token_end_index在context的最后一个token位置
    end_position = token_end_index + 1
    print("start_position: {}, end_position: {}".format(start_position, end_position))
else: #答案在文本外
    print("The answer is not in this feature.")

start_position: 23, end_position: 26


In [85]:
print(tokenizer.decode(tokenized_example["input_ids"][0][start_position: end_position+1]))
print(answers["text"][0])

over 1, 600
over 1,600


In [86]:
pad_on_right = tokenizer.padding_side == "right" #context在右边

In [87]:
def prepare_train_features(examples):
    # 既要对examples进行truncation（截断）和padding（补全）还要还要保留所有信息，所以要用的切片的方法。
    # 每一个超长文本example会被切片成多个输入，相邻两个输入之间会有交集。
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # 我们使用overflow_to_sample_mapping参数来映射切片片ID到原始ID。
    # 比如有2个expamples被切成4片，那么对应是[0, 0, 1, 1]，前两片对应原来的第一个example。
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # offset_mapping也对应4片
    # offset_mapping参数帮助我们映射到原始输入，由于答案标注在原始输入上，所以有助于我们找到答案的起始和结束位置。
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # 重新标注数据
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # 对每一片进行处理
        # 将无答案的样本标注到CLS上
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # 区分question和context
        sequence_ids = tokenized_examples.sequence_ids(i)

        # 拿到原始的example 下标.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # 如果没有答案，则使用CLS所在的位置为答案.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # 答案的character级别Start/end位置.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # 找到token级别的index start.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # 找到token级别的index end.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # 检测答案是否超出文本长度，超出的话也使用CLS index作为标注.
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # 如果不超出则找到答案token的start和end位置。.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [88]:
tokenized_datasets = datasets.map(prepare_train_features, batched=True, remove_columns=datasets["train"].column_names)

Loading cached processed dataset at C:/Users/96212/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453\cache-c012c1f472b79710.arrow


  0%|          | 0/11 [00:00<?, ?ba/s]

In [89]:
from transformers import AutoModelForQuestionAnswering
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

ImportError: 
AutoModelForQuestionAnswering requires the PyTorch library but it was not found in your environment. Checkout the instructions on the
installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment.
Please note that you may need to restart your runtime after installation.


In [None]:
token_end_index

In [None]:
len(sequence_ids)

In [None]:
token_end_index