In [None]:
# default_exp data.seq2seq.core

In [None]:
#hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# data.seq2seq.core

> This module contains the core seq2seq (e.g., language modeling, summarization, translation) bits required to use the fastai DataBlock API and/or mid-level data processing pipelines to organize your data in a way modelable by huggingface transformer implementations.

In [None]:
#export
from functools import reduce

import torch
from transformers import *
from fastai.text.all import *

from blurr.utils import *
from blurr.data.core import *

logging.set_verbosity_error()

In [None]:
#hide
import pdb

from nbdev.showdoc import *
from fastcore.test import *

from fastai import __version__ as fa_version
from torch import __version__ as pt_version
from transformers import __version__ as hft_version

print(f'Using pytorch {pt_version}')
print(f'Using fastai {fa_version}')
print(f'Using transformers {hft_version}')

Using pytorch 1.7.1+cu110
Using fastai 2.2.5
Using transformers 4.2.1


In [None]:
#cuda
torch.cuda.set_device(1)
print(f'Using GPU #{torch.cuda.current_device()}: {torch.cuda.get_device_name()}')

Using GPU #1: GeForce GTX 1080 Ti


In [None]:
pretrained_model_name = "facebook/bart-large-cnn"
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(pretrained_model_name, 
                                                                               model_cls=BartForConditionalGeneration)

hf_arch, type(hf_config), type(hf_tokenizer), type(hf_model)

('bart',
 transformers.models.bart.configuration_bart.BartConfig,
 transformers.models.bart.tokenization_bart_fast.BartTokenizerFast,
 transformers.models.bart.modeling_bart.BartForConditionalGeneration)

## Base tokenization, batch transform, and DataBlock methods

Seq2Seq tasks are essentially conditional generation tasks, this applies to specific derived tasks such as summarization and translation.  Given this, we can use the *same* HF_Seq2Seq transforms, `HF_Seq2SeqInput`, and `HF_Seq2SeqBlock` for these tasks

In [None]:
#export
class HF_Seq2SeqInput(HF_BaseInput): pass

