<a href="https://colab.research.google.com/github/gcosma/COP509/blob/main/Tutorial8Summarization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers and 🤗 Datasets as well as other dependencies. Uncomment the following cell and run it. Note the `rouge-score` and `nltk` dependencies - even if you've used 🤗 Transformers before, you may not have these installed!

In [1]:
! pip install transformers datasets
! pip install rouge-score nltk
! pip install huggingface_hub

Collecting datasets
  Downloading datasets-3.3.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.3.1-py3-none-any.whl (484 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m484.9/484.9 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading multiprocess-0.70.16-py311-none-any.whl (143 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading x

If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.

To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.

First you have to store your authentication token from the Hugging Face website (sign up [here](https://huggingface.co/join) if you haven't already!) then run the following cell and input your token:

In [13]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Then you need to install Git-LFS and setup Git if you haven't already. Uncomment the following instructions and adapt with your name and email:

In [14]:
!apt install git-lfs
!git config --global user.email "g.cosma@lboro.ac.uk"
!git config --global user.name "gcosma"

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
git-lfs is already the newest version (3.0.2-1ubuntu0.3).
0 upgraded, 0 newly installed, 0 to remove and 20 not upgraded.


In [15]:
!apt install git-lfs
!git config --global user.email "g.cosma@lboro.ac.uk"
!git config --global user.name "gcosma"

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
git-lfs is already the newest version (3.0.2-1ubuntu0.3).
0 upgraded, 0 newly installed, 0 to remove and 20 not upgraded.


In [16]:
!git config --global user.name
!git config --global user.email


gcosma
g.cosma@lboro.ac.uk


In [17]:
!git config --global user.name

gcosma


In [18]:
!git config --global user.email

g.cosma@lboro.ac.uk


Make sure your version of Transformers is at least 4.16.0 since some of the functionality we use was introduced in that version:

In [19]:
import transformers

print(transformers.__version__)

4.48.3


You can find a script version of this notebook to fine-tune your model in a distributed fashion using multiple GPUs or TPUs [here](https://github.com/huggingface/transformers/tree/main/examples/tensorflow/summarization).

We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

In [20]:
from transformers.utils import send_example_telemetry

send_example_telemetry("summarization_notebook", framework="tensorflow")

# Fine-tuning a model on a summarization task

In this notebook, we will see how to fine-tune one of the [🤗 Transformers](https://github.com/huggingface/transformers) model for a summarization task. We will use the [XSum dataset](https://arxiv.org/pdf/1808.08745.pdf) (for extreme summarization) which contains BBC articles accompanied with single-sentence summaries.

![Widget inference on a summarization task](https://github.com/huggingface/notebooks/blob/main/examples/images/summarization.png?raw=1)

We will see how to easily load the dataset for this task using 🤗 Datasets and how to fine-tune a model on it using Keras.

In [21]:
model_checkpoint = "t5-small"

This notebook is built to run  with any model checkpoint from the [Model Hub](https://huggingface.co/models) as long as that model has a sequence-to-sequence version in the Transformers library. Here we pick the [`t5-small`](https://huggingface.co/t5-small) checkpoint.

In [5]:
!pip install -U datasets
!pip install -U evaluate
!pip install -U rouge_score
!pip install -U nltk



## Loading the dataset

We will use the [🤗 Datasets](https://github.com/huggingface/datasets) library to download the data and get the metric we need to use for evaluation (to compare our model to the benchmark). This can be easily done with the functions `load_dataset` and `load_metric`.  

In [2]:
from datasets import load_dataset
import evaluate

# Load the XSum dataset
raw_datasets = load_dataset("xsum", trust_remote_code=True)

# Load the ROUGE metric
rouge = evaluate.load('rouge')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/6.24k [00:00<?, ?B/s]

xsum.py:   0%|          | 0.00/5.76k [00:00<?, ?B/s]

The repository for xsum contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/xsum.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N] y


(…)SUM-EMNLP18-Summary-Data-Original.tar.gz:   0%|          | 0.00/255M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.72M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/204045 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11332 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11334 [00:00<?, ? examples/s]

Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

The `dataset` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set:

In [3]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 204045
    })
    validation: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11332
    })
    test: Dataset({
        features: ['document', 'summary', 'id'],
        num_rows: 11334
    })
})

To access an actual element, you need to select a split first, then give an index:

In [4]:
raw_datasets["train"][0]

 'summary': 'Clean-up operations are continuing across the Scottish Borders and Dumfries and Galloway after flooding caused by Storm Frank.',
 'id': '35232142'}

To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset.

In [5]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML


def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(
        dataset
    ), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset) - 1)
        while pick in picks:
            pick = random.randint(0, len(dataset) - 1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [6]:
show_random_elements(raw_datasets["train"])

Unnamed: 0,document,summary,id
0,"Excluding petrol, sales rose 0.2% in March, against a 0.6% rise in February, which has been revised down from an initial estimated of 0.7%,\nFor the first three months of 2015, sales rose 0.9%, down from 2.2% in the first quarter of 2014.\nThe figures show consumers are still cautious about spending, analysts said.\nKeith Richardson, managing director for retail at Lloyds Bank Commercial Banking, said: ""Even with continued falls in fuel and food prices, consumers are responding to this current period of uncertainty by being just as careful about their own spending as they have been for the past few years.\n""Despite the fact that Mother's Day fell in March and Easter fell early in April, this wasn't enough to bring forward any boost in spending into March, doing nothing to allay fears that while consumers may have a little more money in their pockets, they are spending it on leisure treats like eating out and going on holiday, rather than on High Street goods,"" he said.\nUK economic growth figures for the first three months of 2015 are due to be published next week. Economists said the retail data could herald slower growth.\nAlan Clarke, at Scotiabank, said: ""The monthly data all point towards sluggish Q1 GDP next Tuesday, not the sort of reading that the coalition government will be hoping for.""\nBut Howard Archer, chief UK and European economist at IHS Global Insight, said that although the retail data was ""disappointing"", wage growth and low inflation should bolster consumer spending over the coming months.\n""Despite March's weaker-than-expected performance, the prospects for retail sales and consumer spending look bright, as purchasing power has strengthened and should continue to do so,"" Mr Archer said.","UK retail sales fell 0.5% in March from February, dragged down by a 6.2% fall in sales at petrol stations, the Office for National Statistics said.",32428935
1,"Sparked by hitting his first ball for six, off Jeetan Patel, Finch raced to 110, slowed only by a brief rain delay.\nBut Patel's five-wicket haul helped the Bears bowl out Surrey for 273, the last five wickets going down for just 30.\nAt the close, Warwickshire were 12-0 in reply with Varun Chopra seven not out and Andy Umeed unbeaten on three.\nAfter Surrey had won the toss, opener Rory Burns made 45 before being trapped lbw to become the first victim for Patel (5-62).\nAt 79-2, that brought in Finch, who then plundered his 110 from just 98 balls to move Surrey on to 243-5.\nBut, once he had been smartly taken by Surrey old boy Rikki Clarke, off Patel, the hosts lost their last five wickets in little more than half an hour.\nBoth sides are searching for only their second four-day win of the season and, if the weather holds, a result looks likely.\nSurrey batsman Aaron Finch told BBC Radio London:\n""I can't think I've hit my first ball for six before, other than in a one-day game. I remember doing it in a T20 international.\n""One of my things as a batsman is to be as positive as you can, and to take it on if the ball is there for the shot - whatever ball it is you're facing. Luckily, I got away with it this time.\n""It's obviously nice to get a few in my first innings for Surrey, although I'd have liked to have got a few more.\n""I reckon we've got a pretty decent score on that pitch. We battled hard for a lot of the day before losing wickets quickly towards the end. There's something in it for the bowlers if they get it in the right area.""",Warwickshire enjoyed a successful day in the field at Guildford despite Australian international Aaron Finch's stylish century on his Surrey debut.,36684020
2,"Firefighters were called after smoke and fire were seen coming from the ITV soap's set in Trafford Park at about 21:25 BST on Tuesday.\nA spokesman said the explosion was a ""false alarm"".\nGreater Manchester Fire and Rescue Service (GMFRS) tweeted: ""The Coronation Street set IS NOT on fire!""","Flames engulfed the set of Coronation Street after an ""explosion and fireball"" - which turned out to be special effects for a storyline.",32314167
3,"Ekeng died aged 26 of a suspected heart attack after collapsing on the pitch playing for Dinamo Bucharest.\nCameroon players past and present attended the service, and gave support to Ekeng's wife and family.\nThe country's sports minister Bidoung Kpwatt also attended, representing Cameroon's president Paul Biya.\nEkeng's brother Jacques Ekeng, paid tribute to ""Patou"" as he was known.\n""He gave more love to his friends than to his family. That is Patrick, he was generous,"" Jacques Ekeng said.\n""No matter the beauty, riches and cars, you cannot be saved from death,"" he told his brother's friends.\n""My brother is dead with a smile, and this is what comforts me.""\nEkeng's club Dinamo Bucharest said they would honour the player's memory by sending the Romanian Cup to his family in Cameroon if they win the trophy.\nThe 13-time winners face CFR Cluj in the postponed final on 17 May.","The funeral of Cameroon international Patrick Ekeng has been held in Yaounde, following his death during a league match in Romania on 6 May.",36294260
4,"Their position is only 19 places better than their worst ever ranking of 160th place, which they occupied in 2010.\nThe west African country have been has high as 73rd spot, in 2001, when they could still field former world player of the year George Weah.\nNigeria and Tunisia have swapped in the only moves in the African top 10.\nAfrica's top 10 in Fifa's rankings for April (last month's rankings in brackets):\n1 (1) Egypt\n2 (2) Senegal\n3 (3) Cameroon\n4 (4) Burkina Faso\n5 (7) Nigeria\n6 (6) DR Congo\n7 (5) Tunisia\n8 (8) Ghana\n9 (9) Ivory Coast\n10 (10) Morocco","Liberia were the biggest losers in Fifa's world rankings for April, dropping 39 places to sit in 141st spot - and a lowly 40th best in Africa.",39504425


In [10]:
import datasets
import random
import pandas as pd
from IPython.display import display, HTML
from datasets import load_dataset
import evaluate

# Load dataset and metric
raw_datasets = load_dataset("xsum", trust_remote_code=True)
rouge = evaluate.load('rouge')

def show_random_elements(dataset, num_examples=5):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."

    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset) - 1)
        while pick in picks:
            pick = random.randint(0, len(dataset) - 1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])

    display(HTML(df.to_html()))

