In [None]:
# default_exp configure

# Configure

> The foundational event system for `Performer` based on fastai `Callback`s

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
from abc import ABC, abstractmethod, abstractproperty
from typing import Any

import torch
from fastreinference.utils import SelfEnum

In [None]:
#export
@SelfEnum
class DeviceType:
    """
    Enum of all supported device placements
    """
    CPU:Any
    CUDA:Any

In [None]:
#export
def get_default_device():
    """
    Returns `DeviceType.CPU` if GPU is not available, else `DeviceType.CUDA`
    """
    return DeviceType.CPU if not torch.cuda.is_available() else DeviceType.CUDA

In [None]:
#export
@SelfEnum(special=["INFERENCE"])
class ManagerType:
    """
    Enum of the various context manager options you can use when doing inference, with documentation of its members
    """
    NO_GRAD:Any = "Run with `torch.no_grad`"
    INFERENCE = "inference_mode", "Run with `torch.inference_mode`"
    NONE:Any = "Keep all gradients and apply no context managers"

In [None]:
show_doc(ManagerType.NO_GRAD)

<h4 id="no_grad" class="doc_header"><code>no_grad</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Run with `torch.no_grad`

In [None]:
show_doc(ManagerType.INFERENCE)

<h4 id="inference_mode" class="doc_header"><code>inference_mode</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Run with `torch.inference_mode`

In [None]:
show_doc(ManagerType.NONE)

<h4 id="none" class="doc_header"><code>none</code><a href="" class="source_link" style="float:right">[source]</a></h4>

Keep all gradients and apply no context managers

In [None]:
#export
class InferenceConfiguration(ABC):
    """
    The foundational class for customizing behaviors during inference.
    
    There are three methods available that must be implemented:
      - `after_drawn_batch`
      - `gather_predictions`
      - `decoding_values`
      
    If an implementation should stay to its default behavior, do the following:
      - `event_name(self, *args): return super().event_name(*args)`
      
    Where `event_name` is any of the three events listed above
    
    A `context` can be set with a `ManagerType` for what type of context manager should be ran at inference time
    """
    context: ManagerType = ManagerType.NO_GRAD # Context manager to be ran at inference time
    device: DeviceType = get_default_device() # Device to be used during inference. Default is cuda if available
    
    @abstractmethod
    def after_drawn_batch(self, batch):
        """
        Called immediatly after a batch has been drawn from the `DataLoader`.
        Any final adjustments to the batch before being sent to the model should be done here.
        
        Default implementation is to return `batch`.
        """
        return batch
    
    @abstractmethod
    def gather_predictions(self, model, batch): 
        """
        Performs inference with `model` on `batch`.
        Any specific inference decorators such as `no_grad` or `inference_mode` is done in `Performer`.
        
        Default implementation is `model(*batch)`.
        """
        return model(*batch)
    
    @abstractmethod
    def decoding_values(self, values): 
        """
        Called after predictions have been gathered on a `batch`.
        Any specific class decoding and final datatype preparation should be done here.
        
        Default implementation is to return `values`.
        """
        return values

In [None]:
show_doc(InferenceConfiguration)

<h2 id="InferenceConfiguration" class="doc_header"><code>class</code> <code>InferenceConfiguration</code><a href="" class="source_link" style="float:right">[source]</a></h2>

> <code>InferenceConfiguration</code>() :: `ABC`

The foundational class for customizing behaviors during inference.

There are three methods available that must be implemented:
  - `after_drawn_batch`
  - `gather_predictions`
  - `decoding_values`
  
If an implementation should stay to its default behavior, do the following:
  - `event_name(self, *args): return super().event_name(*args)`
  
Where `event_name` is any of the three events listed above

A `context` can also be set with a `ManagerType` for what type of context manager should be ran at inference time

In [None]:
show_doc(InferenceConfiguration.gather_predictions)

<h4 id="InferenceConfiguration.gather_predictions" class="doc_header"><code>InferenceConfiguration.gather_predictions</code><a href="__main__.py#L30" class="source_link" style="float:right">[source]</a></h4>

> <code>InferenceConfiguration.gather_predictions</code>(**`model`**, **`batch`**)

Performs inference with `model` on `batch`.
Any specific inference decorators such as `no_grad` or `inference_mode` is done in `Performer`.

Default implementation is `model(*batch)`.

In [None]:
# Why abstract: Force users to think about if this is how they want their code in prod to be ran
# Since there would only be one level, easy to track where and how

In [None]:
class ImageClassifierConfiguration(InferenceConfiguration):
    def __init__(self, vocab):
        self.vocab = vocab
    def after_drawn_batch(self, batch): super().after_drawn_batch(batch)
    def gather_predictions(self, model, batch):
        return model(*batch)
    def decoding_values(self, values):
        preds = values.argmax(dim=-1)
        decoded_preds = [self.vocab[p] for p in preds]
        return {"classes":decoded_preds, "probabilities":preds}