In [None]:
# default_exp data.summarization

In [None]:
#hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# data.summarization

> This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for summarization tasks using architectures like BART and T5.

In [None]:
#export
import ast
from functools import reduce

import torch
from transformers import *
from fastai.text.all import *

from blurr.utils import *
from blurr.data.core import *

In [None]:
#hide
import pdb

from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#cuda
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')

Using GPU #1: GeForce GTX 1080 Ti


## Summarization tokenization, batch transform, and DataBlock methods

Summarization tasks attempt to generate a human-understandable and sensible representation of a larger body of text (e.g., capture the meaning of a larger document in 1-3 sentences).

In [None]:
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv'); len(cnndm_df)

1000

In [None]:
cnndm_df.head(2)

Unnamed: 0,article,highlights,ds_type
0,"(CNN) -- Globalization washes like a flood over the world's cultures and economies. Floods can be destructive; however, they can also bring blessings, as the annual floods of the Nile did for ancient Egypt. The world's great universities can be crucial instruments in shaping, in a positive way, humankind's reaction to globalization and the development of humankind itself. Traditionally, universities have been defined and limited by location, creating an academic community and drawing students and scholars to that place. Eventually, some universities began to encourage students to study el...","John Sexton: Traditionally, universities have been defined and limited by location .\nGlobal campuses form a network of thought, innovation, he writes .\nFaculty can teach, Sexton says, students can team up in many cities at once .\nSexton: Research, scholarship can be shared and cultural ties made in ""century of knowledge""",train
1,"(CNN) -- Armenian President Robert Kocharian declared a state of emergency Saturday night after a day of clashes between police and protesters, a spokeswoman for the Armenian Foreign Ministry said. Opposition supporters wave an Armenian flag during a protest rally in Yerevan, Armenia, on Saturday. The protesters claim last month's presidential election was rigged. The state of emergency will ""hopefully bring some order"" to the capital, Yerevan, said Salpi Ghazarian, assistant to the Armenian foreign minister, who spoke to CNN early Sunday. The state of emergency could last until March 20, ...","NEW: Protest moves after crackdown at Freedom Square .\nOrder sought after protests over last month's election turn violent .\nDemonstrators say the election was fraudulent .\nState of emergency could last until March 20, official says .",train


In [None]:
pretrained_model_name = "facebook/bart-large-cnn"

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, 
                                                                               model_cls=BartForConditionalGeneration)

hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)

('bart',
 transformers.tokenization_bart.BartTokenizer,
 transformers.configuration_bart.BartConfig,
 transformers.modeling_bart.BartForConditionalGeneration)

In [None]:
#export
class HF_SummarizationInput(list): pass

We create a subclass of `HF_BatchTransform` for summarization tasks to add `decoder_input_ids` and `labels` to our inputs during training, which will in turn allow the huggingface model to calculate the loss for us.  See [here](https://huggingface.co/transformers/model_doc/bart.html#transformers.BartModel.forward) for more information on these additional inputs are used in summarization and conversational training tasks.  

Note also that `labels` is simply target_ids shifted to the right by one since the task to is to predict the next token based on the current (and all previous) `decoder_input_ids`.

And lastly, we also update our targets to just be the `input_ids` of our target sequence so that fastai's `Learner.show_results` works (again, almost all the fastai bits require returning a single tensor to work).

In [None]:
#export
class HF_SummarizationBatchTransform(HF_BatchTransform):
    def __init__(self, hf_arch, hf_tokenizer, **kwargs):
        super().__init__(hf_arch, hf_tokenizer, HF_SummarizationInput, **kwargs)
        
    def encodes(self, samples):  
        samples = super().encodes(samples)
        if (len(samples[0]) == 1): return samples
        
        updated_samples = []
        for s in samples:
            s[0]['decoder_input_ids'] = s[1]['input_ids'][:-1].clone()
            s[0]['labels'] = s[1]['input_ids'][1:].clone()
            s[0]['labels'][s[0]['labels'] == self.hf_tokenizer.pad_token_id] = -100
            
            targ_ids = s[1]['input_ids']
            
            updated_samples.append((s[0], targ_ids))
        
        return updated_samples
    
    def decodes(self, encoded_samples):
        if (isinstance(encoded_samples, dict)): return self.hf_input_return_type([encoded_samples['input_ids']])
        return [encoded_samples]

We had to override the `decodes` method above because, while both our inputs and targets are technically the same things, we update the later to consist of *only* the target input_ids so that methods like `Learner.show_results` work.  Nevertheless, because fastai remembers what they are, `HF_TokenizerTransform.decodes` will be called for both and it works on a `list` of input_ids.

In [None]:
hf_batch_tfm = HF_SummarizationBatchTransform(hf_arch, hf_tokenizer)

blocks = ( 
    HF_TextBlock(hf_arch, hf_tokenizer), 
    HF_TextBlock(hf_arch, hf_tokenizer, hf_batch_tfm=hf_batch_tfm, max_length=150, hf_input_idxs=[0,1])
)

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('article'), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())

