In [1]:
# default_exp callback

# Callbacks
> Callbacks for predicting within AdaptNLP using the fastai framework

In [2]:
#hide
from nbverbose.showdoc import *

In [3]:
#export
from fastcore.basics import store_attr
from fastcore.meta import delegates

from fastai.callback.core import Callback, CancelBatchException

from transformers import PreTrainedModel

In [4]:
#export
class GatherInputsCallback(Callback):
    """
    Prepares basic input dictionary for HuggingFace Transformers

    This `Callback` generates a very basic dictionary consisting of `input_ids`,
    `attention_masks`, and `token_type_ids`, and saves it to the attribute `self.learn.inputs`.

    If further data is expected or needed from the batch, the additional Callback(s) should have
    an order of -2
    """
    order = -3

    def before_validate(self):
        """
        Sets the number of inputs in `self.dls`
        """
        x = self.dl.one_batch()
        self.learn.dls.n_inp = len(x)

    def before_batch(self):
        """
        Turns `self.xb` from a tuple to a dictionary of either
            `{"input_ids", "attention_masks", "token_type_ids"}`d
        or
            `{"input_ids", "attention_masks"}`
        """
        inputs = {
                "input_ids":self.learn.xb[0],
                "attention_mask":self.learn.xb[1]
        }

        if len(self.learn.xb) > 2:
            inputs["token_type_ids"] = self.learn.xb[2]

        self.learn.inputs = inputs

In [5]:
show_doc(GatherInputsCallback.before_validate)

<h4 id="GatherInputsCallback.before_validate" class="doc_header"><code>GatherInputsCallback.before_validate</code><a href="__main__.py#L14" class="source_link" style="float:right">[source]</a></h4>

> <code>GatherInputsCallback.before_validate</code>()

Sets the number of inputs in `self.dls`



In [6]:
show_doc(GatherInputsCallback.before_batch)

<h4 id="GatherInputsCallback.before_batch" class="doc_header"><code>GatherInputsCallback.before_batch</code><a href="__main__.py#L21" class="source_link" style="float:right">[source]</a></h4>

> <code>GatherInputsCallback.before_batch</code>()

Turns `self.xb` from a tuple to a dictionary of either
    `{"input_ids", "attention_masks", "token_type_ids"}`d
or
    `{"input_ids", "attention_masks"}`



In [7]:
#export
class SetInputsCallback(Callback):
    """
    Callback which runs after `GatherInputsCallback` that sets `self.learn.xb`
    """
    order = -1

    def __init__(
        self, 
        as_dict=False # Whether to leave `self.xb` as a dictionary of values
    ): store_attr()

    def before_batch(self):
        """
        Set `self.learn.xb` to `self.learn.inputs.values()`
        """
        if not self.as_dict:
            self.learn.xb = list(self.learn.inputs.values())
        else:
            self.learn.xb = self.learn.inputs

In [8]:
show_doc(SetInputsCallback.before_batch)

<h4 id="SetInputsCallback.before_batch" class="doc_header"><code>SetInputsCallback.before_batch</code><a href="__main__.py#L13" class="source_link" style="float:right">[source]</a></h4>

> <code>SetInputsCallback.before_batch</code>()

Set `self.learn.xb` to `self.learn.inputs.values()`



In [9]:
#export
class GeneratorCallback(Callback):
    """
    Callback used for models that utilize `self.model.generate`
    """
    
    @delegates(PreTrainedModel.generate)
    def __init__(
        self, 
        num_beams:int, # Number of beams for beam search
        min_length:int, # Minimal length of the sequence generated
        max_length:int, # Maximum length of the sequence generated
        early_stopping:bool, # Whether to do early stopping
        **kwargs
    ):
        store_attr()
        self.kwargs = kwargs
    
    def before_batch(self):
        """
        Run model-specific inference
        """
        
        pred = self.learn.model.generate(
            input_ids = self.xb['input_ids'],
            attention_mask = self.xb['attention_mask'],
            num_beams = self.num_beams,
            min_length = self.min_length,
            max_length = self.max_length,
            early_stopping = self.early_stopping,
            **self.kwargs
        )
        
        self.learn.pred = pred
        
        raise CancelBatchException # skip original model inference

In [10]:
show_doc(GeneratorCallback)

