In [None]:
# default_exp data.text_generation

In [None]:
#hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# data.text_generation

> This module contains the bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data for text generation tasks using architectures like BART, T5, or good ol' GPT2, etc....  Abstract summarization and conversational agents are good examples of such tasks.

In [None]:
#export
import ast
from functools import reduce

import torch
from transformers import *
from fastai2.text.all import *

from blurr.utils import *
from blurr.data.core import *

In [None]:
#hide
import pdb

from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#cuda
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')

Using GPU #1: GeForce GTX 1080 Ti


## Text Generation tokenization, batch transform, and DataBlock methods

Text generation tasks attempt to generate a human-understandable and sensible response to a prior text.  For example, in summarization, our objective is to capture the meaning of a larger document in 1-3 sentences.

In [None]:
path = Path('./')
cnndm_df = pd.read_csv(path/'cnndm_sample.csv'); len(cnndm_df)

1000

In [None]:
cnndm_df.head(2)

Unnamed: 0,article,highlights,ds_type
0,"(CNN) -- Globalization washes like a flood over the world's cultures and economies. Floods can be destructive; however, they can also bring blessings, as the annual floods of the Nile did for ancient Egypt. The world's great universities can be crucial instruments in shaping, in a positive way, humankind's reaction to globalization and the development of humankind itself. Traditionally, universities have been defined and limited by location, creating an academic community and drawing students and scholars to that place. Eventually, some universities began to encourage students to study el...","John Sexton: Traditionally, universities have been defined and limited by location .\nGlobal campuses form a network of thought, innovation, he writes .\nFaculty can teach, Sexton says, students can team up in many cities at once .\nSexton: Research, scholarship can be shared and cultural ties made in ""century of knowledge""",train
1,"(CNN) -- Armenian President Robert Kocharian declared a state of emergency Saturday night after a day of clashes between police and protesters, a spokeswoman for the Armenian Foreign Ministry said. Opposition supporters wave an Armenian flag during a protest rally in Yerevan, Armenia, on Saturday. The protesters claim last month's presidential election was rigged. The state of emergency will ""hopefully bring some order"" to the capital, Yerevan, said Salpi Ghazarian, assistant to the Armenian foreign minister, who spoke to CNN early Sunday. The state of emergency could last until March 20, ...","NEW: Protest moves after crackdown at Freedom Square .\nOrder sought after protests over last month's election turn violent .\nDemonstrators say the election was fraudulent .\nState of emergency could last until March 20, official says .",train


In [None]:
pretrained_model_name = "bart-large-cnn"

hf_arch, hf_tokenizer, hf_config, hf_model = \
    BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, BartTokenizer, HF_MODELS.BartForConditionalGeneration)

hf_arch, type(hf_tokenizer), type(hf_config), type(hf_model)

('bart',
 transformers.tokenization_bart.BartTokenizer,
 transformers.configuration_bart.BartConfig,
 transformers.modeling_bart.BartForConditionalGeneration)

We're going to customize our `build_hf_input` for generation tasks to all include `decoder_input_ids` which is useful for summarization and conversational training tasks (see [here](https://huggingface.co/transformers/model_doc/bart.html#transformers.BartModel.forward) for more information)

In [None]:
#export
class HF_TextGenerationInput(list): pass

In [None]:
#export
@typedispatch
def build_hf_input(task:ForConditionalGenerationTask, tokenizer, a_tok_ids, b_tok_ids=None, targets=None,
                   max_length=512, pad_to_max_length=True, truncation_strategy='longest_first', **kwargs):

    enc_res = build_hf_input(None, tokenizer, a_tok_ids, b_tok_ids, targets,
                             max_length, pad_to_max_length, truncation_strategy)
    
    trg_max_length = kwargs['trg_max_length'] if ('trg_max_length' in kwargs) else max_length

    if (len(targets) > 0):
        trg = tokenizer.prepare_for_model(targets[0].tolist(), None, 
                                          max_length=trg_max_length,
                                          add_special_tokens=False,
                                          return_tensors='pt')

        pad_id, bos_id, eos_id = tokenizer.pad_token_id, tokenizer.bos_token_id, tokenizer.eos_token_id
        trg_ids = torch.cat((tensor(bos_id)[None], trg['input_ids'][0], tensor(eos_id)[None]))
        trg_ids = F.pad(trg_ids, pad=(0, trg_max_length), value=pad_id)[:trg_max_length+1]

        # the "decoder_input_ids" are the target ids shifted to the right by one (being with pad token)
        enc_res[0].append(trg_ids[:-1])
        target_ids = trg_ids[1:]
    else:
        target_ids = None

    return HF_TextGenerationInput(enc_res[0]), tuplify(target_ids)

In [None]:
blocks = ( 
    HF_TextBlock(hf_arch, hf_tokenizer), 
    HF_TextBlock(hf_arch, hf_tokenizer, task=ForConditionalGenerationTask(), trg_max_length=150)
)

dblock = DataBlock(blocks=blocks, 
                   get_x=ColReader('article'), 
                   get_y=ColReader('highlights'), 
                   splitter=RandomSplitter())

In [None]:
# dblock.summary(cnndm_df)

In [None]:
dls = dblock.dataloaders(cnndm_df, bs=4)

In [None]:
b = dls.one_batch()

In [None]:
len(b), b[0][0].shape, b[1].shape

(2, torch.Size([4, 512]), torch.Size([4, 150]))