In [None]:
# dblock.summary(cnndm_df)

In [None]:
dls = dblock.dataloaders(cnndm_df, bs=4)

In [None]:
b = dls.one_batch()

In [None]:
len(b), b[0]['input_ids'].shape, b[1].shape

(2, torch.Size([4, 512]), torch.Size([4, 71]))

In [None]:
#export
@typedispatch
def show_batch(x:HF_SummarizationInput, y, samples, dataloaders=None, ctxs=None, max_n=6, **kwargs):  
    res = L([ (s[0], s[1]) for s in samples ])          
    display_df(pd.DataFrame(res, columns=['text', 'target'])[:max_n])
    return ctxs

In [None]:
dls.show_batch(dataloaders=dls, max_n=2)

Unnamed: 0,text,target
0,"(CNN) -- Standing outside a courthouse Sunday that the Libyan opposition is using for a base of operations in the town of Misrata, a witness described a sense of jubilation against a backdrop of blood stains and rocket fragments. ""I'm standing in the middle of a... battlefield,"" the witness told CNN by phone from Misrata after a fierce fight between rebels and Libyan leader Moammar Gadhafi's forces. People were holding their hands up, singing, chanting and cheering, he said. ""Everyone is hugging everyone."" CNN is not identifying witnesses and sources for safety reasons. Videos posted on YouTube and thought to be out of Misrata showed damage to buildings and several shots of people celebrating around the opposition flag -- once being raised on a pole, and another time being waved by a man atop a charred vehicle that had a dead body inside. A doctor at Central Misrata Hospital said 42 people were killed in the fighting -- 17 from the opposition and 25 from the pro-Gadhafi forces. Among the dead was a 3-year-old child, killed from direct fire, the doctor said. At least 85 people were wounded, the doctor said. The fighting continued on the city's outskirts Sunday evening. The witness described the opposition's victory in central Misrata even as people some 200 kilometers (125 miles) west, at a pro-Gadhafi demonstration in Tripoli, insisted the government had taken back the coastal central Libyan city. After reports of the opposition successfully holding onto Misrata, east of Tripoli, Libyan state TV showed a graphic stating that ""strict orders have been issued to the armed forces not to enter cities taken by terrorist gangs."" On Sunday morning, pro-Gadhafi militias converged on Misrata from three different points, trying to retake control of the city, the witness said. He saw four tanks, though other witnesses told him there were a total of six. Using heavy artillery, the ground forces and tanks headed for the courthouse operations base. Tanks fired rockets at the building, and black smoke could be seen rising from it, he said. The opposition couldn't match the government's weaponry, but rebels took to the streets using what weapons they had, such as machine guns. And some simply picked up whatever they could find, with some resorting to sticks, he said. Speaking to CNN during the battle, he said, ""People are willing to die for the cause,"" describing them as ""fearless"" and ""amazing."" Later, after the","NEW: Videos online show damage to buildings and waving of the opposition's flag.\nA doctor at a hospital in the city says 42 people were killed, 85 wounded.\nWitness in Misrata: ""Everyone is hugging everyone"" despite ""blood everywhere""\nPro-Gadhafi demonstrators in Tripoli claimed the government had taken the city."
1,"I have an uncle who has always been a robust and healthy guy. He drank a glass of skim milk every day, bragged about how many pull-ups he was doing and fit into pants he was wearing 20 years before. He didn't take a single medication and retired early. Given that he had no medical problems and ran his own business, he opted to go several years without health insurance. Eventually, when he turned 65, he picked up Medicare. What happened next was a little strange. He fell off the wagon. He exercised only sporadically, and paid hardly any attention to what he was eating. One day, I saw him eat an entire bag of potato chips. He bemoaned the fact that he was forced to buy new, bigger pants, and he stopped drinking his milk. For him, becoming newly insured had nearly the opposite effect on him of what we doctors hope to achieve. He'd become unhealthier. In many ways, my uncle was demonstrating a concept known as the moral hazard. Two economists wrote about this exact scenario in 2006. They found that many men, at the time they obtained Medicare, started behaving badly. Moral, or morale, hazard is a term largely used by economists to describe the actions of people more willing to take risks because they are insulated from the cost of their actions, in this case because of their recently obtained health insurance. In the case of these men, when they got Medicare, they took worse care of themselves; they actually exercised less. Among those who didn't visit the doctor after getting insurance, the effect was dramatic: Their overall physical activity dropped by 40%; they were 16% more likely to smoke cigarettes and 32% more likely to drink alcohol. Even if that seems extreme, it's still worth asking: Does health insurance make us healthier? The past five years have seen a tumultuous battle over Obamacare, or the Affordable Care Act, culminating in the bitter recriminations this fall over lost policies and the disastrous launch of the HealthCare.gov website. When I interviewed Health and Human Services Secretary Kathleen Sebelius at the end of October, she downplayed the concerns and seemed certain the site would be up and running by the end of November. The website may be working better now, but to me that's not the most important issue. In my mind, the real suspense comes from whether Obamacare will really make us a healthier America, even if it succeeds in its ambitions to dramatically expand coverage. A healthier America: That is the goal we should share as Americans","Sanjay Gupta: Moral hazard causes some to neglect health when they get health insurance.\nHe says Obamacare alone won't guarantee good health; personal habits must do that.\nHe says research shows 30 minutes of daily exercise cuts heart attack, stroke risk by a third.\nGupta: It's time to stop playing defense on your health; instead, start optimizing it yourself."


