In [1]:
from transformers import AutoTokenizer

#加载分词工具
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

tokenizer('What is your name?', 'My name is Sylvain.')

{'input_ids': [101, 2054, 2003, 2115, 2171, 1029, 102, 2026, 2171, 2003, 25353, 22144, 2378, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [2]:
from datasets import load_dataset, load_from_disk

#加载数据集
#dataset = load_dataset('squad')
dataset = load_from_disk('datas/squad')

#采样,数据量太大了跑不动
dataset['train'] = dataset['train'].shuffle().select(range(10000))
dataset['validation'] = dataset['validation'].shuffle().select(range(200))

print(dataset['train'][0])

dataset

{'id': '5725d643271a42140099d285', 'title': 'Hellenistic_period', 'context': "Ptolemy's family ruled Egypt until the Roman conquest of 30 BC. All the male rulers of the dynasty took the name Ptolemy. Ptolemaic queens, some of whom were the sisters of their husbands, were usually called Cleopatra, Arsinoe or Berenice. The most famous member of the line was the last queen, Cleopatra VII, known for her role in the Roman political battles between Julius Caesar and Pompey, and later between Octavian and Mark Antony. Her suicide at the conquest by Rome marked the end of Ptolemaic rule in Egypt though Hellenistic culture continued to thrive in Egypt throughout the Roman and Byzantine periods until the Muslim conquest.", 'question': "Till what year did Ptolemy's family rule Egypt?", 'answers': {'text': ['30 BC'], 'answer_start': [57]}}


DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10000
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 200
    })
})

In [3]:
#从官方教程里抄出来的函数,总之就是squad数据的处理函数,过程非常复杂,即使是官方的实现也是有问题的,我实在没本事写这个
def prepare_train_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples['question'],
        examples['context'],
        truncation='only_second',
        max_length=384,
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding='max_length',
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != 1:
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != 1:
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char
                    and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[
                        token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(
                    token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples


#调用squad数据预处理函数
examples = prepare_train_features(dataset['train'][:3])

#先看看处理后的结果
for k, v in examples.items():
    print(k, len(v), v)
    print()

#还原成文字查看,很显然,即使是huggingface的实现也是有问题的
for i in range(len(examples['input_ids'])):
    input_ids = examples['input_ids'][i]
    start_positions = examples['start_positions'][i]
    end_positions = examples['end_positions'][i]

    print('问题和文本')
    question_and_context = tokenizer.decode(input_ids)
    print(question_and_context)

    print('答案')
    answer = tokenizer.decode(input_ids[start_positions:end_positions])
    print(answer)

    print('原答案')
    original_answer = dataset['train'][i]['answers']['text'][0]
    print(original_answer)
    print()

input_ids 3 [[101, 6229, 2054, 2095, 2106, 23517, 1005, 1055, 2155, 3627, 5279, 1029, 102, 23517, 1005, 1055, 2155, 5451, 5279, 2127, 1996, 3142, 9187, 1997, 2382, 4647, 1012, 2035, 1996, 3287, 11117, 1997, 1996, 5321, 2165, 1996, 2171, 23517, 1012, 13866, 9890, 2863, 2594, 8603, 1010, 2070, 1997, 3183, 2020, 1996, 5208, 1997, 2037, 19089, 1010, 2020, 2788, 2170, 22003, 1010, 29393, 5740, 2063, 2030, 2022, 7389, 6610, 1012, 1996, 2087, 3297, 2266, 1997, 1996, 2240, 2001, 1996, 2197, 3035, 1010, 22003, 8890, 1010, 2124, 2005, 2014, 2535, 1999, 1996, 3142, 2576, 7465, 2090, 10396, 11604, 1998, 13433, 8737, 3240, 1010, 1998, 2101, 2090, 13323, 21654, 1998, 2928, 16262, 1012, 2014, 5920, 2012, 1996, 9187, 2011, 4199, 4417, 1996, 2203, 1997, 13866, 9890, 2863, 2594, 3627, 1999, 5279, 2295, 27464, 3226, 2506, 2000, 25220, 1999, 5279, 2802, 1996, 3142, 1998, 8734, 6993, 2127, 1996, 5152, 9187, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

[CLS] till what year did ptolemy's family rule egypt? [SEP] ptolemy's family ruled egypt until the roman conquest of 30 bc. all the male rulers of the dynasty took the name ptolemy. ptolemaic queens, some of whom were the sisters of their husbands, were usually called cleopatra, arsinoe or berenice. the most famous member of the line was the last queen, cleopatra vii, known for her role in the roman political battles between julius caesar and pompey, and later between octavian and mark antony. her suicide at the conquest by rome marked the end of ptolemaic rule in egypt though hellenistic culture continued to thrive in egypt throughout the roman and byzantine periods until the muslim conquest. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PA

In [4]:
#应用预处理函数
dataset = dataset.map(
    function=prepare_train_features,
    batched=True,
    remove_columns=['id', 'title', 'context', 'question', 'answers'])

print(dataset['train'][0])

dataset

  0%|          | 0/10 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

{'input_ids': [101, 6229, 2054, 2095, 2106, 23517, 1005, 1055, 2155, 3627, 5279, 1029, 102, 23517, 1005, 1055, 2155, 5451, 5279, 2127, 1996, 3142, 9187, 1997, 2382, 4647, 1012, 2035, 1996, 3287, 11117, 1997, 1996, 5321, 2165, 1996, 2171, 23517, 1012, 13866, 9890, 2863, 2594, 8603, 1010, 2070, 1997, 3183, 2020, 1996, 5208, 1997, 2037, 19089, 1010, 2020, 2788, 2170, 22003, 1010, 29393, 5740, 2063, 2030, 2022, 7389, 6610, 1012, 1996, 2087, 3297, 2266, 1997, 1996, 2240, 2001, 1996, 2197, 3035, 1010, 22003, 8890, 1010, 2124, 2005, 2014, 2535, 1999, 1996, 3142, 2576, 7465, 2090, 10396, 11604, 1998, 13433, 8737, 3240, 1010, 1998, 2101, 2090, 13323, 21654, 1998, 2928, 16262, 1012, 2014, 5920, 2012, 1996, 9187, 2011, 4199, 4417, 1996, 2203, 1997, 13866, 9890, 2863, 2594, 3627, 1999, 5279, 2295, 27464, 3226, 2506, 2000, 25220, 1999, 5279, 2802, 1996, 3142, 1998, 8734, 6993, 2127, 1996, 5152, 9187, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 10107
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 203
    })
})

In [5]:
import torch
from transformers.data.data_collator import default_data_collator

#数据加载器
loader = torch.utils.data.DataLoader(
    dataset=dataset['train'],
    batch_size=8,
    collate_fn=default_data_collator,
    shuffle=True,
    drop_last=True,
)

for i, data in enumerate(loader):
    break

len(loader), data

(1263,
 {'input_ids': tensor([[  101,  2073,  2106,  ...,     0,     0,     0],
          [  101,  2054,  2001,  ...,     0,     0,     0],
          [  101, 15854,  6238,  ...,     0,     0,     0],
          ...,
          [  101,  2040,  2001,  ...,     0,     0,     0],
          [  101,  2054,  2724,  ...,     0,     0,     0],
          [  101,  2040,  2550,  ...,     0,     0,     0]]),
  'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          ...,
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0]]),
  'start_positions': tensor([ 20,  55, 134,  47, 100, 153,  73,  41]),
  'end_positions': tensor([ 23,  58, 136,  51, 100, 154,  75,  42])})

