# Learner

> Generic learner & utils


Generic and highly flexible learner given the ability to inject behaviour almost every in the pipeline using callbacks. For further details and to give credit where credit is due, please refers to: [FastAI course part II](https://course.fast.ai).

In [6]:
#| default_exp learner

In [7]:
%load_ext autoreload
%autoreload 2

In [8]:
#| export
from functools import partial

import fastcore.all as fc
from collections.abc import Mapping

import torch
from torch import optim, nn
import torch.nn.functional as F
from torchvision import transforms as T
from lssm.callbacks import (run_cbs, to_cpu, LRFinderCB,
                            CancelBatchException, CancelFitException, CancelEpochException)

In [26]:
#|hide
from pathlib import Path

from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split

import timm
from torcheval.metrics import R2Score
from torch.optim import lr_scheduler

from lssm.loading import load_ossl
from lssm.preprocessing import ToAbsorbance, ContinuumRemoval, Log1p
from lssm.dataloaders import SpectralDataset, get_dls
from lssm.callbacks import (MetricsCB, BatchSchedCB, BatchTransformCB,
                            DeviceCB, TrainCB, ProgressCB)
from lssm.transforms import GADFTfm, _resizeTfm, StatsTfm


In [11]:
#|export
class with_cbs:
    'Decorator calling "before_`nm`" and "after_`nm`" around the decorated method. Call "cleanup_`nm`" once done.'
    def __init__(self, 
                 nm:str # Name of the callback method to call
                 ): 
        self.nm = nm
    def __call__(self, f):
        def _f(o, *args, **kwargs):
            try:
                o.callback(f'before_{self.nm}')
                f(o, *args, **kwargs)
                o.callback(f'after_{self.nm}')
            except globals()[f'Cancel{self.nm.title()}Exception']: pass
            finally: o.callback(f'cleanup_{self.nm}')
        return _f

In [29]:
#|export
class Learner():
    def __init__(self,
                 model, 
                 dls=(0,),
                 loss_func=F.mse_loss,
                 lr=0.1,
                 cbs=None,
                 opt_func=optim.SGD
                 ):
        cbs = fc.L(cbs)
        fc.store_attr()

    @with_cbs('batch')
    def _one_batch(self):
        self.predict()
        self.callback('after_predict')
        self.get_loss()
        self.callback('after_loss')
        if self.training:
            self.backward()
            self.callback('after_backward')
            self.step()
            self.callback('after_step')
            self.zero_grad()

    @with_cbs('epoch')
    def _one_epoch(self):
        for self.iter,self.batch in enumerate(self.dl):
            self._one_batch()

    def one_epoch(self, training):
        self.model.train(training)
        self.dl = self.dls.train if training else self.dls.valid
        self._one_epoch()

    @with_cbs('fit')
    def _fit(self, train, valid):
        for self.epoch in self.epochs:
            if train: self.one_epoch(True)
            if valid: torch.no_grad()(self.one_epoch)(False)

    def fit(self, n_epochs=1, train=True, valid=True, cbs=None, lr=None):
        cbs = fc.L(cbs)
        # `add_cb` and `rm_cb` were added in lesson 18
        for cb in cbs: self.cbs.append(cb)
        try:
            self.n_epochs = n_epochs
            self.epochs = range(n_epochs)
            if lr is None: lr = self.lr
            if self.opt_func: self.opt = self.opt_func(self.model.parameters(), lr)
            self._fit(train, valid)
        finally:
            for cb in cbs: self.cbs.remove(cb)

    def get_preds(self, data):
        self.model.train(False)
        self.batch = data
        self._one_batch()
        return learn.preds
        
    def __getattr__(self, name):
        if name in ('predict','get_loss','backward','step','zero_grad'): return partial(self.callback, name)
        raise AttributeError(name)

    def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)
    
    @property
    def training(self): return self.model.training

Example:

In [10]:
#|eval: false
model_name = 'resnet18'
model = timm.create_model(model_name, pretrained=True, in_chans=1, num_classes=1)

# Load dataset
analytes = 'k.ext_usda.a725_cmolc.kg'
data = load_ossl(analytes, spectra_type='visnir')
X, y, X_names, smp_idx, ds_name, ds_label = data

# Preprocess
X = Pipeline([('to_abs', ToAbsorbance()), 
              ('cr', ContinuumRemoval(X_names))]).fit_transform(X)
y = Log1p().fit_transform(y)

# Train/valid split
n_smp = None # For demo. purpose
X_train, X_valid, y_train, y_valid = train_test_split(X[:n_smp, :], y[:n_smp], 
                                                      test_size=0.1,
                                                      stratify=ds_name[:n_smp], 
                                                      random_state=41)

# Get PyTorch datasets
train_ds, valid_ds = [SpectralDataset(X, y, ) 
                      for X, y, in [(X_train, y_train), (X_valid, y_valid)]]

# Then PyTorch dataloaders
dls = get_dls(train_ds, valid_ds, bs=32)


100%|██████████| 44489/44489 [00:15<00:00, 2859.72it/s]


In [12]:
#|eval: false

# Define modelling pipeline & Train
epochs = 1
lr = 5e-3

metrics = MetricsCB(r2=R2Score())

tmax = epochs * len(dls.train)
sched = partial(lr_scheduler.OneCycleLR, max_lr=lr, total_steps=tmax)

xtra = [BatchSchedCB(sched)]

gadf = BatchTransformCB(GADFTfm())
resize = BatchTransformCB(_resizeTfm)
stats = BatchTransformCB(StatsTfm(model.default_cfg))

cbs = [DeviceCB(), gadf, resize, stats, TrainCB(), 
       metrics, ProgressCB(plot=False)]

learn = Learner(model, dls, nn.MSELoss(), lr=lr, cbs=cbs+xtra, opt_func=optim.AdamW)

learn.fit(epochs)

r2,loss,epoch,train
0.36,0.095,0,train
0.539,0.068,0,eval


In [13]:
#|export
@fc.patch
def lr_find(self:Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10):
    self.fit(max_epochs, lr=start_lr, cbs=LRFinderCB(gamma=gamma, max_mult=max_mult))

## Hooks

m = nn.Conv1d(16, 33, 3, stride=2)