In [None]:
#| default_exp learner

# Learner
> The core training loop

In [23]:
#| export
import torch
from typing import Union, List
from torch.optim import Optimizer
from fastai_reimplemented.callback.core import Callback

In [24]:
#| export
class Learner:
    """
    Groups together a model, a set of dataloaders, an optimizer, and 
    a loss function to handle training.
    """
    training = False
    def __init__(
        self,
        dls, # A `DataLoaders` instance containing fastai or PyTorch `DataLoader`s
        model:torch.nn.Module, # A PyTorch model for training or inference
        loss_function:callable, # The loss function to use, defaults to the one attached in `dls`
        opt_function:Optimizer, # Optimization function for training
        learning_rate:float=1e-3, # Default learning rate
        callbacks:List[Callback] = [], # Additional `Callback`'s to be tied to the `Learner` directly
        metrics:callable=None, # Metrics to be applied on the validation set
    ):
        self.dls = dls
        # if loss_function is None:
        #     loss_function = getattr(self.dls.train_ds, "loss_func", False)
        #     if not loss_function:
        #         raise ValueError("Could not infer loss function from the dataloaders, please pass in a loss function")
        self.model = model
        self.loss_function = loss_function
        self.opt_function = opt_function
        self.learning_rate = learning_rate
        self.callbacks = callbacks
        self.metrics = metrics
        
    def call_event(
        self, 
        event:str # A valid `Callback` event
    ):
        "Calls a callback `event` for all callbacks in `self.callbacks`"
        for callback in sorted(
            self.callbacks,
            key=lambda cb: cb.order, 
            reverse=True
        ):
            if isinstance(callback, type):
                callback = callback(self)
            callback(event)

In [25]:
class TestCallback(Callback):
    def before_batch(self):
        print("Called in before batch!")

In [26]:
learn = Learner(None, None, None, None, callbacks=[TestCallback])

In [27]:
learn.call_event("before_batch")

Called in before batch!