In [6]:
from transformers import AutoModelForQuestionAnswering, DistilBertModel

#加载模型
#model = AutoModelForQuestionAnswering.from_pretrained('distilbert-base-uncased')


#定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pretrained = DistilBertModel.from_pretrained(
            'distilbert-base-uncased')

        self.fc = torch.nn.Sequential(torch.nn.Dropout(0.1),
                                      torch.nn.Linear(768, 2))
        
        #加载预训练模型的参数
        parameters = AutoModelForQuestionAnswering.from_pretrained('distilbert-base-uncased')
        self.fc[1].load_state_dict(parameters.qa_outputs.state_dict())

    def forward(self, input_ids, attention_mask, start_positions,
                end_positions):
        #[b, lens] -> [b, lens, 768]
        logits = self.pretrained(input_ids=input_ids,
                                 attention_mask=attention_mask)
        logits = logits.last_hidden_state

        #[b, lens, 768] -> [b, lens, 2]
        logits = self.fc(logits)

        #[b, lens, 2] -> [b, lens, 1],[b, lens, 1]
        start_logits, end_logits = logits.split(1, dim=2)

        #[b, lens, 1] -> [b, lens]
        start_logits = start_logits.squeeze(2)
        end_logits = end_logits.squeeze(2)

        #起点和终点都不能超出句子的长度
        lens = start_logits.shape[1]
        start_positions = start_positions.clamp(0, lens)
        end_positions = end_positions.clamp(0, lens)

        criterion = torch.nn.CrossEntropyLoss(ignore_index=lens)

        start_loss = criterion(start_logits, start_positions)
        end_loss = criterion(end_logits, end_positions)
        loss = (start_loss + end_loss) / 2

        return {
            'loss': loss,
            'start_logits': start_logits,
            'end_logits': end_logits
        }


model = Model()

#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)

out = model(**data)

out['loss'], out['start_logits'].shape, out['end_logits'].shape

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_layer_

6636.4418


(tensor(5.8415, grad_fn=<DivBackward0>),
 torch.Size([8, 384]),
 torch.Size([8, 384]))

