In [None]:
# default_exp data.text2text.summarization

In [None]:
#hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# data.text2text.summarization

> This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for summarization tasks using architectures like BART and T5.

In [None]:
#export
import ast
from functools import reduce

import torch
from transformers import *
from fastai.text.all import *

from blurr.utils import *
from blurr.data.core import *
from blurr.data.text2text.core import *

logging.set_verbosity_error()

In [None]:
#hide
import pdb

from nbdev.showdoc import *
from fastcore.test import *

from fastai import __version__ as fa_version
from torch import __version__ as pt_version
from transformers import __version__ as hft_version

print(f'Using pytorch {pt_version}')
print(f'Using fastai {fa_version}')
print(f'Using transformers {hft_version}')

Using pytorch 1.7.1
Using fastai 2.1.8
Using transformers 4.0.1


In [None]:
#cuda
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')

Using GPU #1: GeForce GTX 1080 Ti


## Summarization tokenization, batch transform, and DataBlock methods

Summarization tasks attempt to generate a human-understandable and sensible representation of a larger body of text (e.g., capture the meaning of a larger document in 1-3 sentences).

In [None]:
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv'); len(cnndm_df)

1000

In [None]:
cnndm_df.head(2)

Unnamed: 0,article,highlights,ds_type
0,"(CNN) -- Globalization washes like a flood over the world's cultures and economies. Floods can be destructive; however, they can also bring blessings, as the annual floods of the Nile did for ancient Egypt. The world's great universities can be crucial instruments in shaping, in a positive way, humankind's reaction to globalization and the development of humankind itself. Traditionally, universities have been defined and limited by location, creating an academic community and drawing students and scholars to that place. Eventually, some universities began to encourage students to study el...","John Sexton: Traditionally, universities have been defined and limited by location .\nGlobal campuses form a network of thought, innovation, he writes .\nFaculty can teach, Sexton says, students can team up in many cities at once .\nSexton: Research, scholarship can be shared and cultural ties made in ""century of knowledge""",train
1,"(CNN) -- Armenian President Robert Kocharian declared a state of emergency Saturday night after a day of clashes between police and protesters, a spokeswoman for the Armenian Foreign Ministry said. Opposition supporters wave an Armenian flag during a protest rally in Yerevan, Armenia, on Saturday. The protesters claim last month's presidential election was rigged. The state of emergency will ""hopefully bring some order"" to the capital, Yerevan, said Salpi Ghazarian, assistant to the Armenian foreign minister, who spoke to CNN early Sunday. The state of emergency could last until March 20, ...","NEW: Protest moves after crackdown at Freedom Square .\nOrder sought after protests over last month's election turn violent .\nDemonstrators say the election was fraudulent .\nState of emergency could last until March 20, official says .",train


In [None]:
pretrained_model_name = "facebook/bart-large-cnn"

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, 
                                                                               model_cls=BartForConditionalGeneration)

hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)

('bart',
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)

We create a subclass of `HF_BeforeBatchTransform` for summarization tasks to add `decoder_input_ids` and `labels` to our inputs during training, which will in turn allow the huggingface model to calculate the loss for us.  See [here](https://huggingface.co/transformers/glossary.html#labels) and [here](https://huggingface.co/transformers/glossary.html#decoder-input-ids) for more information on these additional inputs used in summarization and conversational training tasks. How they should look for particular architectures can be found by looking at those model's `forward` function's docs (See [here](https://huggingface.co/transformers/model_doc/bart.html#transformers.BartModel.forward) for BART for example)

Note also that `labels` is simply target_ids shifted to the right by one since the task to is to predict the next token based on the current (and all previous) `decoder_input_ids`.

And lastly, we also update our targets to just be the `input_ids` of our target sequence so that fastai's `Learner.show_results` works (again, almost all the fastai bits require returning a single tensor to work).

In [None]:
#export
class HF_SummarizationInput(HF_BaseInput): pass

In [None]:
#export
class HF_SummarizationBeforeBatchTransform(HF_BeforeBatchTransform):
    
    def __init__(self, hf_arch, hf_tokenizer, max_length=None, padding=True, truncation=True, 
                 is_split_into_words=False, n_tok_inps=2, ignore_token_id=CrossEntropyLossFlat().ignore_index, 
                 tok_kwargs={}, **kwargs):
                 
        super().__init__(hf_arch, hf_tokenizer, max_length=max_length, padding=padding, truncation=truncation, 
                         is_split_into_words=is_split_into_words, n_tok_inps=n_tok_inps, 
                         tok_kwargs=tok_kwargs.copy(), **kwargs)
        
        self.ignore_token_id = ignore_token_id
        
    def encodes(self, samples):  
        samples = super().encodes(samples)
        if (len(samples[0]) == 1): return samples
        
        updated_samples = []
        for s in samples:
            trg_input_ids = s[1]['input_ids']
            
            if (self.hf_arch in ['t5', 'pegasus']):
                # see: https://github.com/huggingface/transformers/issues/7986#issuecomment-714938591
                trg_input_ids = F.pad(trg_input_ids, pad=(1,0), value=self.hf_tokenizer.pad_token_id)[:-1]
            
            s[0]['decoder_input_ids'] = trg_input_ids[:-1].clone()
            s[0]['labels'] = trg_input_ids[1:].clone()
            s[0]['labels'][s[0]['labels'] == self.hf_tokenizer.pad_token_id] = self.ignore_token_id
            
            targ_ids = s[0]['labels'].clone()

            updated_samples.append((s[0], targ_ids))
        
        return updated_samples

In [None]:
before_batch_tfm = HF_SummarizationBeforeBatchTransform(hf_arch, hf_tokenizer)
blocks = (HF_Text2TextBlock(before_batch_tfms=before_batch_tfm, input_return_type=HF_SummarizationInput), noop)

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('article'), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())

