<a href="https://colab.research.google.com/github/kokolerk/colab/blob/main/Summarization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets as well as other dependencies. Uncomment the following cell and run it.

In [2]:
! pip install datasets transformers rouge-score nltk

Collecting datasets
  Downloading datasets-1.18.3-py3-none-any.whl (311 kB)
[K     |████████████████████████████████| 311 kB 4.0 MB/s 
[?25hCollecting transformers
  Downloading transformers-4.16.2-py3-none-any.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 44.9 MB/s 
[?25hCollecting rouge-score
  Downloading rouge_score-0.0.4-py2.py3-none-any.whl (22 kB)
Collecting fsspec[http]>=2021.05.0
  Downloading fsspec-2022.1.0-py3-none-any.whl (133 kB)
[K     |████████████████████████████████| 133 kB 39.0 MB/s 
[?25hCollecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 37.0 MB/s 
[?25hCollecting xxhash
  Downloading xxhash-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (243 kB)
[K     |████████████████████████████████| 243 kB 37.0 MB/s 
Collecting huggingface-hub<1.0.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl 

If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.

To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.

First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then execute the following cell and input your username and password:

In [3]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center>\n<img src=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Then you need to install Git-LFS. Uncomment the following instructions:

In [4]:
 !apt install git-lfs

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following packages were automatically installed and are no longer required:
  cuda-command-line-tools-10-0 cuda-command-line-tools-10-1
  cuda-command-line-tools-11-0 cuda-compiler-10-0 cuda-compiler-10-1
  cuda-compiler-11-0 cuda-cuobjdump-10-0 cuda-cuobjdump-10-1
  cuda-cuobjdump-11-0 cuda-cupti-10-0 cuda-cupti-10-1 cuda-cupti-11-0
  cuda-cupti-dev-11-0 cuda-documentation-10-0 cuda-documentation-10-1
  cuda-documentation-11-0 cuda-documentation-11-1 cuda-gdb-10-0 cuda-gdb-10-1
  cuda-gdb-11-0 cuda-gpu-library-advisor-10-0 cuda-gpu-library-advisor-10-1
  cuda-libraries-10-0 cuda-libraries-10-1 cuda-libraries-11-0
  cuda-memcheck-10-0 cuda-memcheck-10-1 cuda-memcheck-11-0 cuda-nsight-10-0
  cuda-nsight-10-1 cuda-nsight-11-0 cuda-nsight-11-1 cuda-nsight-compute-10-0
  cuda-nsight-compute-10-1 cuda-nsight-compute-11-0 cuda-nsight-compute-11-1
  cuda-nsight-systems-10-1 cuda-nsight-systems-

Make sure your version of Transformers is at least 4.11.0 since the functionality was introduced in that version:

In [5]:
import transformers

print(transformers.__version__)

4.16.2


You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs [here](https://github.com/huggingface/transformers/tree/master/examples/seq2seq).

# Fine-tuning a model on a summarization task

In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model for a summarization task. We will use the [XSum dataset](https://arxiv.org/pdf/1808.08745.pdf) (for extreme summarization) which contains BBC articles accompanied with single-sentence summaries.

![Widget inference on a summarization task](https://github.com/huggingface/notebooks/blob/master/examples/images/summarization.png?raw=1)

We will see how to easily load the dataset for this task using 🤗 Datasets and how to fine-tune a model on it using the `Trainer` API.

In [6]:
model_checkpoint = "t5-small"

This notebook is built to run  with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a sequence-to-sequence version in the Transformers library. Here we picked the [`t5-small`](https://huggingface.co/t5-small) checkpoint. 

## Loading the dataset

We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`.  

In [7]:
from datasets import load_dataset, load_metric

raw_datasets = load_dataset("xsum")
metric = load_metric("rouge")

Downloading:   0%|          | 0.00/2.05k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/954 [00:00<?, ?B/s]

Using custom data configuration default


Downloading and preparing dataset xsum/default (download: 245.38 MiB, generated: 507.60 MiB, post-processed: Unknown size, total: 752.98 MiB) to /root/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934...


  0%|          | 0/2 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.00M [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset xsum downloaded and prepared to /root/.cache/huggingface/datasets/xsum/default/1.2.0/32c23220eadddb1149b16ed2e9430a05293768cfffbdfd151058697d4c11f934. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set:
数据集包含一个用于训练，测试，验证的key值。

In [8]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

To access an actual element, you need to select a split first, then give an index:
获得一个数据实例，给一个序号。

In [9]:
raw_datasets["train"][0]

 'id': '35232142',
 'summary': 'Clean-up operations are continuing across the Scottish Borders and Dumfries and Galloway after flooding caused by Storm Frank.'}

To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.

为了看看数据集实际是什么样的，下面的函数随机选取了几个数据集中的实例。

In [10]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    # 设置随机数
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    # 输出数据实例
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [11]:
show_random_elements(raw_datasets["train"])

Unnamed: 0,document,summary,id
0,"Paul Cummins received threats by email, phone and letter over plans for services charities to benefit from the Â£10m raised, he told the Sunday Times.\nThe Derbyshire artist said he believed it was because some people felt the charities were ""involved in war"".\nBlood Swept Lands And Seas Of Red saw 888,246 ceramic poppies ""planted"" in the moat.\nEach poppy represented one British or colonial military death during World War One.\nThe installation culminated on Armistice Day when the final poppy was planted by 13-year old cadet Harry Hayes from Berkshire, before the guns were fired 21 times and a two-minute silence was observed.\nMr Cummins said: ""The threats came, I suppose, because they felt that the money was going to charities which in some way were involved in war.""\nHe added that police had been called in over the matter.\nIt is thought about five million people visited Blood-Swept Lands and Seas of Red, the title of which was inspired by a line from the will of a Derbyshire serviceman who died in Flanders.\nHe had described ""the blood-swept lands and seas of red, where angels fear to tread"".\nThe following day a team of 8,000 volunteers started removing the 888,246 poppies and sending them to the people who bought them for Â£25 each.\nThe net proceeds plus 10% of every sale generated the Â£10m being shared between Help for Heroes, the Royal British Legion, Combat Stress, Cobseo, Coming Home and SSAFA.\nMr Cummins also told the paper he was working on a new British-based project involving ceramic tulips, although the location is yet to be disclosed.\nHe added that he had had offers to work on projects elsewhere in Europe as well as another ""distant part of the world"".\nWorld War One Centenary","The artist behind the Tower of London poppy exhibition received death threats, he has revealed.",31477733
1,"Mason Jones, from Deri, Caerphilly county, died in 2005 after eating contaminated meat at school.\nWilliam John Tudor was jailed for a year for breaking food safety laws.\nBut Cardiff Civil Justice Centre heard the Crown Prosecution Service has admitted it made the ""wrong decision"" in not pressing manslaughter charges.\nThe infected meat which killed Mason came from Tudor's butchers in Bridgend which also supplied more than 40 other schools in the south Wales valleys.\nAbout 160 people fell ill during the E. coli outbreak.\nThe original inquest in 2010 into Mason's death recorded a narrative conclusion. Delivering his verdict, coroner David Bowen said: ""I have agonised over a verdict of unlawful killing but despite substantial, some might say horrific, breaches of food hygiene regulations the evidence is not strong enough.\n""There is little doubt Mason was owed a duty of care and a catalogue of failures to observe basic food hygiene breached that duty.\n""But it is not enough for there to be a breach of the duty of care, however extensive and reprehensible that may be.""\nOn Tuesday, Cardiff Civil Justice centre considered an appeal by Mason's family to get the original inquest narrative conclusion quashed and a second inquest heard.\nMason's family, including his older brother who was also infected during the outbreak, were all present in court.\nRepresenting the family, Mark Powell QC, claimed the coroner got the original verdict wrong.\n""We say the coroner misdirected the jury and misdirected himself as to the essential elements of the offence,"" he said.\n""We don't say the coroner was acting improperly, we say he was wrong.""\nFollowing Mason's death, the CPS declined to press charges of manslaughter against Tudor, of Cowbridge, Vale of Glamorgan.\nThe hearing was told that the CPS have since said to the family that was the wrong decision.\nMr Powell said the family is motivated by ""a desire to right a wrong"".\nThe hearing was told Tudor had sufficient training in food hygiene and knew there was a risk someone might die.\nThe Civil Justice Centre heard a judgement would be handed down at a later date.","A butcher whose E. coli-infected meat killed a five-year-old boy should have been charged with manslaughter, a court has heard.",34452603
2,"With a ""degree of technical competence rarely seen"", Regin had probably taken years to develop, Symantec said.\nAnd a nation state may have written it to serve its spying agencies' needs.\nThe program had been used in ""systematic spying campaigns"" over the past six years, Symantec said.\nAimed at Windows users, Regin slowly infiltrated its targets, taking care at each stage to hide its tracks, the company said.\n""Many components of Regin remain undiscovered and additional functionality and versions may exist,"" it added.\n""Its design makes it highly suited for persistent, long-term surveillance operations against targets.""\nJason Steer, director of technology strategy at security firm FireEye, said: ""These types of toolkits have existed for a few years now.""\nHe added: ""It's a challenge to the whole security industry as to how they find these malicious and sophisticated pieces of code,""\nSecurity firms were better at spotting such things even though Regin and its ilk were built to fool modern-day tools that look for malicious programs and monitor activity to spot anything suspicious. The techniques Regin used to sneak on to a network and communicate with its creators were very complicated, he said.\n""It's clearly been written by someone that has much more than making money in mind,"" he said.\nMr Steer said the tip-offs about Regin and similarly sophisticated threats often came from government agencies who kept an eye on the cyber spying capabilities of both friendly and hostile nations.\nVictims had been infected via spoofed versions of well-known websites and by exploiting known vulnerabilities in web browser software, said Symantec in a detailed analysis.\nIn a blogpost, security company F-Secure said it had first encountered Regin in 2009 after investigating what was making a server on the network of one of its customers crash repeatedly. Closer investigation revealed the culprit to be Regin which was attempting to insert itself into the heart of the software controlling the server.\nChief research officer Mikko Hypponen said: ""Finding malware of this calibre is very rare.\n""We're still missing big parts of the puzzle.""\n""Nevertheless, it's obvious this is a very complicated malware written by a well-equipped nation-state."" He added that the malware did not look like it originated in China or Russia - the places suspected of creating many other stealthy, spying programs.\nSecurity firm Kaspersky Lab said it too had spotted Regin being used to infiltrate networks and steal data. In one attack, Regin was used to gather administrative details for a mobile phone network in the Middle East that, if used, would have given attackers control over the system.\nSymantec said it had captured the first copies of Regin in a small number of organisations between 2008 and 2011.\nSoon after, the malware had appeared to have been withdrawn, but a new version found in 2013 was now being actively used.\nOnly about 100 Regin infections have so far been identified.\nIt is believed to provide the ability to:\nSymantec said that Regin had a lot in common with other malicious programs such as Flame, Duqu and Stuxnet, also thought to be written by nation states to aid their spying efforts.","An ""extremely complex"" and ""stealthy"" spying program has been stealing data from ISPs, energy companies, airlines and research-and-development labs, a security company has said.",30145265
3,"The moors provide an ideal habitat for mountain hares.\nBut gamekeepers may cull the animals in an effort to protect red grouse from the tick-borne louping ill virus.\nWildlife groups say there is a lack of scientific evidence to support the claim that culling hares protects grouse.\nSimon Jones of the Scottish Wildlife Trust, one of the groups campaigning for a moratorium, said: ""We, along with the other organisations, are calling for a three year ban, to allow time for all those involved to take stock of the longer term impacts of large scale culling.\n""Once the results of the study have been published we will then be able to identify the best ways to monitor mountain hare populations and measure the impact that management is having on their conservation status.\n""The unregulated and seemingly unsustainable culling that is endemic on many grouse moors is a threat to these important populations.""\nDuncan Orr-Ewing of RSPB Scotland said: ""Very little is known about their current numbers of mountain hares and population trends.\n""We also don't know what impact these large scale culls are having on mountain hares' wider conservation status, which could mean the Scottish Government may be in breach of its legally binding international EU obligations to this species.""\nThe calls for a ban have drawn an angry response from gamekeepers' leaders.\nA spokesman for the Scottish Gamekeepers Association said: ""For groups with environmental credentials to call for such an environmentally irresponsible measure beggars belief.\n""The numbers of tick, already a growing problem in the countryside, will escalate, endangering any bird that nests on the ground, not to mention the potential repercussions for human health.\n""It will be bad for birds and bad for biodiversity.""\nThe ten organisations calling for the three year ban are:\n· Highland Foundation for Wildlife\n· John Muir Trust\n· National Trust for Scotland\n· RSPB Scotland\n· Royal Zoological Society of Scotland\n· Scottish Raptor Study Group\n· The Scottish Wildlife Trust\n· The Cairngorms Campaign\n· The Mammal Society\n· The Badenoch and Strathspey Conservation Group\nMeanwhile, landowners have said such a move would be ""ill informed"" and ""heavy handed"".\nTim Baynes of the Scottish Moorland Group told BBC Scotland: ""We firmly believe the management of the hare population as it is at present is sustainable.\n""This does mean some culling for sound habitat conservation reasons, but not to the extent where it threatens a population.\n""We know in broad terms the numbers of hares and it is a population that is thriving.""\nScottish Natural Heritage has pledged to work with ""all interested parties"" to make sure that management practices are not damaging the species' long-term prospects.\nResponding to the call for a moratorium, a spokeswoman said: ""Mountain hare populations can fluctuate widely and are influenced by a range of factors - the cyclical nature of the species, habitat fragmentation, changes in land use and over-culling\n""We have already asked estates for restraint on large-scale culls of mountain hares which could jeopardise their conservation status.\n""We continue to gather evidence on numbers, status and changes and for this reason have commissioned further research by Game and Wildlife Conservation Trust and the James Hutton Institute to inform local and national management.""",Ten wildlife and conservation organisations are calling for a three-year ban on the culling of mountain hares on Scotland's grouse moors.,32292965
4,"The number 91 bus hit the trees on Kingsway, near the London School of Economics (LSE) building, in Holborn.\nLondon Ambulance Service said two people were taken to hospital with facial injuries. Two others were treated for minor injuries.\nKingsway was closed between the Great Queen Street and A4 Aldwych junctions, but has since re-opened.\nAt the scene, London Fire Brigade station manager, Gary Squires, said: ""Those involved were very lucky to escape serious injury.""\nLSE student Ethan Meade said he turned around when he heard a crash.\n""I saw the roof fall down off the side of the bus and the glass shatter everywhere.\n""The passengers seemed to be sitting there pretty stunned, as you'd expect. Police seemed to handle it very well.""",The roof of a bus has been ripped off after it hit overhanging trees in central London.,31097083


The metric is an instance of [`datasets.Metric`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Metric):

metric应该是一种测评函数，输入预测值和baseline，输出的是预测的准确程度，比如recall，precision，f1之类的

In [12]:
metric

Metric(name: "rouge", features: {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}, usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each predictions
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLSum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/datasets/issues/617
    use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes.
    use_agregator: Return aggregates if this is set to True
Retu

You can call its `compute` method with your predictions and labels, which need to be list of decoded strings:

In [13]:
fake_preds = ["hello there", "general kenobi"]
fake_labels = ["hello there", "general kenobi"]
metric.compute(predictions=fake_preds, references=fake_labels)

{'rouge1': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rouge2': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rougeL': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)),
 'rougeLsum': AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0))}

## Preprocessing the data

Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that the model requires.

将输入的单词转化为对应的id号

To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:

- we get a tokenizer that corresponds to the model architecture we want to use,


- we download the vocabulary used when pretraining this specific checkpoint.

我们得到一个与我们想要使用的模型架构相对应的分词器，
我们下载预训练这个特定检查点时使用的词汇。

That vocabulary will be cached, so it's not downloaded again the next time we run the cell.

词汇会被保存下来，这样下次就不用重新下载了

In [14]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading:   0%|          | 0.00/1.17k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/773k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

By default, the call above will use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library.

You can directly call this tokenizer on one sentence or a pair of sentences:

In [15]:
tokenizer("Hello, this one sentence!")

{'input_ids': [8774, 6, 48, 80, 7142, 55, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}

Depending on the model you selected, you will see different keys in the dictionary returned by the cell above. They don't matter much for what we're doing here (just know they are required by the model we will instantiate later), you can learn more about them in [this tutorial](https://huggingface.co/transformers/preprocessing.html) if you're interested.

Instead of one sentence, we can pass along a list of sentences:

In [16]:
tokenizer(["Hello, this one sentence!", "This is another sentence."])

{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}

To prepare the targets for our model, we need to tokenize them inside the `as_target_tokenizer` context manager. This will make sure the tokenizer uses the special tokens corresponding to the targets:

标记输出为对应的id

In [17]:
with tokenizer.as_target_tokenizer():
    print(tokenizer(["Hello, this one sentence!", "This is another sentence."]))

{'input_ids': [[8774, 6, 48, 80, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}


If you are using one of the five T5 checkpoints we have to prefix the inputs with "summarize:" (the model can also translate and it needs the prefix to know which task it has to perform).

In [18]:
if model_checkpoint in ["t5-small", "t5-base", "t5-larg", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model. The padding will be dealt with later on (in a data collator) so we pad examples to the longest length in the batch and not the whole dataset.

In [19]:
max_input_length = 1024
max_target_length = 128

def preprocess_function(examples):
    '''
    将训练集的数据（text和summary）转化为对应的id
    '''
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples["summary"], max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:

In [20]:
preprocess_function(raw_datasets['train'][:2])

{'input_ids': [[21603, 10, 37, 423, 583, 13, 1783, 16, 20126, 16496, 6, 80, 13, 8, 844, 6025, 4161, 6, 19, 341, 271, 14841, 5, 7057, 161, 19, 4912, 16, 1626, 5981, 11, 186, 7540, 16, 1276, 15, 2296, 7, 5718, 2367, 14621, 4161, 57, 4125, 387, 5, 15059, 7, 30, 8, 4653, 4939, 711, 747, 522, 17879, 788, 12, 1783, 44, 8, 15763, 6029, 1813, 9, 7472, 5, 1404, 1623, 11, 5699, 277, 130, 4161, 57, 18368, 16, 20126, 16496, 227, 8, 2473, 5895, 15, 147, 89, 22411, 139, 8, 1511, 5, 1485, 3271, 3, 21926, 9, 472, 19623, 5251, 8, 616, 12, 15614, 8, 1783, 5, 37, 13818, 10564, 15, 26, 3, 9, 3, 19513, 1481, 6, 18368, 186, 1328, 2605, 30, 7488, 1887, 3, 18, 8, 711, 2309, 9517, 89, 355, 5, 3966, 1954, 9233, 15, 6, 113, 293, 7, 8, 16548, 13363, 106, 14022, 84, 47, 14621, 4161, 6, 243, 255, 228, 59, 7828, 8, 1249, 18, 545, 11298, 1773, 728, 8, 8347, 1560, 5, 611, 6, 255, 243, 72, 1709, 1528, 161, 228, 43, 118, 4006, 91, 12, 766, 8, 3, 19513, 1481, 410, 59, 5124, 5, 96, 196, 17, 19, 1256, 68, 27, 103, 317, 132

To apply this function on all the pairs of sentences in our dataset, we just use the `map` method of our `dataset` object we created earlier. This will apply the function on all the elements of all the splits in `dataset`, so our training, validation and testing data will be preprocessed in one single command.

应用map方法，对数据集的所有句子都进行预处理。

In [21]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

  0%|          | 0/205 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass `load_from_cache_file=False` in the call to `map` to not use the cached files and force the preprocessing to be applied again.

Datasets library可以储存你预处理数据的结果，这样可以避免每次都预处理的浪费。并且它还可以检测你是否变化了map里面传递的函数。

load_from_cache_file=False，表示每次运行都预处理数据，不存储。

Note that we passed `batched=True` to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.

请注意，我们传递了 batched=True 来对文本进行批量编码。

这是为了充分利用我们之前加载的快速分词器的全部优势，它将使用多线程同时处理批处理中的文本。

## Fine-tuning the model

Now that our data is ready, we can download the pretrained model and fine-tune it. Since our task is of the sequence-to-sequence kind, we use the `AutoModelForSeq2SeqLM` class. Like with the tokenizer, the `from_pretrained` method will download and cache the model for us.

选择模型。

In [22]:
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

Downloading:   0%|          | 0.00/231M [00:00<?, ?B/s]

Note that  we don't get a warning like in our classification example. This means we used all the weights of the pretrained model and there is no randomly initialized head in this case.

请注意，我们不会像分类示例中那样收到警告。 这意味着我们使用了预训练模型的所有权重，并且在这种情况下没有随机初始化的头部。



To instantiate a `Seq2SeqTrainer`, we will need to define three more things. The most important is the [`Seq2SeqTrainingArguments`](https://huggingface.co/transformers/main_classes/trainer.html#transformers.Seq2SeqTrainingArguments), which is a class that contains all the attributes to customize the training. It requires one folder name, which will be used to save the checkpoints of the model, and all other arguments are optional:

要实例化 Seq2SeqTrainer，我们需要再定义三件事。

最重要的是 Seq2SeqTrainingArguments，它是一个包含自定义训练的所有属性的类。

它需要一个文件夹名称，用于保存模型的检查点，所有其他参数都是可选的：

In [23]:
batch_size = 16 #batch_size
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-xsum",
    evaluation_strategy = "epoch",
    learning_rate=2e-5, #学习率
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01, #权重衰减率0.01
    save_total_limit=3, # 最多保存三次模型
    num_train_epochs=1, # 训练轮数
    predict_with_generate=True, # 正确生成摘要
    fp16=True, #混合精度 更快一点
    # push_to_hub=True,
)

Here we set the evaluation to be done at the end of each epoch, tweak the learning rate, use the `batch_size` defined at the top of the cell and customize the weight decay. Since the `Seq2SeqTrainer` will save the model regularly and our dataset is quite large, we tell it to make three saves maximum. Lastly, we use the `predict_with_generate` option (to properly generate summaries) and activate mixed precision training (to go a bit faster).

The last argument to setup everything so we can push the model to the [Hub](https://huggingface.co/models) regularly during training. Remove it if you didn't follow the installation steps at the top of the notebook. If you want to save your model locally in a name that is different than the name of the repository it will be pushed, or if you want to push your model under an organization and not your name space, use the `hub_model_id` argument to set the repo name (it needs to be the full name, including your namespace: for instance `"sgugger/t5-finetuned-xsum"` or `"huggingface/t5-finetuned-xsum"`).

Then, we need a special kind of data collator, which will not only pad the inputs to the maximum length in the batch, but also the labels:

在这里，我们将评估设置为在每个 epoch 结束时进行，调整学习率，使用在单元格顶部定义的 batch_size 并自定义权重衰减。

由于 Seq2SeqTrainer 会定期保存模型并且我们的数据集非常大，我们告诉它最多保存 3 次。

最后，我们使用 predict_with_generate 选项（正确生成摘要）并激活混合精度训练（更快一点）。


设置所有内容的最后一个参数，以便我们可以在训练期间定期将模型推送到 Hub。

如果您没有按照笔记本顶部的安装步骤将其删除。

如果您想以与将要推送的存储库的名称不同的名称在本地保存模型，或者如果您想将模型推送到组织而不是名称空间下，请使用 hub_model_id 参数设置存储库名称（它必须是全名，包括您的命名空间：例如“sgugger/t5-finetuned-xsum”或“huggingface/t5-finetuned-xsum”）。

然后，我们需要一种特殊的数据整理器，它不仅可以将输入填充到批处理中的最大长度，还可以填充标签：

In [24]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

The last thing to define for our `Seq2SeqTrainer` is how to compute the metrics from the predictions. We need to define a function for this, which will just use the `metric` we loaded earlier, and we have to do a bit of pre-processing to decode the predictions into texts:

最有一件事情，怎么计算我们预测出来的结果的分数，也就是metrics。

我们需要定一个函数来进行一些数据处理。

In [25]:
import nltk
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

Then we just need to pass all of this along with our datasets to the `Seq2SeqTrainer`:

将我们上述定义的所有东西，一起传递给Seq2SeqTrainer。

In [26]:
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

Using amp half precision backend


We can now finetune our model by just calling the `train` method:

train函数微调

In [27]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: document, summary, id.
***** Running training *****
  Num examples = 204045
  Num Epochs = 1
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 12753


RuntimeError: ignored

You can now upload the result of the training to the Hub, just execute this instruction:

上传结果到hub上面。

In [None]:
# trainer.push_to_hub()

You can now share this model with all your friends, family, favorite pets: they can all load it with the identifier `"your-username/the-name-you-picked"` so for instance:

```python
from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained("sgugger/my-awesome-model")
```

可以分享给大家都知道了……