In [None]:
#default_exp learner

# learner
> The main interface for this library

In [None]:
#export
# We only import .* so all transforms are in namespace
from fastinference_pytorch.transforms.data import *
from fastinference_pytorch.transforms.vision import *
from fastinference_pytorch.utils import to_device, tensor, to_numpy
from fastinference_pytorch.rebuild import make_pipelines, generate_pipeline, load_model, load_data

In [None]:
#export
from pathlib import Path
from fastcore.utils import store_attr
import torch

## `data_funcs`

`data_funcs` are designed to be quick references to tell your `Learner` how your data is going to come in. Below is an example for grabbing images, however this should be however you expect your data to be coming in:

In [None]:
#export
def get_image(fn, mode='RGB'): return PILBase.create(fn,mode=mode)

## Prediction functions

These are prediction functions for a `PyTorch` model and an `ONNX` model

In [None]:
#export
def torch_preds(batch, model):
    "Get predictions from torch model"
    with torch.no_grad():
        model.eval()
        out = model(batch)
    return out

In [None]:
#export
def onnx_preds(batch, model):
    if isinstance(batch[0], Tensor): inps = [to_numpy(x) for x in batch]
    names = [i.name for i in model.get_inputs()]
    xs = {name:x for name,x in zip(names, batch)}
    return tensor(model.run(None, xs))

## Learner

In [None]:
class GenericDataset():
    def __init__(self, data, pipelines):
        "Generic dataset with `pipelines` on `get_item`"
        self.pipelines = make_pipelines(data)

In [None]:
#export
class Learner():
    """
    Similar to a `fastai` learner for inference

    Params:
      > `path` (str): The exact path to where your data and model is stored, relative to the `cwd`
      > `data_fn` (str): Filename of your pickled data
      > `model_fn` (str): Filename of your model
      > `data_func` (function): A function in which has the ability to grab your data based on some input.
                     The default grabs an image in a location and opens it with Pillow
      > `bs` (int): The batch size you are wanting to use per inference (this can be tweaked later)
      > `cpu` (bool): Whether to use the CPU or GPU
      > `onnx` (bool): Whether the model is expected to be PyTorch or ONNX format

    Example use:

    learn = Learner('models/data', 'models/model', data_func=image_getter, bs=4, cpu=True)
    """
    def __init__(self, path = Path('.'), data_fn='data', model_fn='model', data_func=None, bs=16, cpu=False, onnx=False):
        data = load_data(path, data_fn)
        self.n_inp = data['n_inp']
        self.pipelines = make_pipelines(data)
        self.after_item = self.pipelines['after_item']
        self.after_batch = self.pipelines['after_batch']
        self.tfm_y = generate_pipeline(data['tfms'], order=False)
        self.model = load_model(path, model_fn, cpu, onnx)
        self.device = 'cpu' if cpu else 'cuda'
        store_attr(self, 'data_func,bs')
        self.decode_func = None
        
    def _make_data(self, data):
        "Passes `data` through `after_item` and `after_batch`, splitting into batches"
        self.n_batches = len(data) // self.bs + (0 if len(data)%self.bs == 0 else 1)
        batch,batches = [],[]
        for d in data:
            d = self.data_func(d)
            for tfm in self.after_item: d = tfm(d)
            batch.append(d)
            if len(batch) == self.bs or (len(batches) == self.n_batches - 1 and len(batch) == len(data)):
                batch = torch.stack(batch, dim=0)
                batch = to_device(batch, self.device)
                for tfm in self.after_batch: 
                    batch = tfm(batch)
                batches.append(batch)
                batch = []
        return batches
    
    def _decode_inputs(self, inps, outs):
        "Decodes images through `after_batch`"
        for tfm in self.after_batch[::-1]:
            if hasattr(tfm, 'can_decode'):
                inps = to_device(tfm(inps, decode=True))
        outs.insert(len(outs), inps)
        return outs
    
    def get_preds(self, data, raw_outs=False, decode_func=None, with_input=False):
        """
        Gather predictions on `data` with possible decoding. 
        
        Params:
          > `data`: Incoming data formatted to what `self.data_func`is expecting
          > `raw_outs`: Whether to return the raw outputs
          > `decode_func`: A function to use for decoding potential outputs.
                           While the default is `None`, see `decode_cel` for an example
          > `with_input`: Whether to return a decoded input up to what the model was passed
        """
        inps, outs, dec_out, raw = [],[],[],[]
        batches = self._make_data(data)
        if self.n_inp > 1:
            [inps.append([]) for _ in range(n_inp)]
        for batch in batches:
            if with_input:
                if self.n_inp > 1:
                    for i in range(self.n_inp):
                        inps[i].append(batch[i].cpu())
                    else:
                        inps.append(batch[0].cpu())
            if onnx: out = onnx_preds(batch[:self.n_inp], self.model)
            else: out = torch_preds(batch[:self.n_inp], self.model)
            raw.append(out)
            if self.decode_func is not None: dec_out.append(decode_func(out))
        raw = torch.cat(raw, dim=0).cpu().numpy()
        if self.decode_func is not None: dec_out = torch.cat(dec_out, dim=0)
        if self.decode_func is None or raw_outs: outs.insert(0, raw)
        else: outs.insert(0, dec_out)
        if with_input: outs = self._decode_inputs(inps, outs)
        return outs