Two lines!  Notice we pass in `noop` for our targets (e.g. our summaries) because the batch transform will take care of both out inputs and targets.

In [None]:
# dblock.summary(cnndm_df)

In [None]:
dls = dblock.dataloaders(cnndm_df, bs=4)

In [None]:
b = dls.one_batch()

In [None]:
len(b), b[0]['input_ids'].shape, b[1].shape

(2, torch.Size([4, 1024]), torch.Size([4, 78]))

In [None]:
t = torch.randn((3,3));

F.pad(t, pad=(1,0), value=1)[:,:-1]

tensor([[ 1.0000,  1.2162, -1.3936],
        [ 1.0000,  0.3981, -1.5165],
        [ 1.0000, -1.7347,  2.1353]])

In [None]:
#export
@typedispatch
def show_batch(x:HF_SummarizationInput, y, samples, dataloaders, ctxs=None, max_n=6, 
               input_trunc_at=None, target_trunc_at=None, **kwargs):  
    
    before_batch_tfm = dataloaders.before_batch[0]
    hf_tokenizer = before_batch_tfm.hf_tokenizer
    ignore_token_id = before_batch_tfm.ignore_token_id
    
    res = L([ (hf_tokenizer.decode(s[0], skip_special_tokens=True)[:input_trunc_at], 
               hf_tokenizer.decode(s[1][s[1] != ignore_token_id], skip_special_tokens=True)[:target_trunc_at])
             for s in samples ])      
    
    display_df(pd.DataFrame(res, columns=['text', 'target'])[:max_n])
    return ctxs

In [None]:
dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=1000, target_trunc_at=250)

Unnamed: 0,text,target
0,"(CNN) -- Home to up to 10 percent of all known species, Mexico is recognized as one of the most biodiverse regions on the planet. The twin threats of climate change and human encroachment on natural environments are, however, threatening the existence of the country's rich wildlife. And there is a great deal to lose. In the United Nations Environment Program (UNEP) World Conservation Monitoring Centre's list of megadiverse countries Mexico ranks 11th. The list represents a group of 17 countries that harbor the majority of the Earth's species and are therefore considered extremely biodiverse. From its coral reefs in the Caribbean Sea to its tropical jungles in Chiapas and the Yucatan peninsula and its deserts and prairies in the north, Mexico boasts an incredibly rich variety of flora and fauna. Some 574 out of 717 reptile species found in Mexico -- the most in any country -- can only be encountered within its borders. It is home to 502 types of mammals, 290 species of birds, 1,150 var","Mexico hosts to up to 10 percent of all known species on Earth.\nIt is home to 502 types of mammals, 290 bird species and 26,000 types of plants.\nHuman development and climate change is placing a big strain on its biodiversity.\nThe Golden Eagle is un"
1,"(CNN) -- It's a congested, sprawling transport hub surrounded by 1950s architecture and predominantly used by commuters or tourists to cross the city of Istanbul. But proposed changes to Taksim Square have seen it become the flashpoint for protests that have swept through Turkey in the past week, leaving thousands injured and focusing the world's attention on the government of Prime Minister Recep Tayyip Erdogan. Taksim has been no stranger to violence. In 1977, at least 34 protesters died during May Day clashes with police. May 1 rallies in the square were banned in 1980 and were only allowed to legally resume in 2010. On May Day this year, there were riots after city authorities again refused to grant trade unions and youth groups permission to demonstrate in Taksim, blaming construction work being carried out in the square. Professor Ersin Kalaycioglu, professor of political science at Istanbul's Sabanci University, said significantly, Taksim Square was also known as ""republic squa",Taksim Square was where Istanbul's water was distributed -- Taksim means divide.\nThe site is seen as symbolizing the seclar Turkish republic founded by Ataturk.\nErdogan's government's plans to alter Taksim's Gezi Park prompted protests.\nThe police's