## Tests

The tests below to ensure the core DataBlock code above works for **all** pretrained summarization models available in huggingface.  These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

**Note**: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue *(or a PR if you'd like to fix it yourself)*

In [None]:
BLURR_MODEL_HELPER.get_models(task='ConditionalGeneration')

[transformers.modeling_bart.BartForConditionalGeneration,
 transformers.modeling_mbart.MBartForConditionalGeneration,
 transformers.modeling_pegasus.PegasusForConditionalGeneration,
 transformers.modeling_t5.T5ForConditionalGeneration]

In [None]:
pretrained_model_names = [
    ('facebook/bart-base',BartForConditionalGeneration),
    ('t5-small', T5ForConditionalGeneration),
    ('google/pegasus-cnn_dailymail', PegasusForConditionalGeneration)
]

In [None]:
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv')

In [None]:
#slow
#hide_output
task = HF_TASKS_ALL.ConditionalGeneration
bsz = 2

test_results = []
for model_name, model_cls in pretrained_model_names:
    error=None
    
    print(f'=== {model_name} ===\n')
    
    hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name, 
                                                                                   task=task,
                                                                                   model_cls=model_cls)
    
    print(f'architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\n')
    
    hf_batch_tfm = HF_SummarizationBatchTransform(hf_arch, hf_tokenizer)

    blocks = ( 
        HF_TextBlock(hf_arch, hf_tokenizer, padding='max_length', max_length=256), 
        HF_TextBlock(hf_arch, hf_tokenizer, hf_batch_tfm=hf_batch_tfm, padding='max_length', max_length=50, 
                     hf_input_idxs=[0,1])
    )

    def add_t5_prefix(inp): return f'summarize: {inp}' if (hf_arch == 't5') else inp

    dblock = DataBlock(blocks=blocks, 
                   get_x=Pipeline([ColReader('article'), add_t5_prefix]), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())

    dls = dblock.dataloaders(cnndm_df, bs=bsz) 
    b = dls.one_batch()
    
    try:
        print('*** TESTING DataLoaders ***\n')
        test_eq(len(b), 2)
        test_eq(len(b[0]['input_ids']), bsz)
        test_eq(b[0]['input_ids'].shape, torch.Size([bsz, 256]))
        test_eq(len(b[1]), bsz)
        test_eq(b[1].shape, torch.Size([bsz,50]))

        if (hasattr(hf_tokenizer, 'add_prefix_space')):
            test_eq(dls.tfms[0].kwargs['add_prefix_space'], True)
            
        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, 'PASSED', ''))
        dls.show_batch(dataloaders=dls, max_n=2)
        
    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, 'FAILED', err))

=== facebook/bart-base ===

architecture:	bart
tokenizer:	BartTokenizer

