In [None]:
# default_exp callbacks

# Callbacks

> The foundational Callback event system for `Performer`

In [1]:
#hide
from nbdev.showdoc import *

In [12]:
#export
from abc import ABC, abstractmethod

In [20]:
#export
class InferenceCallback(ABC):
    """
    The foundational class for customizing behaviors during inference.
    
    There are three methods available that must be implemented:
      - `after_drawn_batch`
      - `gather_predictions`
      - `decoding_values`
      
    If an implementation should stay to its default behavior, do the following:
      - `event_name(self, *args): return super().event_name(*args)`
      
    Where `event_name` is any of the three events listed above
    """
    @abstractmethod
    def after_drawn_batch(self, batch):
        """
        Called immediatly after a batch has been drawn from the `DataLoader`.
        Any final adjustments to the batch before being sent to the model should be done here.
        
        Default implementation is to return `batch`.
        """
        return batch
    
    @abstractmethod
    def gather_predictions(self, model, batch): 
        """
        Performs inference with `model` on `batch`.
        Any specific inference decorators such as `no_grad` or `inference_mode` is done in `Performer`.
        
        Default implementation is `model(*batch)`.
        """
        return model(*batch)
    
    @abstractmethod
    def decoding_values(self, values): 
        """
        Called after predictions have been gathered on a `batch`.
        Any specific class decoding and final datatype preparation should be done here.
        
        Default implementation is to return `values`.
        """
        return values

In [21]:
show_doc(InferenceCallback)

<h2 id="InferenceCallback" class="doc_header"><code>class</code> <code>InferenceCallback</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>InferenceCallback</code>() :: `ABC`

The foundational class for customizing behaviors during inference.

There are three methods available that must be implemented:
  - `after_drawn_batch`
  - `gather_predictions`
  - `decoding_values`
  
If an implementation should stay to its default behavior, do the following:
  - `event_name(self, *args): return super().event_name(*args)`
  
Where `event_name` is one of the three events listed above

In [22]:
show_doc(InferenceCallback.gather_predictions)

<h4 id="InferenceCallback.gather_predictions" class="doc_header"><code>InferenceCallback.gather_predictions</code><a href="__main__.py#L26" class="source_link" style="float:right">[source]</a></h4>

> <code>InferenceCallback.gather_predictions</code>(**`model`**, **`batch`**)

Performs inference with `model` on `batch`.
Any specific inference decorators such as `no_grad` or `inference_mode` is done in `Performer`.

Default implementation is `model(*batch)`.

In [None]:
# Why abstract: Force users to think about if this is how they want their code in prod to be ran
# Since there would only be one level, easy to track where and how

In [8]:
class ImageClassifierCallback(InferenceCallback):
    def after_drawn_batch(self, batch):
        print(batch)
        return batch
    def gather_predictions(self, *args): return super().gather_predictions(*args)
    def decoding_values(self, *args): return super().decoding_values(*args)

In [None]:
class ImageClassifierCallback(InferenceCallback):
    def __init__(self, vocab):
        self.vocab = vocab
    def gather_predictions(self, model, batch):
        return model(*batch)
    def decoding_values(self, values):
        preds = values.argmax(dim=-1)
        decoded_preds = [self.vocab[p] for p in preds]
        return {"classes":decoded_preds, "probabilities":preds}

In [None]:
# Performer should only do the following automatically:
#  Device placement of both model and batch
#  Pull a batch from the DataLoader
#  From item(s), create a new `DataLoader`
#  Whether we run in torch.no_grad or not

#  Be configurable to allow for training_augmentation or test_augmentation if we are a fastai dataloader
#  Maybe as a context_manager? e.g. with Performer.in_training_state(): ... do inference. Sets it in Performer so we don't need it in the DataLoader.
#    Should raise a warning if the used `DataLoader` isn't one from fastai with this

# Maybe a `from_pipes` method?