## Tests

The tests below to ensure the core DataBlock code above works for **all** pretrained summarization models available in huggingface.  These tests are excluded from the CI workflow because of how long they would take to run and the amount of data that would be required to download.

**Note**: Feel free to modify the code below to test whatever pretrained summarization models you are working with ... and if any of your pretrained summarization models fail, please submit a github issue *(or a PR if you'd like to fix it yourself)*

In [None]:
BLURR_MODEL_HELPER.get_models(task='ConditionalGeneration')

[transformers.models.bart.modeling_bart.BartForConditionalGeneration,
 transformers.models.blenderbot.modeling_blenderbot.BlenderbotForConditionalGeneration,
 transformers.models.fsmt.modeling_fsmt.FSMTForConditionalGeneration,
 transformers.models.mbart.modeling_mbart.MBartForConditionalGeneration,
 transformers.models.mt5.modeling_mt5.MT5ForConditionalGeneration,
 transformers.models.pegasus.modeling_pegasus.PegasusForConditionalGeneration,
 transformers.models.prophetnet.modeling_prophetnet.ProphetNetForConditionalGeneration,
 transformers.models.t5.modeling_t5.T5ForConditionalGeneration,
 transformers.models.xlm_prophetnet.modeling_xlm_prophetnet.XLMProphetNetForConditionalGeneration]

In [None]:
pretrained_model_names = [
    ('facebook/bart-base',BartForConditionalGeneration),
    ('t5-small', T5ForConditionalGeneration),
    ('google/pegasus-cnn_dailymail', PegasusForConditionalGeneration)
]

In [None]:
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv')

In [None]:
#slow
#hide_output
task = HF_TASKS_ALL.ConditionalGeneration
bsz = 2
seq_sz = 256
trg_seq_sz = 40

test_results = []
for model_name, model_cls in pretrained_model_names:
    error=None
    
    print(f'=== {model_name} ===\n')
    
    hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name, 
                                                                                   task=task, 
                                                                                   model_cls=model_cls)
    print(f'architecture:\t{hf_arch}\ntokenizer:\t{type(hf_tokenizer).__name__}\n')
    
    before_batch_tfm = HF_SummarizationBeforeBatchTransform(hf_arch, hf_tokenizer, 
                                                            padding='max_length', 
                                                            max_length=[seq_sz, trg_seq_sz])
    
    blocks = (HF_TextBlock(before_batch_tfms=before_batch_tfm, input_return_type=HF_SummarizationInput), noop)
    
    def add_t5_prefix(inp): return f'summarize: {inp}' if (hf_arch == 't5') else inp

    dblock = DataBlock(blocks=blocks, 
                   get_x=Pipeline([ColReader('article'), add_t5_prefix]), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())

    dls = dblock.dataloaders(cnndm_df, bs=bsz) 
    b = dls.one_batch()
    
    try:
        print('*** TESTING DataLoaders ***\n')
        test_eq(len(b), 2)
        test_eq(len(b[0]['input_ids']), bsz)
        test_eq(b[0]['input_ids'].shape, torch.Size([bsz, seq_sz]))
        test_eq(len(b[1]), bsz)
        test_eq(b[1].shape, torch.Size([bsz, trg_seq_sz - 1]))

        if (hasattr(hf_tokenizer, 'add_prefix_space')):
             test_eq(hf_tokenizer.add_prefix_space, True)
            
        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, 'PASSED', ''))
        dls.show_batch(dataloaders=dls, max_n=2, input_trunc_at=1000)
        
    except Exception as err:
        test_results.append((hf_arch, type(hf_tokenizer).__name__, model_name, 'FAILED', err))

=== facebook/bart-base ===

architecture:	bart
tokenizer:	BartTokenizerFast

*** TESTING DataLoaders ***



