In [None]:
# default_exp data.text_generation

In [None]:
#hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# data.text_generation

> This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for text generation tasks using architectures like BART, T5, or good ol' GPT2, etc....  Abstract summarization and conversational agents are good examples of such tasks.

In [None]:
#export
import ast
from functools import reduce

import torch
from transformers import *
from fastai2.text.all import *

from blurr.utils import *
from blurr.data.core import *

In [None]:
#hide
import pdb

from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#cuda
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')

Using GPU #1: GeForce GTX 1080 Ti


## Text Generation tokenization, batch transform, and DataBlock methods

Text generation tasks attempt to generate a human-understandable and sensible response to a prior text.  For example, in summarization, our objective is to capture the meaning of a larger document in 1-3 sentences.

In [None]:
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv'); len(cnndm_df)

1000

In [None]:
cnndm_df.head(2)

Unnamed: 0,article,highlights,ds_type
0,"(CNN) -- Globalization washes like a flood over the world's cultures and economies. Floods can be destructive; however, they can also bring blessings, as the annual floods of the Nile did for ancient Egypt. The world's great universities can be crucial instruments in shaping, in a positive way, humankind's reaction to globalization and the development of humankind itself. Traditionally, universities have been defined and limited by location, creating an academic community and drawing students and scholars to that place. Eventually, some universities began to encourage students to study el...","John Sexton: Traditionally, universities have been defined and limited by location .\nGlobal campuses form a network of thought, innovation, he writes .\nFaculty can teach, Sexton says, students can team up in many cities at once .\nSexton: Research, scholarship can be shared and cultural ties made in ""century of knowledge""",train
1,"(CNN) -- Armenian President Robert Kocharian declared a state of emergency Saturday night after a day of clashes between police and protesters, a spokeswoman for the Armenian Foreign Ministry said. Opposition supporters wave an Armenian flag during a protest rally in Yerevan, Armenia, on Saturday. The protesters claim last month's presidential election was rigged. The state of emergency will ""hopefully bring some order"" to the capital, Yerevan, said Salpi Ghazarian, assistant to the Armenian foreign minister, who spoke to CNN early Sunday. The state of emergency could last until March 20, ...","NEW: Protest moves after crackdown at Freedom Square .\nOrder sought after protests over last month's election turn violent .\nDemonstrators say the election was fraudulent .\nState of emergency could last until March 20, official says .",train


In [None]:
pretrained_model_name = "facebook/bart-large-cnn"

hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, 
                                                                               model_cls=BartForConditionalGeneration)

hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)

Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-large-cnn and are newly initialized: ['final_logits_bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


('bart',
 transformers.tokenization_bart.BartTokenizer,
 transformers.configuration_bart.BartConfig,
 transformers.modeling_bart.BartForConditionalGeneration)

In [None]:
#export
class HF_TextGenerationInput(list): pass

We create a subclass of `HF_BatchTransform` for generation tasks to add `decoder_input_ids` and `labels` to our inputs during training, which will in turn allow the huggingface model to calculate the loss for us.  See [here](https://huggingface.co/transformers/model_doc/bart.html#transformers.BartModel.forward) for more information on these additional inputs are used in summarization and conversational training tasks.  

Note also that `labels` is simply target_ids shifted to the right by one since the task to is to predict the next token based on the current (and all previous) `decoder_input_ids`.

And lastly, we also update our targets to just be the `input_ids` of our target sequence so that fastai's `Learner.show_results` works (again, almost all the fastai bits require returning a single tensor to work).

In [None]:
#export
class HF_TextGenerationBatchTransform(HF_BatchTransform):
    def __init__(self, hf_arch, hf_tokenizer, **kwargs):
        super().__init__(hf_arch, hf_tokenizer, HF_TextGenerationInput, **kwargs)
        
    def encodes(self, samples):  
        samples = super().encodes(samples)
        if (len(samples[0]) == 1): return samples
        
        updated_samples = []
        for s in samples:
            s[0]['decoder_input_ids'] = s[1]['input_ids'][:-1].clone()
            s[0]['labels'] = s[1]['input_ids'][1:].clone()
            s[0]['labels'][s[0]['labels'] == self.hf_tokenizer.pad_token_id] = -100
            
            targ_ids = s[1]['input_ids']
            
            updated_samples.append((s[0], targ_ids))
        
        return updated_samples
    
    def decodes(self, encoded_samples):
        if (isinstance(encoded_samples, dict)): return self.hf_input_return_type([encoded_samples['input_ids']])
        return [encoded_samples]

We had to override the `decodes` method above because, while both our inputs and targets are technically the same things, we update the later to consist of *only* the target input_ids so that methods like `Learner.show_results` work.  Nevertheless, because fastai remembers what they are, `HF_TokenizerTransform.decodes` will be called for both and it works on a `list` of input_ids.

In [None]:
hf_batch_tfm = HF_TextGenerationBatchTransform(hf_arch, hf_tokenizer)

blocks = ( 
    HF_TextBlock(hf_arch, hf_tokenizer), 
    HF_TextBlock(hf_arch, hf_tokenizer, hf_batch_tfm=hf_batch_tfm, max_length=150)
)

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('article'), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())

In [None]:
# dblock.summary(cnndm_df)

In [None]:
dls = dblock.dataloaders(cnndm_df, bs=4)

In [None]:
b = dls.one_batch()

In [None]:
len(b), b[0]['input_ids'].shape, b[1].shape

(2, torch.Size([4, 512]), torch.Size([4, 150]))

In [None]:
#export
@typedispatch
def show_batch(x:HF_TextGenerationInput, y, samples, dataloaders=None, ctxs=None, max_n=6, **kwargs):  
    res = L([ (s[0], s[1]) for s in samples ])          
    display_df(pd.DataFrame(res, columns=['text', 'target'])[:max_n])
    return ctxs

In [None]:
dls.show_batch(dataloaders=dls, max_n=2)

Unnamed: 0,text,target
0,"London (CNN) -- More than three meters above east London's Sclater Street is a mural of sprinter Usain Bolt, captured in explosive color by artist James Cochran. The street artwork, more than four meters high and six meters wide, is a dramatic sight, designed by Cochran to celebrate London's Olympic Games. Cochran, known as ""Jimmy C,"" has a style which combines his background in graffiti art and academic training in figurative realism. The UK-born artist is based in London's Shoreditch, having spent much of his life in Australia. After his mother died in a car crash when Cochran was 12, his family life deteriorated. At 16, he spent three months on the streets where he began painting with aerosol cans. Cochran later completed a visual arts degree at the University of South Australia, before going on to complete a masters degree. The Olympics miss: Why street art should be embraced not snubbed. His art often depicts the homeless, as he seeks to capture ""a more raw essence of the human subject."" Cochran's style has evolved from what he calls ""aerosol pointillism,"" with its impressionist overtones, to ""atomic pointillism, in which the subject appears to atomize. Why gritty East End is London's gold standard. Cochran also explores the relationship between individuals and the urban landscape, which has led to paintings of buildings sprouting from heads -- one of which can be seen in this Sam Taylor-Wood directed REM video. Cochran's canvases sell for thousands of pounds. But he continues to paint on the streets, and his work can be found in cities including Paris, Berlin and New York. ""When you paint on the street there is a lot more rawness to it,"" Cochran says. ""Anything can happen."" East London, where the city's Olympic Games are based, is a hub for street art. ""That's the great thing about Shoreditch and Hackney,"" says Cochran. ""[The art] is part of the look of the street.""","James Cochran, known as ""Jimmy C,"" painted an outdoor portrait of athlete Usain Bolt.\nCochran has a background in graffiti art and academic arts training.\nHe is based in east London's Shoreditch, which is a hub for the city's street artists.\nCochran painted the mural in the hope it would last beyond the Olympic Games."
1,"(CNN) -- Sarin gas has been used several times in the Syrian civil war, including at least once by the Assad regime, France's foreign minister said Tuesday, citing results from test samples in France's possession. Laurent Fabius announced that conclusion after meeting with the head of a United Nations mission set up to establish the facts about the alleged use of chemical weapons in Syria. ""I gave him the results of tests carried out by our lab appointed by the Organization for the Prohibition of Chemical Weapons to identify chemical warfare,"" Fabius said in a statement, referring to the Swedish scientist Professor Ake Sellstrom. ""These results show the presence of sarin in the samples that are in our possession,"" Fabius said. ""In view of these elements, France now has the certainty that the sarin gas was used in Syria several times and in a localized manner."" In an interview later Tuesday with CNN affiliate France 2, Fabius blamed the Syrian government in at least one of the cases. ""There is no doubt that it is the regime and its accomplices,"" Fabius said. He added the French government examined the chain of events from the moment of the attack through the lab results to determine that government was responsible. Fabius' announcement did not say when or where the weapons may have been used or who may have used the gas in the other cases. Syrian rebels have been fighting the government for more than two years. Atrocities have been blamed on both sides. The announcement coincided with the release of a draft report posted on the website of the U.N. Human Rights Council that concludes: ""There are reasonable grounds to believe that chemical agents have been used as weapons. The precise agents, delivery systems or perpetrators could not be identified."" In Washington, White House press secretary Jay Carney said the United States was working with the French and other allies as well as the Syrian opposition to determine those answers. ""We need to expand the evidence we have,"" he told reporters Tuesday. ""We need to make it reviewable; we need to have it corroborated before we make any decisions based on the clear violation that use of chemical weapons would represent by the Syrian regime. So, we will continue in that effort."" Asked how long that might take, he said, ""I don't have a timetable for you."" He noted that Damascus has consistently turned down U.S. requests for a U.N. investigative team to be sent to Syria. ""But we are not relying on the United Nations alone,"" he said. ""We are aggressively pursuing","NEW: Fabius says the Assad regime is culpable in at least one instance.\nFrance is certain sarin gas was used in Syria ""several times,"" Fabius says.\nThe announcement comes after a meeting with the head of a fact-finding mission.\nHuman Rights Council report: ""Reasonable grounds"" to believe chemical agents were used."


## Cleanup

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 01_data-core.ipynb.
Converted 01a_data-language-modeling.ipynb.
Converted 01c_data-question-answering.ipynb.
Converted 01d_data-token-classification.ipynb.
Converted 01e_data-text-generation.ipynb.
Converted 02_modeling-core.ipynb.
Converted 02a_modeling-language-modeling.ipynb.
Converted 02c_modeling-question-answering.ipynb.
Converted 02d_modeling-token-classification.ipynb.
Converted 02e_modeling-text-generation.ipynb.
Converted index.ipynb.