# Display random examples from the training set
show_random_elements(raw_datasets["train"])

Unnamed: 0,document,summary,id
0,"The France international, 29, has made only nine Premier League starts in his two seasons at Stamford Bridge following a £10.5m move from QPR.\nRemy played with Palace boss Alan Pardew on loan at Newcastle in 2013-14, scoring 14 goals.\n""It was very important to know Alan Pardew as he is a very good manager and I am happy to be here,"" said Remy.\nPardew said: ""Loic has been a target of mine throughout this transfer window and I'm delighted the deal has been done.\n""I brought him to Newcastle so I know what he is capable of and I am convinced he will be a quality addition to our squad as we evolve.\n""Loic has international and Champions League experience as well as being a Premier League title winner and is the latest example of the high calibre of players we have brought into the club during this transfer window.""\nIt is understood that Tottenham have now ended their interest in signing Crystal Palace winger Wilfried Zaha this summer.\nSpurs had bid £15m for the 23-year-old England international but a move was blocked by Palace chairman Steve Parish.\nZaha rejoined Palace in February 2015 for £3m just two years after a £15m move to Old Trafford,",Crystal Palace have signed Chelsea striker Loic Remy on loan for the season.,37223115
1,"Nick Bennett's comments follow a report that found such abuse of patients at Tawel Fan mental health ward in Glan Clwyd Hospital, Denbighshire.\nA newly-published separate review described the local health board as having a ""bullying culture"".\nThe Betsi Cadwaladr board was placed under special measures on Monday.\nMr Bennett told BBC Wales' Y Sgwrs programme on S4C changes are needed to the current system so his office can launch investigations without having to wait for a complaint to be made first.\n""I think it's really important, and a test of a mature democracy, that you have independent institutions that are capable of scrutinising public institutions in Wales to make sure that this type of abuse doesn't occur again"", he said.\n""I think there are a number of reforms that could be put in place including 'own initiative' powers for my own office"", Mr Bennett added.\nMore details have also emerged of the crisis at Betsi Cadwaladr health board, that led up to the decision to put it under more direct Welsh government control.\nA newly published review by former head of the Welsh NHS Ann Lloyd says the health board's chair described it as having a ""bureaucratic and bullying culture"".\nThe report also raises financial concerns and warns a ""mammoth effort"" was needed to make the changes needed.\nIn the document, summing up the view of the board's chair Peter Higson, Ms Lloyd says: ""He is very concerned about the lack of creativity within the organisation and considers that the organisation has a rigid, overly bureaucratic and bullying culture.""\nDescribing a series of clinical, organisational, managerial and financial problems, she observes: ""It will take a mammoth effort on behalf of the whole of the executive team to enable the organisation to improve.""\nIn the report, Ms Lloyd says ""there is a great deal of work needed to bring the (mental health) services up to the standard required"".\nShe also highlights the findings of a review by the Good Governance Institute last year.\nIt concluded members of the Local Health Board were ""not seen as adding value to the organisation"", there was ""no clear direction/strategy or corporate objectives"" and departments were ""setting their own objectives and timescales"".\nOn Tuesday, Mr Drakeford announced the deputy chief executive of the Welsh NHS Simon Dean was to take charge of the board after its chief executive was suspended.\nMore on Nick Bennett's comments on Y Sgwrs on S4C at 21:30 on Wednesday 10 June.","Reforms are needed to ""minimise the risk of institutional abuse ever occurring"", the public services ombudsman for Wales has said.",33069560
2,"Mae'r arolwg yn dod i'r casgliad fod llai na 10% o ddisgyblion blwyddyn 10 yn astudio iaith dramor yn nhraean ysgolion Cymru.\nDywedodd y Cyngor Prydeinig yng Nghymru fod hyn yn gosod ""her aruthrol"".\nMae cynllun ar y gweill yn barod i geisio gwella'r sefyllfa, medd Llywodraeth Cymru.\nCasgliadau eraill\nRhwng 2002 a 2016, bu gostyngiad o 48% yn nifer y disgyblion oedd yn astudio iaith dramor ar gyfer TGAU, i 6,891 y llynedd.\nO ran Safon Uwch, mae'r canran wedi gostwng 44% ers 2001.\nMae'r adroddiad yn dweud fod y rhagolygon ar gyfer ieithoedd tramor ""yn edrych hyd yn oed yn fwy bregus o ystyried y pwysau ariannol sydd ar ysgolion ac effaith posib gadael yr Undeb Ewropeaidd"".\nYchwanegodd fod diffyg athrawon yn fygythiad pellach i ddarpariaeth ieithoedd tramor.\nDywedodd fod y sefyllfa'n debygol o fod yn ""ddifrifol"" os na fydd athrawon o wledydd eraill y UE ar gael yn dilyn Brexit, gan fod 34% o ysgolion yn dibynnu ar yr athrawon hynny.\nDaw'r arolwg 18 mis wedi i Lwyodraeth Cymru ddechrau ar gynllun Dyfodol Byd-eang i wella a hyrwyddo ieithoedd tramor yng Nghymru.\nDywedodd llefarydd ar ran Llywodraeth Cymru: ""Yn ystod y misoedd diwethaf yn unig, mae'r llywodraeth wedi sicrhau nawdd pellach ar gyfer prosiectau mentora ysgolion cenedlaethol, wedi ei arwain gan academyddion a myfyrwyr iaith, ac mae wedi arwyddo cytundeb ar ddysgu ieithoedd gyda Llywodraeth Sbaen.\n""Rydym yn credu fod ieithoedd yn chwarae rôl bwysig wrth roi addysg gyflawn i berson ifanc, i'w cefnogi i ddod yn ddinasyddion sy'n gallu cyfathrebu'n effeithiol mewn ieithoedd eraill a gwerthfawrogi diwylliannau eraill.""\nArolwg 2016-17 yw'r trydydd adroddiad blynyddol ar sefyllfa dysgu ieithoedd mewn ysgolion uwchradd yng Nghymru.","Mae athrawon yng Nghymru'n ""poeni'n fawr"" am ddyfodol ieithoedd tramor, yn ôl arolwg gan y Cyngor Prydeinig.",40473647
3,"Trade group the International Organisation of Vine and Wine (OIV) has estimated output will reach 259.5m hectolitres (mhl) this year - a fall of 5% compared with 2015.\nThis would be among the three poorest years for output since 2000, it added.\nHowever, it said it was unlikely this would affect prices in the shops.\n""Some price tensions could appear in some geographic areas impacted by [bad weather],"" an OIV spokesperson told the BBC.\n""But very often wine producers and wine industries keep wine stocks in order to respond to such risks.""\nIn Europe, the OIV said Italy was again set to be the world's leading producer - even though output is expected to fall 2% this year to 48.8 mhl.\nBut in France - the number two producer - it said production was likely to fall 12% after vineyards endured frost and hailstorms in the spring, then drought in the summer.\nIn South America, production was also hit by ""climatic events"".\nAs a result, Argentina is likely to report a 35% plunge in output, Chile a 21% fall and Brazil a 50% fall when compared with 2015.\nSouth Africa, meanwhile, is on track to report a bruising 19% fall in production - but other New World producers are doing better.\nThe OIV said Australia was expected to see a 5% rise in production, New Zealand a 34% jump and the US - the world's fourth-largest producer of wine - growth of 2%.","World wine output is expected to hit a four-year low in 2016 after bad weather hit production in France and South America, industry forecasts say.",37715354
4,"One year ago, the federal Liberals swept to victory in Canada and Justin Trudeau became prime minister-elect.\nAt the time, it was considered a surprise, decisive, win.\nTwelve months later Mr Trudeau remains in an extended honeymoon with the voting public, who like his accessible style and ability to garner glowing international reviews.\nCanadians are also feeling optimistic about the direction of their country, more so than a number of other nations.\nLorne Bozinoff, founder of the polling firm Forum Research, says the prime minister has done well ""both in substance and style"" over the past year - charming Canadians with his open manner and moving ahead on key election commitments.\n""The Liberals - and Justin Trudeau in particular - are off to a very good start,"" he said.\nThe prime minister has launched an inquiry into missing and murdered indigenous women in Canada, has followed through on a promise to bring in 30,000 Syrian refugees, has reduced taxes on middle-class Canadians, and has an equal number of men and women in his cabinet - a move justified with his now famous ""because it's 2015"" quip during his swearing-in 4 November.\nStill, there are possible clouds on the horizon. His first year in office has been successful but not trouble-free.\nA number of his cabinet ministers and senior aides have been caught in minor spending scandals. His self-proclaimed feminist credentials have been tested by an appearance at a gender-segregated mosque in Ottawa.\nHis government has faced repeated questions over a CA$15bn ($11.4bn/Â£9.3bn) deal to sell light-armoured vehicles to Saudi Arabia despite fierce criticism from human rights groups.\nThe Liberals have failed to follow through on an order by the Canadian Human Rights Tribunal to fix funding disparities for child welfare services on First Nations.\nIndigenous Canadians are growing increasingly impatient at how long it is taking Mr Trudeau's government to move forward on fixing persistent First Nations problems, from education to water infrastructure.\nThe federal public service union is threatening to pull out of stalled contract talks.\nAnd provincial premiers are grumbling openly about a coming federal carbon tax and arguing for more money for healthcare funding.\nThose bumps in the road have yet to be reflected in Mr Trudeau's polling. He currently holds a 56% approval rating.\n""People will cut you a lot more more slack if they just like a person,"" said Mr Bozinoff. ""He has a reservoir of goodwill going into a second year.""\nWhile Mr Trudeau promised ""sunny ways"" on election night, he shows flashes of his pugilistic side.\nHe told the BBC proudly in November that he left his Canadian critics ""in the dust"".\nIn May, he elbowed a fellow member of Parliament and tugged the arm of another after growing impatient over what was seen as an attempt to slow the course of a vote in the House of Commons.\nEven before becoming prime minister, he demonstrated a willingness to be politically ruthless.\nIn 2015, Mr Trudeau permanently expelled two MPs accused of sexual harassment from the Liberal fold. The year before, he removed all Liberal senators from the party caucus to show he was serious about reducing partisanship in the Senate.\nIt is that part of his temperament that should prove useful as he enters into a more difficult second year.\nTwo opposition parties - the Conservatives and the New Democrats - will have new leaders in place ready to challenge Mr Trudeau in 2017 as his government faces increasingly difficult policy challenges.\nHis party has promised to legalize marijuana in 2017, though they still have to figure out how the potential CA$10bn ($7.6bn/Â£ 6.2bn market for pot will be regulated.\nThe Liberals will have to follow through on the politically divisive promise to change the way Canadians vote in federal elections.\nAnd Mr Trudeau will have to make decisions on hot-button issues like energy pipelines and anti-terrorism laws.\n""He's getting into the type of issues in this second year that are less consensus based,"" Mr Bozinoff said.""The free ride the liberals have been getting - that will disappear.""","It was an election that launched a thousand selfies, including shirtless ones.",37683186


