In [None]:
# default_exp data.text_generation

In [None]:
#hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# data.text_generation

> This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for text generation tasks using architectures like BART, T5, or good ol' GPT2, etc....  Abstract summarization and conversational agents are good examples of such tasks.

In [None]:
#export
import ast
from functools import reduce

import torch
from transformers import *
from fastai2.text.all import *

from blurr.utils import *
from blurr.data.core import *

In [None]:
#hide
import pdb

from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#cuda
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')

Using GPU #1: GeForce GTX 1080 Ti


## Text Generation tokenization, batch transform, and DataBlock methods

Text generation tasks attempt to generate a human-understandable and sensible response to a prior text.  For example, in summarization, our objective is to capture the meaning of a larger document in 1-3 sentences.

In [None]:
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv'); len(cnndm_df)

1000

In [None]:
cnndm_df.head(2)

Unnamed: 0,article,highlights,ds_type
0,"(CNN) -- Globalization washes like a flood over the world's cultures and economies. Floods can be destructive; however, they can also bring blessings, as the annual floods of the Nile did for ancient Egypt. The world's great universities can be crucial instruments in shaping, in a positive way, humankind's reaction to globalization and the development of humankind itself. Traditionally, universities have been defined and limited by location, creating an academic community and drawing students and scholars to that place. Eventually, some universities began to encourage students to study el...","John Sexton: Traditionally, universities have been defined and limited by location .\nGlobal campuses form a network of thought, innovation, he writes .\nFaculty can teach, Sexton says, students can team up in many cities at once .\nSexton: Research, scholarship can be shared and cultural ties made in ""century of knowledge""",train
1,"(CNN) -- Armenian President Robert Kocharian declared a state of emergency Saturday night after a day of clashes between police and protesters, a spokeswoman for the Armenian Foreign Ministry said. Opposition supporters wave an Armenian flag during a protest rally in Yerevan, Armenia, on Saturday. The protesters claim last month's presidential election was rigged. The state of emergency will ""hopefully bring some order"" to the capital, Yerevan, said Salpi Ghazarian, assistant to the Armenian foreign minister, who spoke to CNN early Sunday. The state of emergency could last until March 20, ...","NEW: Protest moves after crackdown at Freedom Square .\nOrder sought after protests over last month's election turn violent .\nDemonstrators say the election was fraudulent .\nState of emergency could last until March 20, official says .",train


In [None]:
pretrained_model_name = "facebook/bart-large-cnn"

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, 
                                                                               model_cls=BartForConditionalGeneration)

hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)

Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-cnn and are newly initialized: ['final_logits_bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


('bart',
 transformers.tokenization_bart.BartTokenizer,
 transformers.configuration_bart.BartConfig,
 transformers.modeling_bart.BartForConditionalGeneration)

In [None]:
#export
class HF_TextGenerationInput(list): pass

We create a subclass of `HF_BatchTransform` for generation tasks to add `decoder_input_ids` and `labels` to our inputs during training, which will in turn allow the huggingface model to calculate the loss for us.  See [here](https://huggingface.co/transformers/model_doc/bart.html#transformers.BartModel.forward) for more information on these additional inputs are used in summarization and conversational training tasks.  

Note also that `labels` is simply target_ids shifted to the right by one since the task to is to predict the next token based on the current (and all previous) `decoder_input_ids`.

And lastly, we also update our targets to just be the `input_ids` of our target sequence so that fastai's `Learner.show_results` works (again, almost all the fastai bits require returning a single tensor to work).

In [None]:
#export
class HF_TextGenerationBatchTransform(HF_BatchTransform):
    def __init__(self, hf_arch, hf_tokenizer, **kwargs):
        super().__init__(hf_arch, hf_tokenizer, HF_TextGenerationInput, **kwargs)
        
    def encodes(self, samples):  
        samples = super().encodes(samples)
        if (len(samples[0]) == 1): return samples
        
        updated_samples = []
        for s in samples:
            s[0]['decoder_input_ids'] = s[1]['input_ids'][:-1].clone()
            s[0]['labels'] = s[1]['input_ids'][1:].clone()
            s[0]['labels'][s[0]['labels'] == self.hf_tokenizer.pad_token_id] = -100
            
            targ_ids = s[1]['input_ids']
            
            updated_samples.append((s[0], targ_ids))
        
        return updated_samples
    
    def decodes(self, encoded_samples):
        if (isinstance(encoded_samples, dict)): return self.hf_input_return_type([encoded_samples['input_ids']])
        return [encoded_samples]

We had to override the `decodes` method above because, while both our inputs and targets are technically the same things, we update the later to consist of *only* the target input_ids so that methods like `Learner.show_results` work.  Nevertheless, because fastai remembers what they are, `HF_TokenizerTransform.decodes` will be called for both and it works on a `list` of input_ids.

In [None]:
hf_batch_tfm = HF_TextGenerationBatchTransform(hf_arch, hf_tokenizer)

blocks = ( 
    HF_TextBlock(hf_arch, hf_tokenizer), 
    HF_TextBlock(hf_arch, hf_tokenizer, hf_batch_tfm=hf_batch_tfm, max_length=150)
)

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('article'), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())

In [None]:
# dblock.summary(cnndm_df)

In [None]:
dls = dblock.dataloaders(cnndm_df, bs=4)

In [None]:
b = dls.one_batch()

In [None]:
len(b), b[0]['input_ids'].shape, b[1].shape

(2, torch.Size([4, 512]), torch.Size([4, 150]))

In [None]:
#export
@typedispatch
def show_batch(x:HF_TextGenerationInput, y, samples, dataloaders=None, ctxs=None, max_n=6, **kwargs):  
    res = L([ (s[0], s[1]) for s in samples ])          
    display_df(pd.DataFrame(res, columns=['text', 'target'])[:max_n])
    return ctxs

In [None]:
dls.show_batch(dataloaders=dls, max_n=2)

Unnamed: 0,text,target
0,"Israel is confronting a problem beyond the Hamas rockets screeching overhead -- a threat underfoot. The Israeli military says it is trying to demolish a sophisticated network of tunnels that run through parts of northeast Gaza, under the border and into southern Israel. Hamas has already used the tunnels several times in the past few days to attempt assaults on Israeli soil. The first attack, on July 17, was foiled but prompted Israel to announce a ground incursion into Gaza with the stated aim of taking out the tunnels. Another assault through tunnels s few days later resulted in clashes that killed more than 10 Hamas fighters and four Israeli soldiers. The assault near the town of Sderot appeared to target two communal areas ""where farmers are trying to conduct their daily lives,"" said Israeli government spokesman Mark Regev. The Hamas fighters were disguised as Israeli soldiers, according to the Israel Defense Forces. The clashes forced area roads to close, residents to shelter in their homes and tied up security forces for hours. The method of attack, in which militants spring out unexpectedly from underground, has struck fear into Israelis living near Gaza. ""Your enemy is about to blast his way into your dining room from below the floor while you are feeding your family. Sounds like a B-rated horror movie, right? This scenario is one real example of a Hamas tunnel discovered just in time by the IDF leading into a kibbutz communal dining hall,"" Benay Browne Katz, a volunteer medic and grandmother who lives in Jaffa, told CNN. 'Lower Gaza' The tunnel network has also been used during combat inside Gaza, the Israeli military says, allowing Hamas fighters to pop up and fire on soldiers or toss grenades before dropping back out of sight. Israeli military officials refer to the underground works as ""Lower Gaza"" and suggest at least some of the war is being waged underground. The tunnels aren't a new phenomenon. Hamas used one in 2006 to capture the Israeli soldier Gilad Shalit and take him back into Gaza. He was held captive for five years until a deal was struck for his release in exchange for more than 1,000 Palestinian prisoners. Memories of his capture were revived by a foiled attack over the weekend, in which one Hamas fighter who entered Israel through a tunnel was found to be carrying tranquilizers and handcuffs, according to the Israeli military. 'A whole industry' Israel received a warning of the growing scale and sophistication of the underground threat last year with the discovery of a tunnel that ran from the Khan Younis refugee camp in Gaza and emerged near the Israeli","Hamas has used tunnels to stage attacks in Israel and Gaza, the Israeli military says.\nThe military has destroyed some tunnels, but says it believes there are many more.\nHamas used a tunnel to capture an Israeli soldier in 2006.\nA tunnel into Israel discovered last year showed increasing sophistication."
1,"(CNN)The same superbug that contributed to two deaths in Los Angeles has been reported in North Carolina, where one person has died, a spokesman told CNN. Eighteen people have contracted carbapenem-resistant Enterobacteriaceae, or CRE, so far this year, said Kevin McCarthy, spokesman with the Carolinas HealthCare System. Of those, 15 had CRE upon admission to the hospital in Charlotte; three acquired it in the hospital, and one died, the spokesman said. The cause of death was not immediately clear. McCarthy declined to provide details on any of the patients. It was also not clear how any of them became infected. In Los Angeles, seven patients contracted CRE after routine endoscopic procedures. Two of them died, the Ronald Reagan UCLA Medical Center said last week. CRE was a contributing factor in the deaths, but the exact cause of the deaths wasn't immediately disclosed in those cases either. Hospital officials there have said the outbreak was caused by two medical scopes that still carried the deadly bacteria even though disinfection guidelines were followed. The UCLA hospital was using a duodenoscope made by Olympus, but the Food and Drug Administration is also reviewing data from the two other U.S. companies that make the devices, Fujifilm and Pentax. The medical center is contacting 179 others who underwent endoscopic procedures between October and January. It's offering them home tests to screen for the bacteria. In a statement, McCarthy said Carolinas HealthCare System uses standard methods for disinfecting its equipment, saying that all duodenoscopes that have been tested have shown to be negative for CRE. Some CRE bacteria can resist most antibiotics, the Centers for Disease Control and Prevention says on its website. CNN's Ben Tinker, Michael Martinez and Steve Almasy contributed to this report.","Eighteen people have contracted CRE so far this year at a hospital in Charlotte, North Carolina.\nCRE is highly resistant to many forms of antibiotic treatments, the CDC says."


## Cleanup

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 01_data-core.ipynb.
Converted 01a_data-language-modeling.ipynb.
Converted 01c_data-question-answering.ipynb.
Converted 01d_data-token-classification.ipynb.
Converted 01e_data-text-generation.ipynb.
Converted 02_modeling-core.ipynb.
Converted 02a_modeling-language-modeling.ipynb.
Converted 02c_modeling-question-answering.ipynb.
Converted 02d_modeling-token-classification.ipynb.
Converted 02e_modeling-text-generation.ipynb.
Converted index.ipynb.