Unnamed: 0,text,target
0,"(CNN) -- Home to up to 10 percent of all known species, Mexico is recognized as one of the most biodiverse regions on the planet. The twin threats of climate change and human encroachment on natural environments are, however, threatening the existence of the country's rich wildlife. And there is a great deal to lose. In the United Nations Environment Program (UNEP) World Conservation Monitoring Centre's list of megadiverse countries Mexico ranks 11th. The list represents a group of 17 countries that harbor the majority of the Earth's species and are therefore considered extremely biodiverse. From its coral reefs in the Caribbean Sea to its tropical jungles in Chiapas and the Yucatan peninsula and its deserts and prairies in the north, Mexico boasts an incredibly rich variety of flora and fauna. Some 574 out of 717 reptile species found in Mexico -- the most in any country -- can only be encountered within its borders. It is home to 502 types of mammals, 290 species of birds, 1,150 var","Mexico hosts to up to 10 percent of all known species on Earth.\nIt is home to 502 types of mammals, 290 bird species and 26,000 types of plants.\nHuman development"
1,"Michael Zehaf-Bibeau, the 32-year-old gunman who attacked the Canadian Parliament and killed a soldier last week, had a familiar profile. It is that of a young man alienated from mainstream society, with few friends, without a steady job, drifting from one place to another -- often with a history of petty crime and drug abuse. Then comes the conversion to or rediscovery of Islam, and the adoption of a jihadist mindset, fed by media and online coverage of the West's involvement in wars in Iraq and Afghanistan, and by the well-oiled propaganda machine of groups like ISIS. Whether these young men acting alone (and they are almost always men) should be better described as ""lone-wolf"" terrorists or deranged criminals is debatable. ""There is no single, universally accepted definition of terrorism,"" says the FBI. In many cases, their conversion to militant Islam is about seeking identity and purpose, or even a sense of adventure. Few of these men have a deep understanding of Salafism, the de","Like many ""lone wolf"" terrorists, Ottawa gunman was alienated drifter who converted to Islam.\nConversion to militant Islam is often about seeking identity, purpose or adventure.\nSome"


=== t5-small ===

architecture:	t5
tokenizer:	T5TokenizerFast

*** TESTING DataLoaders ***



Unnamed: 0,text,target
0,"summarize: (CNN) -- Home to up to 10 percent of all known species, Mexico is recognized as one of the most biodiverse regions on the planet. The twin threats of climate change and human encroachment on natural environments are, however, threatening the existence of the country's rich wildlife. And there is a great deal to lose. In the United Nations Environment Program (UNEP) World Conservation Monitoring Centre's list of megadiverse countries Mexico ranks 11th. The list represents a group of 17 countries that harbor the majority of the Earth's species and are therefore considered extremely biodiverse. From its coral reefs in the Caribbean Sea to its tropical jungles in Chiapas and the Yucatan peninsula and its deserts and prairies in the north, Mexico boasts an incredibly rich variety of flora and fauna. Some 574 out of 717 reptile species found in Mexico -- the most in any country -- can only be encountered within its borders. It is home to 502 types of mammals, 290 species of birds,","Mexico hosts to up to 10 percent of all known species on Earth. It is home to 502 types of mammals, 290 bird species and 26,000 types of plants. Human development"
1,"summarize: (CNN Student News) -- October 27, 2009. Downloadable Maps. Download PDF maps related to today's show: • Afghanistan & Pakistan • Los Angeles & San Diego • Ft. Jackson, South Carolina. Transcript. THIS IS A RUSH TRANSCRIPT. THIS COPY MAY NOT BE IN ITS FINAL FORM AND MAY BE UPDATED. NATISHA LANCE, CNN STUDENT NEWS ANCHOR: A member of the military is making history. We'll explain how in today's edition of CNN Student News. Hi, everyone. Carl Azuz is off this week. I'm Natisha Lance. First Up: Afghan Crashes. LANCE: First up, Pakistan and Afghanistan. The countries share a border, and they also share a common problem: threats from militant groups and terrorists like the Taliban and al Qaeda. It's an issue facing both nations' governments, and one that the U.S. government is concerned about as well. That's why President Obama has been holding a series of meetings with some of his advisers. They're reviewing the U.",Consider U.S. efforts to offer Afghan citizens an alternative to the Taliban. Hear how a proposed health care bill addresses the issue of the public option. Meet a soldier


=== google/pegasus-cnn_dailymail ===



ValueError: Couldn't instantiate the backend tokenizer from one of: (1) a `tokenizers` library serialization file, (2) a slow tokenizer instance to convert or (3) an equivalent slow tokenizer class to instantiate and convert. You need to have sentencepiece installed to convert a slow tokenizer to a fast one.

In [None]:
#slow
#hide_input
test_results_df = pd.DataFrame(test_results, columns=['arch', 'tokenizer', 'model_name', 'result', 'error'])
display_df(test_results_df)

## Cleanup

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()