In [15]:
import datasets
from datasets import load_dataset
import evaluate

# Load dataset
raw_datasets = load_dataset("xsum", trust_remote_code=True)

# Load metric using evaluate instead of datasets
rouge = evaluate.load('rouge')

# Example of using compute with predictions and references
fake_preds = ["hello there", "general kenobi"]
fake_labels = ["hello there", "general kenobi"]
results = rouge.compute(predictions=fake_preds, references=fake_labels)

print(results)

{'rouge1': 1.0, 'rouge2': 1.0, 'rougeL': 1.0, 'rougeLsum': 1.0}


You can call its `compute` method with your predictions and labels, which need to be list of decoded strings:

## Preprocessing the data

Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that the model requires.

To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:

- we get a tokenizer that corresponds to the model architecture we want to use,
- we download the vocabulary used when pretraining this specific checkpoint.

That vocabulary will be cached, so it's not downloaded again the next time we run the cell.

In [17]:
!pip install transformers



In [18]:
from transformers import AutoTokenizer

# Specify the model checkpoint (e.g., using a T5 model suitable for summarization)
model_checkpoint = "t5-small"  # or another appropriate model name

# Initialize the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

By default, the call above will use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library.

You can directly call this tokenizer on one sentence or a pair of sentences:

In [19]:
tokenizer("Hello, this is a sentence!")