We create a subclass of `HF_BeforeBatchTransform` for summarization tasks to add `decoder_input_ids` and `labels` to our inputs during training, which will in turn allow the huggingface model to calculate the loss for us.  See [here](https://huggingface.co/transformers/glossary.html#labels) and [here](https://huggingface.co/transformers/glossary.html#decoder-input-ids) for more information on these additional inputs used in summarization, translation, and conversational training tasks. How they should look for particular architectures can be found by looking at those model's `forward` function's docs (See [here](https://huggingface.co/transformers/model_doc/bart.html#transformers.BartModel.forward) for BART for example)

Note also that `labels` is simply target_ids shifted to the right by one since the task to is to predict the next token based on the current (and all previous) `decoder_input_ids`.

And lastly, we also update our targets to just be the `input_ids` of our target sequence so that fastai's `Learner.show_results` works (again, almost all the fastai bits require returning a single tensor to work).

In [None]:
#export
def default_text_gen_kwargs(hf_config, hf_model, task=None):
    text_gen_kwargs = {}
    hf_config_dict = hf_config.to_dict()

    generate_func_args = list(inspect.signature(hf_model.generate).parameters.keys())
    for k in generate_func_args:
        if (k in hf_config_dict): text_gen_kwargs.update({k: hf_config_dict[k]})
            
    # not all configs even have a task_specific_params property
    if (task is not None):
        try:
            text_gen_kwargs = { **text_gen_kwargs, **hf_config.task_specific_params[task] }
        except: pass
        
    return text_gen_kwargs

In [None]:
default_text_gen_kwargs(hf_config, hf_model)

{'max_length': 142,
 'min_length': 56,
 'do_sample': False,
 'early_stopping': True,
 'num_beams': 4,
 'temperature': 1.0,
 'top_k': 50,
 'top_p': 1.0,
 'repetition_penalty': 1.0,
 'bad_words_ids': None,
 'bos_token_id': 0,
 'pad_token_id': 1,
 'eos_token_id': 2,
 'length_penalty': 2.0,
 'no_repeat_ngram_size': 3,
 'num_return_sequences': 1,
 'decoder_start_token_id': 2,
 'use_cache': True,
 'num_beam_groups': 1,
 'diversity_penalty': 0.0,
 'output_attentions': False,
 'output_hidden_states': False,
 'output_scores': False,
 'return_dict_in_generate': False}

In [None]:
#hide
t = torch.randn((3,3));

F.pad(t, pad=(1,0), value=1)[:,:-1]

tensor([[ 1.0000,  0.3607, -1.2068],
        [ 1.0000, -0.9912, -0.3943],
        [ 1.0000,  1.2608, -0.7091]])

In [None]:
#export
class HF_Seq2SeqBeforeBatchTransform(HF_BeforeBatchTransform):
    
    def __init__(self, hf_arch, hf_config, hf_tokenizer, hf_model, 
                 ignore_token_id=CrossEntropyLossFlat().ignore_index,
                 max_length=None, max_target_length=None, padding=True, truncation=True,
                 tok_kwargs={}, text_gen_kwargs={}, **kwargs):
                 
        super().__init__(hf_arch, hf_config, hf_tokenizer, hf_model,
                         max_length=max_length, padding=padding, truncation=truncation, is_split_into_words=False, 
                         tok_kwargs=tok_kwargs.copy(), **kwargs)
        
        store_attr(self=self, names='text_gen_kwargs, max_target_length, ignore_token_id')
    
    def encodes(self, samples): 
        samples = L(samples)
        
        # tokenize
        src_texts=samples.itemgot(0).items
        tgt_texts=samples.itemgot(1).items if (len(samples[0]) > 1) else None
        
        try:
            tok_d = self.hf_tokenizer.prepare_seq2seq_batch(src_texts=src_texts, tgt_texts=tgt_texts, 
                                                            max_length=self.max_length, 
                                                            max_target_length=self.max_target_length,
                                                            padding=self.padding, 
                                                            truncation=self.truncation, 
                                                            return_tensors='pt', 
                                                            **self.tok_kwargs)
        except NotImplementedError as err:
            # not all seq2seq models implement "prepare_seq2seq_batch" (i.e., blenderbot)
            tok_d = self.hf_tokenizer(src_texts, max_length=self.max_length, padding=self.padding, 
                                      truncation=self.truncation, return_tensors='pt', **self.tok_kwargs)
            
            if (tgt_texts):
                tok_d_targs = self.hf_tokenizer(tgt_texts, max_length=self.max_target_length, padding=self.padding, 
                                      truncation=self.truncation, return_tensors='pt', **self.tok_kwargs)
            
                tok_d['labels'] = tok_d_targs['input_ids']
        
        # add in target ids for us to use if fastai is calculating the loss
        targ_ids = [[]] * len(samples)
        if ('labels' in tok_d):
            tok_d['labels'].masked_fill_(tok_d['labels'] == self.ignore_token_id, self.hf_tokenizer.pad_token_id)
            targ_ids = tok_d['labels'].clone()

        # update samples with tokenized inputs (e.g. input_ids, attention_mask, etc...)
        d_keys = tok_d.keys()
        updated_samples= [ (*[{k: tok_d[k][idx] for k in d_keys}], *tuplify(targ_ids[idx]), *sample[2:]) 
                          for idx, sample in enumerate(samples) ]
        
        return updated_samples

We include a new AFTER batch `Transform` and `TransformBlock` specific to text-2-text tasks.

In [None]:
#export
class HF_Seq2SeqAfterBatchTransform(HF_AfterBatchTransform):
    def decodes(self, encoded_samples):
        input_ids = encoded_samples['input_ids'] if (isinstance(encoded_samples, dict)) else encoded_samples
        return self.input_return_type(input_ids, hf_tokenizer=self.hf_tokenizer)
    
    
class HF_Seq2SeqBlock(HF_TextBlock):
    
    def __init__(self, hf_arch=None, hf_config=None, hf_tokenizer=None, hf_model=None,
                 before_batch_tfm=None, after_batch_tfm=None,
                 max_length=None, max_target_length=None, padding=True, truncation=True, 
                 input_return_type=HF_Seq2SeqInput, dl_type=SortedDL, 
                 tok_kwargs={}, text_gen_kwargs={}, before_batch_kwargs={}, after_batch_kwargs={}, **kwargs):
        
        # we need to pass text_gen_kwargs into our HF_Seq2SeqBeforeBatchTransform (use default unless specified)
        if (len(text_gen_kwargs) == 0): 
            if (hf_config is None): hf_config = before_batch_tfm.hf_config
            if (hf_model is None): hf_model = before_batch_tfm.hf_model
            self.text_gen_kwargs = default_text_gen_kwargs(hf_config, hf_model)
        else:
            self.text_gen_kwargs = text_gen_kwargs.copy()
            
        # construct our before_batch and after_batch tfms as usual
        if (before_batch_tfm is None): 
            before_batch_tfm = HF_Seq2SeqBeforeBatchTransform(hf_arch, hf_config, hf_tokenizer, hf_model,
                                                              max_length=max_length, 
                                                              max_target_length=max_target_length,
                                                              padding=padding, 
                                                              truncation=truncation,
                                                              tok_kwargs=tok_kwargs.copy(), 
                                                              text_gen_kwargs=text_gen_kwargs, 
                                                              **before_batch_kwargs.copy())

        if (after_batch_tfm is None): 
            hf_tokenizer = hf_tokenizer if (hf_tokenizer is not None) else before_batch_tfm.hf_tokenizer
            after_batch_tfm = HF_Seq2SeqAfterBatchTransform(hf_tokenizer, input_return_type,
                                                            **after_batch_kwargs.copy())
                
        return super().__init__(before_batch_tfm=before_batch_tfm, after_batch_tfm=after_batch_tfm,
                                max_length=max_length, padding=padding, truncation=truncation, 
                                is_split_into_words=False, 
                                input_return_type=input_return_type, dl_type=dl_type, 
                                tok_kwargs=tok_kwargs, 
                                before_batch_kwargs=before_batch_kwargs, 
                                after_batch_kwargs=after_batch_kwargs, 
                                **kwargs)          

... and a `DataLoaders.show_batch` for seq2seq tasks

In [None]:
#export
@typedispatch
def show_batch(x:HF_Seq2SeqInput, y, samples, dataloaders, ctxs=None, max_n=6, 
               input_trunc_at=None, target_trunc_at=None, **kwargs):  
    # grab our tokenizer and ignore token to decode
    hf_before_batch_tfm = get_blurr_tfm(dataloaders.before_batch)
    hf_tokenizer = hf_before_batch_tfm.hf_tokenizer
    ignore_token_id = hf_before_batch_tfm.ignore_token_id
    
    res = L([ (hf_tokenizer.decode(s[0], skip_special_tokens=True)[:input_trunc_at], 
               hf_tokenizer.decode(s[1][s[1] != ignore_token_id], skip_special_tokens=True)[:target_trunc_at])
             for s in samples ])      
    
    display_df(pd.DataFrame(res, columns=['text', 'target'])[:max_n])
    return ctxs

## Cleanup

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 00_utils.ipynb.
Converted 01_data-core.ipynb.
Converted 01a_data-token-classification.ipynb.
Converted 01b_data-question-answering.ipynb.
Converted 01za_data-seq2seq-core.ipynb.
Converted 01zb_data-seq2seq-language-modeling.ipynb.
Converted 01zc_data-seq2seq-summarization.ipynb.
Converted 01zd_data-seq2seq-translation.ipynb.
Converted 02_modeling-core.ipynb.
Converted 02a_modeling-token-classification.ipynb.
Converted 02b_modeling-question-answering.ipynb.
Converted 02za_modeling-seq2seq-core.ipynb.
Converted 02zb_modeling-seq2seq-language-modeling.ipynb.
Converted 02zc_modeling-seq2seq-summarization.ipynb.
Converted 02zc_modeling-seq2seq-translation.ipynb.
Converted 99a_examples-multilabel.ipynb.
Converted index.ipynb.
