# Learner Framework

This is the core framework I'm going to be using to run the training of all the ML models in this repository. It's based of the FastAI miniAI learner class from the Deep learning for coders part 2 course but with minor tweaks and refactoring. Partially so I can make sure I understand how it works and partly because I don't entirely agree with the coding style used in the course. 

I want to make this a very flexible framework for working on basically any model I could need. I'm going to do this using a callback heavy architecture - kind of like the strategy pattern except with a lot of interior mutablity. The callbacks will have access to the learner itself and be able to change basically anything about it. What this will mean is that theres a lot of instance variables for things that would normally just exist in the class like the current epoch number. 

Callbacks themselves need to inherit from the Callbacks class which quite frankly is barely an implementation. All it does is default an priority value which dictates the order in which callbacks. In the documentation I will endeavour to keep a list of all the methods that can be called in the Learners train method. The way it's done means there isnt an explicit set of methods that could be called (How very unrust of me. Luckily this isn't Rust)

In [1]:
from contextlib import contextmanager
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

In [5]:
class Callback:
    """ Base class for callbacks for the Learner class. This should never be instantiated on it's own. 
    It wouldn't do anything if you did but it would be a waste of time. 

    Below is a list of all the methods that could be run by the Learner Class. Do note that each one has 2 variants "before_" and "after_" prepended to the name.
        - 
    
    
    """
    priority=0
    

In [None]:
class CancelTrainException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass

In [3]:
class Learner:

    def __init__(self, model, optimiser, loss_func, dataloaders, callbacks=None):
        self.model = model
        self.optimiser = optimiser
        self.loss_func = loss_func
        self.dataloaders = dataloaders
        self.callbacks = [] if callbacks is None else list(callbacks)

    def one_epoch(self):
        for j, (data, labels) in enumerate(self.train_dataset):
            self.batch_num = j
            self.data = data
            self.labels = labels
            try:
                with self._callback_manager("batch"):
                    self.optimiser.zero_grad()

                    self.predictions = self.model(self.data)

                    self.loss = self.loss_func(self.predictions, self.labels)

                    self.loss.backward()

                    self.optimiser.step()
            except CancelBatchException:
                pass


    def train(self, epochs=1, learning_rate=1E-3, validation_epochs=1):
        self.num_epochs = epochs
        try: 
            with self._callback_manager("train"):
                for i in range(epochs):
                    self.epoch = i
                    if "train" in self.dataloaders:
                        try: 
                            with self._callback_manager("epoch"):
                                self.one_epoch()
                        except CancelEpochException:
                            pass
                        finally: 
                            if "validation" in self.dataloaders and self.epoch % validation_epochs == 0: 
                                self._validate()
        except CancelTrainException:
            pass

        if "test" in self.dataloaders:
            self._test()

    
    def _validate(self):
        with torch.no_grad():
            for i, (data, labels) in enumerate(self.dataloaders["validation"]):
                self.data, self.labels = data, labels
                with self._callback_manager("validation_batch"):
                    self.validation_predictions = self.model(self.data, self.labels)
                    self.validation_loss_batch = self.loss_func(self.predictions, self.labels)

                    if i:
                        self.validation_loss = self.accumulate_loss(self.validation_loss, self.validation_loss_batch)
                    else:
                        self.validation_loss = self.validation_loss_batch

    def _test(self):
        with torch.no_grad():
            for i, (data, labels) in enumerate(self.dataloaders["test"]):
                self.data, self.labels = data, labels
                with self._callback_manager("test_batch"):
                    self.test_predictions = self.model(self.data, self.labels)
                    self.test_loss_batch = self.loss_func(self.predictions, self.labels)

                    if i:
                        self.test_loss = self.accumulate_loss(self.test_loss, self.test_loss_batch)
                    else:
                        self.test_loss = self.test_loss_batch

    @contextmanager
    def _callback_manager(self, name):
        # Is there a way to move the try except block into here?
        # Might have to pass in a list of exceptions in here but thats doable
        # That way I could run the after callbacks even if the context was cancelled early
        # Also would be nice to have a hook to let callbacks know if the batch/epoch/training was cancelled early
        # Something like self.epoch_failed = True but more computational than that
        self._run_callback("before_" + name)
        yield
        self._run_callback("after_" + name)


    def _run_callback(self, name):
        for callback in sorted(self.callbacks, lambda x: x.order):
            callback = getattr(callback, name, None)

            if callback is not None:
                callback(self)