*** TESTING DataLoaders ***



Unnamed: 0,text,target
0,"We may never know the why -- though there has been no shortage of speculation on the Internet -- but at least now we know what the Carter-Knowles family has to say about their already infamous elevator fight. Solange Knowles, Jay Z and his wife, Beyonce, have released a statement about surveillance video originally posted by TMZ showing Solange, the younger sister of Beyonce, hitting and kicking her brother-in-law. The trio had previously not spoken publicly about the incident in an elevator at the Standard Hotel in New York City following the Met Gala held on May 5. But on Thursday, they broke their silence. The family referred CNN to a statement they previously gave to the Associated Press. The statement says:. ""As a result of the public release of the elevator security footage from Monday, May 5th, there has been a great deal of speculation about what triggered the unfortunate incident. But the most important thing is that our family has worked through it. Jay and Solange each assume their share of responsibility for what has occurred. ""They both acknowledge their role in this private matter that has played out in the public. They both have apologized to each other and we have moved forward as a united family. ""The reports of Solange being","Solange Knowles, Jay Z and his wife, Beyonce, release a statement.\nTMZ released video showing an altercation between Solange and Jay Z in an elevator.\nThey say they have ""worked through it,"" saying all"
1,"(CNN) -- The latest trend at teen parties isn't warm beer or prescription medicines pilfered from parents' medicine cabinets. Instead, increasing numbers of youths are turning to an herb-based product to get high, and unlike marijuana, it's perfectly legal. It's known as K2 or Spice, a synthetic substance that, when smoked, gives users a marijuana-like high, according to drug authorities. Its growing popularity is causing increasing alarm among health care professionals, law enforcement authorities and lawmakers, with one Drug Enforcement Agency official calling its use the equivalent of ""playing Russian roulette."" Should some illegal drugs be legalized? Manufactured in Asia and sold online or in local stores, K2 and similar substances are marketed as herbal incense. A disclaimer on a K2-selling Web site reads: ""K2Herbal products are novelty incenses and are not for consumption."" Sold in various flavors in 3-gram bags, the product consists of herbs that are sprayed with synthetic substances that mimic THC, the high-causing natural chemical found in marijuana. A call to regulate K2. Health and drug officials say the danger in using such products is the unregulated nature of their production and makeup. ""Our biggest concern is that this particular chemical is likely manufactured","K2 or Spice, when smoked, gives users a marijuana-like high.\nDanger of products is the unregulated nature of their production and makeup.\nSide effects include heart palpitations, respiratory issues, panic attacks, hallucinations."


=== t5-small ===

architecture:	t5
tokenizer:	T5Tokenizer

*** TESTING DataLoaders ***



Unnamed: 0,text,target
0,"summarize: WASHINGTON (CNN) -- House and Senate Democrats reached agreement late Monday on a budget resolution for 2010, which includes key spending priorities for the young Obama administration. The Senate and House could vote on the budget resolution Tuesday. President Obama's budget request is $3.67 trillion. ""This budget is a major accomplishment,"" Senate Budget Committee Chairman Kent Conrad said in a statement. ""We are meeting President Obama's goals of reducing our dependence on foreign energy, striving for excellence in education, reforming our health care system, and providing middle-class tax relief."" The agreement came as lawmakers were reconciling the House and Senate versions of the budget package. The president's budget request is $3.67 trillion. The full Senate and House are each expected to vote on the fiscal 2010 budget resolution this week. The House vote could come as soon as Tuesday. Budget negotiators have fast-tracked part of the budget process. Major health reform is likely to pass this year, because the special process -- known as budget reconciliation -- won't allow Republicans to filibuster the legislation, as was widely expected. Democrats, who currently control 58 seats in the Senate, will be able to",House and Senate Dems say President Obama's goals addressed in resolution. Democratic leaders urge pay-as-you-go system that Obama has emphasized. Senate and House are each expected to vote on the budget resolution this week
1,"summarize: (CNN) -- This election should have been a walkover for the Republicans. The economy is sluggish and the United States is beset with crises abroad. Yet, Mitt Romney has committed one gaffe after another, almost as if he actually wants to lose. Perhaps the multi-millionaire has decided that the White House is too small for him. On Monday night, Romney was hit with what we might call a ""pre-gaffe"" when a private statement that he made months ago suddenly hit the Web. The video shows Romney apparently dismissing the 47% of Americans who he says don't pay federal income taxes as freeloaders. For someone who is often portrayed as cynical and uncaring, this is not good news. What will we see next? Leaked footage of Romney stealing candy from a baby? There's cause for Republicans to panic. Some commentators are starting to ask, ""Did Romney just lose the election?"" When I first saw the ""47%"" video, I wrote that it had to damage Romney's already poor likeability ratings and maybe even cost him the White House. But, after","Tim Stanley: Mitt Romney's secretly taped remarks hurt his cause but aren't fatal. He says Romney may find, like other politicians, that gaffes often don't stick. Polls in recent"