{'input_ids': [8774, 6, 48, 19, 3, 9, 7142, 55, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}

Depending on the model you selected, you will see different keys in the dictionary returned by the cell above. They don't matter much for what we're doing here (just know they are required by the model we will instantiate later), you can learn more about them in [this tutorial](https://huggingface.co/transformers/preprocessing.html) if you're interested.

Instead of one sentence, we can pass along a list of sentences:

In [20]:
tokenizer(["Hello, this is a sentence!", "This is another sentence."])

{'input_ids': [[8774, 6, 48, 19, 3, 9, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}

To prepare the targets for our model, we need to tokenize them inside the `as_target_tokenizer` context manager. This will make sure the tokenizer uses the special tokens corresponding to the targets:

In [21]:
with tokenizer.as_target_tokenizer():
    print(tokenizer(["Hello, this is a sentence!", "This is another sentence."]))

{'input_ids': [[8774, 6, 48, 19, 3, 9, 7142, 55, 1], [100, 19, 430, 7142, 5, 1]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]}




If you are using one of the five T5 checkpoints we have to prefix the inputs with "summarize:" (the model can also translate and it needs the prefix to know which task it has to perform).

In [22]:
if model_checkpoint in ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]:
    prefix = "summarize: "
else:
    prefix = ""

We can then write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. This will ensure that an input longer that what the model selected can handle will be truncated to the maximum length accepted by the model. The padding will be dealt with later on (in a data collator) so we pad examples to the longest length in the batch and not the whole dataset.

In [23]:
max_input_length = 1024
max_target_length = 128


def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["document"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["summary"], max_length=max_target_length, truncation=True
        )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:

In [24]:
preprocess_function(raw_datasets["train"][:2])

{'input_ids': [[21603, 10, 37, 423, 583, 13, 1783, 16, 20126, 16496, 6, 80, 13, 8, 844, 6025, 4161, 6, 19, 341, 271, 14841, 5, 7057, 161, 19, 4912, 16, 1626, 5981, 11, 186, 7540, 16, 1276, 15, 2296, 7, 5718, 2367, 14621, 4161, 57, 4125, 387, 5, 15059, 7, 30, 8, 4653, 4939, 711, 747, 522, 17879, 788, 12, 1783, 44, 8, 15763, 6029, 1813, 9, 7472, 5, 1404, 1623, 11, 5699, 277, 130, 4161, 57, 18368, 16, 20126, 16496, 227, 8, 2473, 5895, 15, 147, 89, 22411, 139, 8, 1511, 5, 1485, 3271, 3, 21926, 9, 472, 19623, 5251, 8, 616, 12, 15614, 8, 1783, 5, 37, 13818, 10564, 15, 26, 3, 9, 3, 19513, 1481, 6, 18368, 186, 1328, 2605, 30, 7488, 1887, 3, 18, 8, 711, 2309, 9517, 89, 355, 5, 3966, 1954, 9233, 15, 6, 113, 293, 7, 8, 16548, 13363, 106, 14022, 84, 47, 14621, 4161, 6, 243, 255, 228, 59, 7828, 8, 1249, 18, 545, 11298, 1773, 728, 8, 8347, 1560, 5, 611, 6, 255, 243, 72, 1709, 1528, 161, 228, 43, 118, 4006, 91, 12, 766, 8, 3, 19513, 1481, 410, 59, 5124, 5, 96, 196, 17, 19, 1256, 68, 27, 103, 317, 132

To apply this function on all the pairs of sentences in our dataset, we just use the `map` method of our `dataset` object we created earlier. This will apply the function on all the elements of all the splits in `dataset`, so our training, validation and testing data will be preprocessed in one single command.

In [25]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True)

Map:   0%|          | 0/204045 [00:00<?, ? examples/s]

Map:   0%|          | 0/11332 [00:00<?, ? examples/s]

Map:   0%|          | 0/11334 [00:00<?, ? examples/s]

Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass `load_from_cache_file=False` in the call to `map` to not use the cached files and force the preprocessing to be applied again.

Note that we passed `batched=True` to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.

## Fine-tuning the model

Now that our data is ready, we can download the pretrained model and fine-tune it. Since our task is sequence-to-sequence (both the input and output are text sequences), we use the `AutoModelForSeq2SeqLM` class. Like with the tokenizer, the `from_pretrained` method will download and cache the model for us.

In [26]:
from transformers import TFAutoModelForSeq2SeqLM, DataCollatorForSeq2Seq

model = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

config.json:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/242M [00:00<?, ?B/s]

All PyTorch model weights were used when initializing TFT5ForConditionalGeneration.

All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


Note that  we don't get a warning like in our classification example. This means we used all the weights of the pretrained model and there is no randomly initialized head in this case.

Next we set some parameters like the learning rate and the `batch_size`and customize the weight decay.

The last two arguments are to setup everything so we can push the model to the [Hub](https://huggingface.co/models) at the end of training. Remove the two of them if you didn't follow the installation steps at the top of the notebook, otherwise you can change the value of push_to_hub_model_id to something you would prefer.

In [27]:
batch_size = 8
learning_rate = 2e-5
weight_decay = 0.01
num_train_epochs = 1

model_name = model_checkpoint.split("/")[-1]
push_to_hub_model_id = f"{model_name}-finetuned-xsum"

Then, we need a special kind of data collator, which will not only pad the inputs to the maximum length in the batch, but also the labels. Note that our data collators are designed to work for multiple frameworks, so ensure you set the `return_tensors='np'` argument to get NumPy arrays out - you don't want to accidentally get a load of `torch.Tensor` objects in the middle of your nice TF code! You could also use `return_tensors='tf'` to get TensorFlow tensors, but our TF dataset pipeline actually uses a NumPy loader internally, which is wrapped at the end with a `tf.data.Dataset`. As a result, `np` is usually more reliable and performant when you're using it!

We also want to compute `ROUGE` metrics, which will require us to generate text from our model. To speed things up, we can compile our generation loop with XLA. This results in a *huge* speedup - up to 100X! The downside of XLA generation, though, is that it doesn't like variable input shapes, because it needs to run a new compilation for each new input shape! To compensate for that, let's use `pad_to_multiple_of` for the dataset we use for text generation. This will reduce the number of unique input shapes a lot, meaning we can get the benefits of XLA generation with only a few compilations.

In [28]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="np")

generation_data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="np", pad_to_multiple_of=128)

In [29]:
tokenized_datasets["train"]

Dataset({
    features: ['document', 'summary', 'id', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 204045
})

Next, we convert our datasets to `tf.data.Dataset`, which Keras understands natively. There are two ways to do this - we can use the slightly more low-level [`Dataset.to_tf_dataset()`](https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.to_tf_dataset) method, or we can use [`Model.prepare_tf_dataset()`](https://huggingface.co/docs/transformers/main_classes/model#transformers.TFPreTrainedModel.prepare_tf_dataset). The main difference between these two is that the `Model` method can inspect the model to determine which column names it can use as input, which means you don't need to specify them yourself. Make sure to specify the collator we just created as our `collate_fn`!

In [30]:
train_dataset = model.prepare_tf_dataset(
    tokenized_datasets["train"],
    batch_size=batch_size,
    shuffle=True,
    collate_fn=data_collator,
)

validation_dataset = model.prepare_tf_dataset(
    tokenized_datasets["validation"],
    batch_size=batch_size,
    shuffle=False,
    collate_fn=data_collator,
)

generation_dataset = model.prepare_tf_dataset(
    tokenized_datasets["validation"],
    batch_size=8,
    shuffle=False,
    collate_fn=generation_data_collator
)

Now we initialize our loss and optimizer and compile the model. Note that most Transformers models compute loss internally - we can train on this as our loss value simply by not specifying a loss when we `compile()`.

In [31]:
from transformers import AdamWeightDecay
import tensorflow as tf

optimizer = AdamWeightDecay(learning_rate=learning_rate, weight_decay_rate=weight_decay)
model.compile(optimizer=optimizer)

Now we can train our model. We can also add a few optional callbacks here, which you can remove if they aren't useful to you. In no particular order, these are:
- PushToHubCallback will sync up our model with the Hub - this allows us to resume training from other machines, share the model after training is finished, and even test the model's inference quality midway through training!
- TensorBoard is a built-in Keras callback that logs TensorBoard metrics.
- KerasMetricCallback is a callback for computing advanced metrics. There are a number of common metrics in NLP like ROUGE which are hard to fit into your compiled training loop because they depend on decoding predictions and labels back to strings with the tokenizer, and calling arbitrary Python functions to compute the metric. The KerasMetricCallback will wrap a metric function, outputting metrics as training progresses.

If this is the first time you've seen `KerasMetricCallback`, it's worth explaining what exactly is going on here. The callback takes two main arguments - a `metric_fn` and an `eval_dataset`. It then iterates over the `eval_dataset` and collects the model's outputs for each sample, before passing the `list` of predictions and the associated `list` of labels to the user-defined `metric_fn`. If the `predict_with_generate` argument is `True`, then it will call `model.generate()` for each input sample instead of `model.predict()` - this is useful for metrics that expect generated text from the model, like `ROUGE`.

This callback allows complex metrics to be computed each epoch that would not function as a standard Keras Metric. Metric values are printed each epoch, and can be used by other callbacks like `TensorBoard` or `EarlyStopping`.

In [32]:
import numpy as np
import nltk


def metric_fn(eval_predictions):
    predictions, labels = eval_predictions
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    for label in labels:
        label[label < 0] = tokenizer.pad_token_id  # Replace masked label tokens
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # Rouge expects a newline after each sentence
    decoded_predictions = [
        "\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_predictions
    ]
    decoded_labels = [
        "\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels
    ]
    result = metric.compute(
        predictions=decoded_predictions, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    # Add mean generated length
    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions
    ]
    result["gen_len"] = np.mean(prediction_lens)

    return result

In [42]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

And now we can try training our model. By default, we only do a single epoch of training here, as the inputs are very long, which means training is quite slow. However, you may wish to experiment with larger pre-trained models and longer training runs if you want to maximize the quality of your summaries.

In [None]:
from tensorflow.keras.callbacks import TensorBoard
from transformers.keras_callbacks import KerasMetricCallback
import evaluate

# Define the metric function
rouge = evaluate.load('rouge')

def metric_fn(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    return rouge.compute(predictions=decoded_preds, references=decoded_labels)

# Set up metric callback
metric_callback = KerasMetricCallback(
    metric_fn=metric_fn,
    eval_dataset=generation_dataset,
    predict_with_generate=True,
    use_xla_generation=True
)

# Train the model with just the metric callback
model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=1,
    callbacks=[metric_callback]
)

    8/25505 [..............................] - ETA: 328:11:35 - loss: 3.8962

You can use the callback above, you can now share this model with all your friends, family, favorite pets: they can all load it with the identifier `"your-username/the-name-you-picked"` so for instance:

```python
from transformers import TFAutoModelForSeq2SeqLM

model = TFAutoModelForSeq2SeqLM.from_pretrained("your-username/my-awesome-model")
```

In [None]:
from huggingface_hub import notebook_login, HfFolder

# This will prompt you to enter your token
notebook_login()

# After logging in, set up your model ID with your actual username
your_username = "YOUR_ACTUAL_USERNAME"  # Replace with your Hugging Face username
model_name = "summarization-model"
push_to_hub_model_id = f"{your_username}/{model_name}"

# Set up your callbacks with the authenticated model ID
tensorboard_callback = TensorBoard(
    log_dir="./summarization_model_save/logs",
    update_freq='epoch'
)

push_to_hub_callback = PushToHubCallback(
    output_dir="./summarization_model_save",
    tokenizer=tokenizer,
    hub_model_id=push_to_hub_model_id
)

metric_callback = KerasMetricCallback(
    metric_fn=metric_fn,
    eval_dataset=generation_dataset,
    predict_with_generate=True,
    use_xla_generation=True
)

callbacks = [metric_callback, tensorboard_callback, push_to_hub_callback]

## Inference

Now we've trained our model, let's see how we could load it and use it to summarize text in future! First, let's load it from the hub. This means we can resume the code from here without needing to rerun everything above every time.

In [None]:
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

# You can of course substitute your own username and model here if you've trained and uploaded it!
model_name = 'Rocketknight1/t5-small-finetuned-xsum'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)

Now let's try tokenizing a document from the training set. Don't forget to add 'summarize:' at the start if you're using a `T5` model.

In [None]:
document = 'The full cost of damage in Newton Stewart, one of the areas worst affected, is still being assessed.\nRepair work is ongoing in Hawick and many roads in Peeblesshire remain badly affected by standing water.\nTrains on the west coast mainline face disruption due to damage at the Lamington Viaduct.\nMany businesses and householders were affected by flooding in Newton Stewart after the River Cree overflowed into the town.\nFirst Minister Nicola Sturgeon visited the area to inspect the damage.\nThe waters breached a retaining wall, flooding many commercial properties on Victoria Street - the main shopping thoroughfare.\nJeanette Tate, who owns the Cinnamon Cafe which was badly affected, said she could not fault the multi-agency response once the flood hit.\nHowever, she said more preventative work could have been carried out to ensure the retaining wall did not fail.\n"It is difficult but I do think there is so much publicity for Dumfries and the Nith - and I totally appreciate that - but it is almost like we\'re neglected or forgotten," she said.\n"That may not be true but it is perhaps my perspective over the last few days.\n"Why were you not ready to help us a bit more when the warning and the alarm alerts had gone out?"\nMeanwhile, a flood alert remains in place across the Borders because of the constant rain.\nPeebles was badly hit by problems, sparking calls to introduce more defences in the area.\nScottish Borders Council has put a list on its website of the roads worst affected and drivers have been urged not to ignore closure signs.\nThe Labour Party\'s deputy Scottish leader Alex Rowley was in Hawick on Monday to see the situation first hand.\nHe said it was important to get the flood protection plan right but backed calls to speed up the process.\n"I was quite taken aback by the amount of damage that has been done," he said.\n"Obviously it is heart-breaking for people who have been forced out of their homes and the impact on businesses."\nHe said it was important that "immediate steps" were taken to protect the areas most vulnerable and a clear timetable put in place for flood prevention plans.\nHave you been affected by flooding in Dumfries and Galloway or the Borders? Tell us about your experience of the situation and how it was handled. Email us on selkirk.news@bbc.co.uk or dumfries@bbc.co.uk.'
if 't5' in model_name:
    document = "summarize: " + document
tokenized = tokenizer([document], return_tensors='np')
out = model.generate(**tokenized, max_length=128)

In [None]:
with tokenizer.as_target_tokenizer():
    print(tokenizer.decode(out[0]))

Not bad for a single epoch of training! Of course, the flood warning isn't much use to them after they've been flooded, but the summary correctly identified flooding in Dumfries and the Nith as the key event.

## Using XLA in inference

If you just want to generate a few summaries, the code above is all you need. However, generation can be **much** faster if you use XLA, and if you want to generate data in bulk, you should probably use it! If you're using XLA, though, remember that you'll need to do a new XLA compilation for every input size you pass to the model. This means that you should keep your batch size constant, and consider padding inputs to the same length, or using `pad_to_multiple_of` in your tokenizer to reduce the number of different input shapes you pass. Let's show an example of that:

In [None]:
import tensorflow as tf

@tf.function(jit_compile=True)
def generate(inputs):
    return model.generate(**inputs, max_length=128)

tokenized_data = tokenizer([document], return_tensors="np", pad_to_multiple_of=128)
out = generate(tokenized_data)

In [None]:
with tokenizer.as_target_tokenizer():
    print(tokenizer.decode(out[0]))

When using XLA generation, you'll notice that the first call to generate with a new input shape takes a long time because XLA has to compile your function, but subsequent calls are extremely quick. Also, XLA always generates to the maximum length, which can lead to a lot of padding tokens in your output! These are easy to remove, however:

In [None]:
with tokenizer.as_target_tokenizer():
    print(tokenizer.decode(out[0], skip_special_tokens=True))

## Pipeline API

The pipeline API offers a convenient shortcut for all of this, but doesn't (yet!) support XLA generation:

In [None]:
from transformers import pipeline

summarizer = pipeline('text2text-generation', model_name, framework="tf")

Remember that if we're using a T5 model then we appended "summarize: " to the start of our input above. Don't forget to do that when you're getting summaries for new texts!

In [None]:
summarizer(document, max_length=128)

Easy!