In [7]:
#测试
def test():
    model.eval()

    #数据加载器
    loader_val = torch.utils.data.DataLoader(
        dataset=dataset['validation'],
        batch_size=16,
        collate_fn=default_data_collator,
        shuffle=True,
        drop_last=True,
    )

    start_offset = 0
    end_offset = 0
    total = 0
    for i, data in enumerate(loader_val):
        #计算
        with torch.no_grad():
            out = model(**data)

        start_offset += (out['start_logits'].argmax(dim=1) -
                         data['start_positions']).abs().sum().item()

        end_offset += (out['end_logits'].argmax(dim=1) -
                       data['end_positions']).abs().sum().item()

        total += 16

        if i % 10 == 0:
            print(i)

        if i == 50:
            break

    print(start_offset / total, end_offset / total)

    start_logits = out['start_logits'].argmax(dim=1)
    end_logits = out['end_logits'].argmax(dim=1)
    for i in range(4):
        input_ids = data['input_ids'][i]

        pred_answer = input_ids[start_logits[i]:end_logits[i]]

        label_answer = input_ids[
            data['start_positions'][i]:data['end_positions'][i]]

        print('input_ids=', tokenizer.decode(input_ids))
        print('pred_answer=', tokenizer.decode(pred_answer))
        print('label_answer=', tokenizer.decode(label_answer))
        print()


test()

0
10
58.484375 61.0
input_ids= [CLS] what shape are pyrenoids? [SEP] the chloroplasts of some hornworts and algae contain structures called pyrenoids. they are not found in higher plants. pyrenoids are roughly spherical and highly refractive bodies which are a site of starch accumulation in plants that contain them. they consist of a matrix opaque to electrons, surrounded by two hemispherical starch plates. the starch is accumulated as the pyrenoids mature. in algae with carbon concentrating mechanisms, the enzyme rubisco is found in the pyrenoids. starch can also accumulate around the pyrenoids when co2 is scarce. pyrenoids can divide to form new pyrenoids, or be produced " de novo ". [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

In [8]:
from transformers import AdamW
from transformers.optimization import get_scheduler


#训练
def train():
    optimizer = AdamW(model.parameters(), lr=2e-5)
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)

    model.train()
    for i, data in enumerate(loader):
        out = model(**data)
        loss = out['loss']

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        optimizer.zero_grad()
        model.zero_grad()

        if i % 50 == 0:
            lr = optimizer.state_dict()['param_groups'][0]['lr']

            start_offset = (out['start_logits'].argmax(dim=1) -
                            data['start_positions']).abs().sum().item() / 8

            end_offset = (out['end_logits'].argmax(dim=1) -
                          data['end_positions']).abs().sum().item() / 8

            print(i, loss.item(), lr, start_offset, end_offset)

    torch.save(model, 'models/3.阅读理解.model')


train()



0 5.984038352966309 1.9984164687252575e-05 79.125 57.25
50 4.169722080230713 1.9192399049881235e-05 38.875 36.75
100 4.359471321105957 1.84006334125099e-05 63.125 44.5
150 3.296112537384033 1.760886777513856e-05 43.0 45.0
200 3.181190252304077 1.6817102137767223e-05 40.75 57.875
250 3.5123019218444824 1.6025336500395887e-05 60.25 64.75
300 2.9549078941345215 1.5233570863024545e-05 17.0 40.375
350 3.4506635665893555 1.4441805225653207e-05 7.625 13.5
400 2.231827735900879 1.365003958828187e-05 8.125 7.5
450 2.395646572113037 1.2858273950910532e-05 7.375 13.0
500 2.952989101409912 1.2066508313539194e-05 30.5 34.375
550 2.6806797981262207 1.1274742676167856e-05 23.625 22.25
600 2.2227702140808105 1.0482977038796518e-05 20.375 33.875
650 1.6418771743774414 9.69121140142518e-06 10.0 12.375
700 1.9634692668914795 8.899445764053842e-06 12.75 13.25
750 1.7344026565551758 8.107680126682502e-06 36.875 10.5
800 2.1363699436187744 7.315914489311164e-06 5.0 5.0
850 1.0939271450042725 6.5241488519398

In [9]:
model = torch.load('models/3.阅读理解.model')
test()

0
10
9.78125 13.145833333333334
input_ids= [CLS] what are the top 4 - 5 % graduating students honored with? [SEP] harvard's academic programs operate on a semester calendar beginning in early september and ending in mid - may. undergraduates typically take four half - courses per term and must maintain a four - course rate average to be considered full - time. in many concentrations, students can elect to pursue a basic program or an honors - eligible program requiring a senior thesis and / or advanced course work. students graduating in the top 4 – 5 % of the class are awarded degrees summa cum laude, students in the next 15 % of the class are awarded magna cum laude, and the next 30 % of the class are awarded cum laude. harvard has chapters of academic honor societies such as phi beta kappa and various committees and departments also award several hundred named prizes annually. harvard, along with other universities, has been accused of grade inflation, although there is evidence tha