=== google/pegasus-cnn_dailymail ===

architecture:	pegasus
tokenizer:	PegasusTokenizer

*** TESTING DataLoaders ***



Unnamed: 0,text,target
0,"Editor's note: Jay S. Winuk, co-founder of MyGoodDeed, is the brother of Glenn J. Winuk, an attorney and volunteer firefighter and EMT who died in the line of duty when the South Tower of the World Trade Center collapsed on September 11, 2001. This week Glenn was posthumously honored with the 9/11 Heroes Medal of Valor from the United States of America. Jay Winuk says September 11 is best observed as a day of service to others. NEW YORK (CNN) -- The upcoming eighth anniversary of the attacks of September 11 raises a compelling question for millions of Americans: How should we best observe this uniquely tragic day in our nation's history? Surely, it should not be a holiday. This is no time for days off from work and three-day weekends to enjoy barbeques and white sales. No, September 11 is a day for reflection, and its historical and emotional significance should not lessen with time or be diminished in any way. It is a day to focus on the substantial lessons learned. I'm a 9/11 family member. My brave brother, Glenn J. Winuk, was a partner at a large law firm, Holland & Knight, located two blocks from the World Trade Center. For almost 20 years Glenn was also a","Jay Winuk: 9/11 has been recognized as a national day of service. He says it's not a day to skip work or go shopping. He says people choose to do acts of kindness, large or small. Winuk"
1,"(CNN) -- Bayern Munich closed to within two points of Bundesliga leaders Hamburg and Bayer Leverkusen with a 2-1 victory at home to Bavarian rivals Nuremberg on Saturday. Belgian defender Daniel Van Buyten celebrates his winning goal for Bayern Munich. Daniel Van Buyten's 82nd-minute winner left Bayern on 13 points behind the two top teams, who are in action on Sunday, and above Hoffenheim and Mainz on goal difference. Mainz won 3-2 away to Bochum on Saturday, while Hoffenheim scored three late goals to snatch a 4-2 victory at Borussia Moenchengladbach. They moved above Schalke, who lost 2-1 to coach Felix Magath's former team, defending champions Wolfsburg, on Friday night. Bayern broke into the top three for the first time this season despite again starting with France star Franck Ribery on the substitutes' bench, where he was joined by Miroslav Klose as young striker Thomas Muller was given the chance to add to his run of four goals in two games. Nuremberg fielded a defensive line-up, but striker Mario Gomez still managed to hit the crossbar for Bayern before Ivica Olic broke the deadlock in the 55th minute. The Croatia forward scored his second goal of the season after being fed in the penalty area by Muller","Bayern Munich close to within two points of Bundesliga's top two teams. Daniel van Buyten heads late goal against Bavarian rivals Nuremberg in 2-1 win. Bayern third on goal difference ahead of Hoffenheim and Mainz, who both win."


In [None]:
#slow
#hide_input
test_results_df = pd.DataFrame(test_results, columns=['arch', 'tokenizer', 'model_name', 'result', 'error'])
display_df(test_results_df)

Unnamed: 0,arch,tokenizer,model_name,result,error
0,bart,BartTokenizer,facebook/bart-base,PASSED,
1,t5,T5Tokenizer,t5-small,PASSED,
2,pegasus,PegasusTokenizer,google/pegasus-cnn_dailymail,PASSED,


## Cleanup

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 01_data-core.ipynb.
Converted 01a_data-token-classification.ipynb.
Converted 01b_data-question-answering.ipynb.
Converted 01e_data-summarization.ipynb.
Converted 01z_data-language-modeling.ipynb.
Converted 02_modeling-core.ipynb.
Converted 02a_modeling-token-classification.ipynb.
Converted 02b_modeling-question-answering.ipynb.
Converted 02e_modeling-summarization.ipynb.
Converted 02z_modeling-language-modeling.ipynb.
Converted index.ipynb.