<h2 id="GeneratorCallback" class="doc_header"><code>class</code> <code>GeneratorCallback</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>GeneratorCallback</code>(**`num_beams`**:`int`, **`min_length`**:`int`, **`max_length`**:`int`, **`early_stopping`**:`bool`, **`input_ids`**:`Optional`\[`LongTensor`\]=*`None`*, **`do_sample`**:`Optional`\[`bool`\]=*`None`*, **`temperature`**:`Optional`\[`float`\]=*`None`*, **`top_k`**:`Optional`\[`int`\]=*`None`*, **`top_p`**:`Optional`\[`float`\]=*`None`*, **`repetition_penalty`**:`Optional`\[`float`\]=*`None`*, **`bad_words_ids`**:`Optional`\[`Iterable`\[`int`\]\]=*`None`*, **`bos_token_id`**:`Optional`\[`int`\]=*`None`*, **`pad_token_id`**:`Optional`\[`int`\]=*`None`*, **`eos_token_id`**:`Optional`\[`int`\]=*`None`*, **`length_penalty`**:`Optional`\[`float`\]=*`None`*, **`no_repeat_ngram_size`**:`Optional`\[`int`\]=*`None`*, **`encoder_no_repeat_ngram_size`**:`Optional`\[`int`\]=*`None`*, **`num_return_sequences`**:`Optional`\[`int`\]=*`None`*, **`max_time`**:`Optional`\[`float`\]=*`None`*, **`max_new_tokens`**:`Optional`\[`int`\]=*`None`*, **`decoder_start_token_id`**:`Optional`\[`int`\]=*`None`*, **`use_cache`**:`Optional`\[`bool`\]=*`None`*, **`num_beam_groups`**:`Optional`\[`int`\]=*`None`*, **`diversity_penalty`**:`Optional`\[`float`\]=*`None`*, **`prefix_allowed_tokens_fn`**:`Optional`\[`Callable`\[`int`, `Tensor`, `List`\[`int`\]\]\]=*`None`*, **`output_attentions`**:`Optional`\[`bool`\]=*`None`*, **`output_hidden_states`**:`Optional`\[`bool`\]=*`None`*, **`output_scores`**:`Optional`\[`bool`\]=*`None`*, **`return_dict_in_generate`**:`Optional`\[`bool`\]=*`None`*, **`forced_bos_token_id`**:`Optional`\[`int`\]=*`None`*, **`forced_eos_token_id`**:`Optional`\[`int`\]=*`None`*, **`remove_invalid_values`**:`Optional`\[`bool`\]=*`None`*, **`synced_gpus`**:`Optional`\[`bool`\]=*`None`*) :: `Callback`

Callback used for models that utilize `self.model.generate`

**Parameters:**


 - **`num_beams`** : *`<class 'int'>`*	<p>Number of beams for beam search</p>


 - **`min_length`** : *`<class 'int'>`*	<p>Minimal length of the sequence generated</p>


 - **`max_length`** : *`<class 'int'>`*	<p>Maximum length of the sequence generated</p>


 - **`early_stopping`** : *`<class 'bool'>`*	<p>Whether to do early stopping</p>


 - **`input_ids`** : *`typing.Union[torch.LongTensor, NoneType]`*, *optional*

 - **`do_sample`** : *`typing.Union[bool, NoneType]`*, *optional*

 - **`temperature`** : *`typing.Union[float, NoneType]`*, *optional*

 - **`top_k`** : *`typing.Union[int, NoneType]`*, *optional*

 - **`top_p`** : *`typing.Union[float, NoneType]`*, *optional*

 - **`repetition_penalty`** : *`typing.Union[float, NoneType]`*, *optional*

 - **`bad_words_ids`** : *`typing.Union[typing.Iterable[int], NoneType]`*, *optional*

 - **`bos_token_id`** : *`typing.Union[int, NoneType]`*, *optional*

 - **`pad_token_id`** : *`typing.Union[int, NoneType]`*, *optional*

 - **`eos_token_id`** : *`typing.Union[int, NoneType]`*, *optional*

 - **`length_penalty`** : *`typing.Union[float, NoneType]`*, *optional*

 - **`no_repeat_ngram_size`** : *`typing.Union[int, NoneType]`*, *optional*

 - **`encoder_no_repeat_ngram_size`** : *`typing.Union[int, NoneType]`*, *optional*

 - **`num_return_sequences`** : *`typing.Union[int, NoneType]`*, *optional*

 - **`max_time`** : *`typing.Union[float, NoneType]`*, *optional*

 - **`max_new_tokens`** : *`typing.Union[int, NoneType]`*, *optional*

 - **`decoder_start_token_id`** : *`typing.Union[int, NoneType]`*, *optional*

 - **`use_cache`** : *`typing.Union[bool, NoneType]`*, *optional*

 - **`num_beam_groups`** : *`typing.Union[int, NoneType]`*, *optional*

 - **`diversity_penalty`** : *`typing.Union[float, NoneType]`*, *optional*

 - **`prefix_allowed_tokens_fn`** : *`typing.Union[typing.Callable[[int, torch.Tensor], typing.List[int]], NoneType]`*, *optional*

 - **`output_attentions`** : *`typing.Union[bool, NoneType]`*, *optional*

 - **`output_hidden_states`** : *`typing.Union[bool, NoneType]`*, *optional*

 - **`output_scores`** : *`typing.Union[bool, NoneType]`*, *optional*

 - **`return_dict_in_generate`** : *`typing.Union[bool, NoneType]`*, *optional*

 - **`forced_bos_token_id`** : *`typing.Union[int, NoneType]`*, *optional*

 - **`forced_eos_token_id`** : *`typing.Union[int, NoneType]`*, *optional*

 - **`remove_invalid_values`** : *`typing.Union[bool, NoneType]`*, *optional*

 - **`synced_gpus`** : *`typing.Union[bool, NoneType]`*, *optional*