In [None]:
#export
@typedispatch
def show_batch(x:HF_TextGenerationInput, y, samples, hf_tokenizer, skip_special_tokens=True, 
               ctxs=None, max_n=6, **kwargs):  
    res = L()
    for inp, trg in zip(x[0], y):
        txt_inp = hf_tokenizer.decode(inp, skip_special_tokens=skip_special_tokens).replace(hf_tokenizer.pad_token, '')
        txt_trg = hf_tokenizer.decode(trg, skip_special_tokens=skip_special_tokens).replace(hf_tokenizer.pad_token, '')
        res.append((txt_inp, txt_trg))
                       
    display_df(pd.DataFrame(res, columns=['text', 'target'])[:max_n])
    return ctxs

In [None]:
dls.show_batch(hf_tokenizer=hf_tokenizer, max_n=2)

Unnamed: 0,text,target
0,"(CNN) -- Home to up to 10 percent of all known species, Mexico is recognized as one of the most biodiverse regions on the planet. The twin threats of climate change and human encroachment on natural environments are, however, threatening the existence of the country's rich wildlife. And there is a great deal to lose. In the United Nations Environment Program (UNEP) World Conservation Monitoring Centre's list of megadiverse countries Mexico ranks 11th. The list represents a group of 17 countries that harbor the majority of the Earth's species and are therefore considered extremely biodiverse. From its coral reefs in the Caribbean Sea to its tropical jungles in Chiapas and the Yucatan peninsula and its deserts and prairies in the north, Mexico boasts an incredibly rich variety of flora and fauna. Some 574 out of 717 reptile species found in Mexico -- the most in any country -- can only be encountered within its borders. It is home to 502 types of mammals, 290 species of birds, 1,150 varieties of birds and 26,000 classifications of plants. Pronatura, a non-profit organization that works to promote conservation and sustainable development in Mexico, has selected six species which it says symbolize the problems faced by the destruction of nature. ""These are only some of the species which have some degree of conservation,"" says Eduardo Cota Corona, Director of Conservation at Pronatura. ""However, there is a countless number of species in Mexico which find themselves in danger of extinction."" Golden Eagle. It is the country's national symbol yet the Golden Eagle is close to extinction in Mexico. One of the largest raptors or birds of prey in the world, the Golden Eagle's wingspan can reach lengths greater than two metres. Only the Bald Eagle and the California Greater exceed it in size in North America. With its powerful hooked bill and long and sharp claws it can sometimes capture prey of a size that is surprising for its size, including crane, wild ungulates and domestic livestock, though more often than not it tends to feed off small mammals such as rabbits, hares, ground squirrels and prairie dogs as well as reptiles and small-to-medium sized birds. Primarily a solitary bird, the Golden Eagle pairs up to breed, building nests made of dry branches in cliffs and escarpments. The female typically lays two eggs which are incubated by both the male and female. Usually, only one of the hatchlings survives. The Golden Eagle can be found in","Mexico hosts to up to 10 percent of all known species on Earth.\nIt is home to 502 types of mammals, 290 bird species and 26,000 types of plants.\nHuman development and climate change is placing a big strain on its biodiversity.\nThe Golden Eagle is under threat in spite of being the country's national symbol."
1,"(CNN Student News) -- March 23, 2010. Download PDF maps related to today's show:. • Haiti • China. Transcript. THIS IS A RUSH TRANSCRIPT. THIS COPY MAY NOT BE IN ITS FINAL FORM AND MAY BE UPDATED. CARL AZUZ, CNN STUDENT NEWS ANCHOR: Happy birthday, Roger Bannister -- first man to run the mile in less than four minutes. In more than twice that time, you'll be up to speed on today's headlines. I'm Carl Azuz. First Up: Health Care. AZUZ: First up, it's the biggest expansion of the United States health care system in more than forty years. And by a vote of 219-212, the U.S. House of Representatives passed a health care reform bill late Sunday night. This is the same bill that the Senate passed last December. This means that when President Obama signs it, it's law. The House also passed a set of changes to the Senate bill. We're gonna get back to that in just a second. But first, you know this health care issue has been controversial. We want you to check out some of the reaction to last night's vote. REP. NANCY PELOSI, (D) HOUSE SPEAKER: Thirty-two million more Americans having access to health care. $1.3 trillion saved for the taxpayer and accountability for the insurance companies so they cannot come between patients and their doctors. REP. JOHN BOEHNER, (R) HOUSE MINORITY LEADER: Shame on us. Shame on this body. Shame on each and every one of you who substitutes your will and your desires above those of your fellow countrymen. BARACK OBAMA, PRESIDENT OF THE UNITED STATES: This legislation will not fix everything that ails our health care system, but it moves us decisively in the right direction. This is what change looks like. AZUZ: Okay, so the House passed the bill. Now what? Well, President Obama is expected to sign it today. Again, as soon as he does that, the Senate bill becomes law. But how about those changes passed by the House? Those still need to get through the Senate. That process can't start until the president signs the original bill. Some analysts are saying that the package of changes could inspire a pretty heated battle in the Senate. We're going to have more on that as the debate gets going. Another thing that's",Find out what comes next after the passage of a health care reform bill.\nLearn about a proposal that would change how student loans are funded.\nFollow the steps that led to a showdown between China and Google.\nUse the Daily Discussion to help students understand today's featured news stories.


## Cleanup

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 01_data-core.ipynb.
Converted 01a_data-language-modeling.ipynb.
Converted 01c_data-question-answering.ipynb.
Converted 01d_data-token-classification.ipynb.
Converted 01e_data-text-generation.ipynb.
Converted 02_modeling-core.ipynb.
Converted 02a_modeling-language-modeling.ipynb.
Converted 02c_modeling-question-answering.ipynb.
Converted 02d_modeling-token-classification.ipynb.
Converted 02e_modeling-text-generation.ipynb.
Converted index.ipynb.
