In [15]:
import math,torch,matplotlib.pyplot as plt
import fastcore.all as fc
from collections.abc import Mapping
from operator import attrgetter
from functools import partial
from copy import copy
from torch import optim
import torch.nn.functional as F
from operator import attrgetter
import matplotlib as mpl
import torchvision.transforms.functional as TF
from contextlib import contextmanager
from torch import nn,tensor
from datasets import load_dataset,load_dataset_builder
import logging
from fastcore.test import test_close
from torch.utils.data import DataLoader


from lib import *

In [16]:
default_device = "mps" if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

# Set up data

In [17]:
@inplace
def transformi(b): 
    b[x] = [torch.flatten(TF.to_tensor(o)) for o in b[x]]

def get_model(): 
    m,nh = 28*28,50
    return nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))    

class DataLoaders:
    def __init__(self, train_data_loader, valid_data_loader):
        self.train = train_data_loader
        self.valid = valid_data_loader

    @classmethod # static method
    def from_datasetDict(cls, datasetDict, batch_size): #, as_tuple=True):
        return cls (*[DataLoader(ds, batch_size, collate_fn=collate_dict(ds)) for ds in datasetDict.values()])
        # this return calls __init__
    # static method with cls allows the instanciation of the class
    # recall that DataLoader can use multiple workers
    # dont send anything to device here cuz huge overload 


In [18]:
x,y = 'image','label'
dsd = load_dataset("fashion_mnist")
bs = 1024
tds = dsd.with_transform(transformi)
dls = DataLoaders.from_datasetDict(tds, bs)
dt = dls.train
xb,yb = next(iter(dt))
xb.shape,yb[:10]

(torch.Size([1024, 784]), tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5]))

# Learner with callbacks

In [19]:
def run_cbs(cbs, method_name): # core function!
    
    '''
        Loops ovel all cbs in input list and for each cb obj calls its .method_name() method
    '''
    
    for cb in sorted(cbs, key=attrgetter('order')): 
        # attrgetter('name'), the call f(b) returns b.name; similar to __getattr__        
        method = getattr(cb, method_name, None)
        if method is not None:
            method()            

In [20]:
class Callback(): order = 0 # order of execution

In [21]:
# custom ad-hoc exceptions
class CancelFitException(Exception): pass
class CancelEpochException(Exception): pass
class CancelBatchException(Exception): pass

In [22]:
# example of simple callback and functioning of run_cbs()
class CompletionCB(Callback):
    def before_fit(self):
        self.count = 0
    def after_batch(self):
        self.count +=1
    def after_fit(self):
        print(f"Fit executed, completed {self.count} batches")

cbs = [CompletionCB()]
run_cbs(cbs, "before_fit")
run_cbs(cbs, "after_batch")
run_cbs(cbs, "after_fit")

Fit executed, completed 1 batches


Everything is always set as self.foo such to make everything modifiable via callbacks, since callbacks have a reference to this learner
Learner obj and Callbacks obj are very coupled but high performance cuz of direct access across each other

In [23]:
class Learner():
    def __init__(self, model, dls, loss_func, lr, cbs, opt_func=optim.SGD):
        fc.store_attr()
        for cb in cbs:
            cb.learner = self # in each Callback object create a reference dmember to this learner        
            
    def fit(self, n_epochs):
        self.n_epochs = n_epochs
        self.epochs = range(n_epochs)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        try:
            self.callback("before_fit") # calls all .before_fit() methods of all cb objs stored in self.cbs
            for self.epoch in self.epochs:
                self.one_epoch(True)
                self.one_epoch(False)
            self.callback("after_fit") 
        except CancelFitException: 
            pass # if any of the callback for "before_fit" and/or "after_fit" throws ONLY THIS PARTICULAR EXCEPTION, do nothing
        
    def callback(self, method_name):
        run_cbs(self.cbs, method_name)
            
    def one_epoch(self, train):
        self.model.train(train)
        self.dl = self.dls.train if train else self.dls.valid
        try:
            self.callback("before_epoch") 
            for self.iter, self.batch in enumerate(self.dl):
                try:
                    self.callback("before_batch") 
                    self.one_batch()
                    self.callback("after_batch")                     
                except CancelBatchException:
                    pass    
            self.callback('after_epoch')
        except CancelEpochException:
            pass 
        
    def one_batch(self):
        self.preds = self.model(self.batch[0])
        self.loss = self.loss_func(self.preds, self.batch[1])
        if self.model.training:
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()

In [24]:
learn = Learner(model=get_model(), dls=dls, loss_func=F.cross_entropy, lr=0.2, cbs=cbs)
learn.fit(1)

Fit executed, completed 69 batches


# Other useful callbacks

In [25]:
class DeviceCB(Callback):
    '''
        self.learner set up into Learner ctor
    '''
    def __init__(self, device=default_device): 
        self.device = device
    
    def before_fit(self):
        self.learner.model.to(self.device)
    
    def before_batch(self):
        self.learner.batch = to_device(self.learner.batch, self.device)

In [26]:
cbs = [DeviceCB()]
learn = Learner(get_model(), dls, F.cross_entropy, 0.2, cbs)
learn.fit(1)

# Metrics

In [27]:
class Metric:
    '''
        Base class to be extended if particular metric is desired.
        If not extended it computes the weighted average of its input wrt batch_size 
    '''
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.vals, self.ns = [], []
        
    def add(self, input, target=None, batch_size=1):
        '''
            adds (x_hats, y) for minibatch
            target: optional    
        '''
        self.last = self.calc(input, target)
        self.vals.append(self.last)
        self.ns.append(batch_size)
    
    @property # allaws call of value without ()
    def value(self):
        ns = torch.tensor(self.ns)
        return (torch.tensor(self.vals)*ns).sum()/ns.sum()
    
    def calc(self, inputs, targets): 
        ''' method to be overwritten in derived class '''
        return inputs

# Accuracy subclass

In [28]:
class Accuracy(Metric):
    def calc(self, inputs, targets):
        return (inputs==targets).float().sum() 

## Now that we have implemented metrics on our own we can use pytorch ones: https://torchmetrics.readthedocs.io/en/stable/

In [29]:
# ! pip install torchmetrics

from torchmetrics.classification import MulticlassAccuracy #(https://torchmetrics.readthedocs.io/en/stable/classification/accuracy.html#multiclassaccuracy)
from torchmetrics.aggregation import MeanMetric # (https://torchmetrics.readthedocs.io/en/stable/aggregation/mean.html)

Example usage:

In [30]:
n_classes = ((tensor(dsd['train'][y]).unique()).max()+1).item()
metric = MulticlassAccuracy(n_classes)

x_hat_b1 = tensor([0,1,2,0,1]) 
y_b1 = tensor([0,0,2,1,1]) 

x_hat_b2 = tensor([1,1,2,0,0]) 
y_b2 = tensor([0,1,2,0,0]) 

metric.update(x_hat_b1, y_b1)
metric.update(x_hat_b2, y_b2)

metric.compute()

tensor(0.7556)

In [31]:
metric.reset()
metric.compute()



tensor(0.)

In [32]:
loss = MeanMetric()

loss.update(.9, weight=32)
loss.update(.6, weight=2)

loss.compute(), (.9*32 +.6*2)/(32+2)

(tensor(0.8824), 0.8823529411764706)

So let's now create a MetricsCB class to be inserted in our cbs execution list.
This class uses pytorch metrics.

In [35]:
class MetricCB(Callback):
    '''
    Stores all metrics to be used/computed by model. Unique access point/handler for all metrics/to compute all metrics 
    
    You can construct as:
        metric = MetricCB(MulticlassAccuracy(n_classes))
    or:
        metric = MetricCB(accuracy=MulticlassAccuracy(n_classes))

    in the first case *ms will contain: MulticlassAccuracy and by taking its name stores it in metrics dict
    in the second case **metrics dict will contain {'accuracy'= MulticlassAccuracy()}
    '''
    
    # *ms = list of positional non-keyworded inputs
    # **metrics = dict of keyworded inputs
    def __init__(self, *ms, **metrics):
        for o in ms:
            metrics[type(o).__name__] = o # as explained above adds non-keyworded inputs to **metrics       
        self.metrics = metrics # store **metrics dict
        self.all_metrics = copy(metrics)
        self.all_metrics['loss'] = self.loss = MeanMetric() # MeanMetric() added by default
        
    def _log(self, d): 
        # to override for more complex presentations
        print(d)
        
    def before_fit(self): # IMPORTANT        
        # set this MetricCB obj in learner
        self.learner.metrics = self
        
    def after_batch(self):
        x, y = self.learner.batch
        for m in self.metrics.values():
            m.update(self.learner.preds, y) # for each metric, add/update (x_hats, y) for minibatch 
        self.loss.update(self.learner.loss, weight=len(x))
    
    def after_epoch(self): # creates dict for printing/logging 
        log = {k:f'{v.compute()}' for k,v in self.all_metrics.items()} # compute all metrics 
        log['epoch'] = self.learner.epoch
        log['train'] = self.learner.model.training
        self._log(log)
        
    def before_epoch(self): # reset        
        [o.reset() for o in self.all_metrics.values()]

In [36]:
cbs = [MetricCB(MulticlassAccuracy(n_classes))]
learn = Learner(get_model(), dls, F.cross_entropy, 0.2, cbs)
learn.fit(1)

{'MulticlassAccuracy': '0.5950833559036255', 'loss': '1.186214804649353', 'epoch': 0, 'train': True}
{'MulticlassAccuracy': '0.6886999607086182', 'loss': '0.8226090669631958', 'epoch': 0, 'train': False}


# More flexible learner with context manager
### Code design: exceptions as control flow

In [37]:
class TrainCB(Callback):
    '''cb that handles train process details'''
    def predict(self):
        self.learner.preds = self.learner.model(self.learner.batch[0])
    
    def get_loss(self):
        self.learner.loss = self.learner.loss_func(self.learner.preds, self.learner.batch[1])
        
    def backward(self):
        self.learner.loss.backward()
        
    def step(self):
        self.learner.opt.step()
        
    def zero_grad(self):
        self.learner.opt.zero_grad()

In [40]:
class Learner():
    def __init__(self, model, dls, loss_func, lr, cbs, opt_func=optim.SGD):
        fc.store_attr()
        for cb in cbs:
            cb.learner = self 
            
    @contextmanager # defines a with: statement
    def callback_ctx(self, context_name):
        '''
        When the with is executed, all code before the yield in this method is executed.
        When the yield is called, the body of the with statement is executed.
        After the body of the with statement is executed, the code after the yield is executed.
        '''
        try: 
            self.callback(f'before_{context_name}') 
            yield # here it is called all code that is in the with statement
            self.callback(f'after_{context_name}')
        except globals()[f'Cancel{context_name.title()}Exception']: 
            # all globals live in globals() dict, here it is getting the correct exception via classname; title capitalizes 
            pass
   
    def fit(self, n_epochs):
        self.n_epochs = n_epochs 
        self.epochs = range(n_epochs)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        with self.callback_ctx("fit"): # calls all .before_fit() of this learner callbacks
            for self.epoch in self.epochs:
                self.one_epoch(True)
                self.one_epoch(False)
            # calls all .after_fit() of this learner callbacks

    def one_epoch(self, train):
        self.model.train(train)
        self.dl = self.dls.train if train else self.dls.valid
        with self.callback_ctx('epoch'):
            for self.iter, self.batch in enumerate(self.dl):
                with self.callback_ctx("batch"): 
                    # here will be handled all training details taken from TrainCB 
                    self.predict() # these are extracted to make learner very flexible
                    self.get_loss() # you can plug whatever logic you want here now 
                    # first is solved self.g as dmember via __getattr__ -> returns a func; then () to call returned func
                    if self.model.training:
                        self.backward()
                        self.step()
                        self.zero_grad()                

    def callback(self, method_name):
        run_cbs(self.cbs, method_name)
                        
    def __getattr__(self, name):
        if name in ("predict", "get_loss", "backward", "step", "zero_grad"):
            return partial(self.callback, name) # returns the input func, with input "named" passed to input func
        raise AttributeError(name)

In [41]:
cbs = [TrainCB(), MetricCB(MulticlassAccuracy(n_classes))]
learn = Learner(get_model(), dls, F.cross_entropy, 0.2, cbs)
learn.fit(1)

{'MulticlassAccuracy': '0.5978000164031982', 'loss': '1.168588638305664', 'epoch': 0, 'train': True}
{'MulticlassAccuracy': '0.6824000477790833', 'loss': '0.8261908292770386', 'epoch': 0, 'train': False}


# MomentumLearner or MomentumCB

In [None]:
# another way to do the same is by subclassing the learner directly and/or implementing yourself directly 
# the methods required to train the learner

# IDEA: instead of zeroing gradients we keep them "alive" but we shrink them by a factor momentum < 1
# BN: pytorch autograd ONLY ADDS to gradients
class MomentumLearner(Learner):
    def __init__(self, model, dls, loss_func, lr, cbs, opt_func=optim.SGD, momentum=.85):
        self.momentum=momentum
        super().__init__(model, dls, loss_func, lr, cbs, opt_func)

    def predict(self):
        self.preds = self.model(self.batch[0])
    
    def get_loss(self):
        self.loss = self.loss_func(self.preds, self.batch[1])
        
    def backward(self):
        self.loss.backward()
        
    def step(self):
        self.opt.step()
        
    def zero_grad(self):
        with torch.no_grad():
            for p in self.model.parameters():
                p.grad *= self.momentum
                

In [None]:
cbs = [TrainCB(), MetricCB(MulticlassAccuracy(n_classes))]
learn = MomentumLearner(get_model(), dls, F.cross_entropy, lr=0.2, cbs=cbs, momentum=.85)
learn.fit(1)

# LRFinderCB

In [None]:
class LRFinderCB(Callback):
    def __init__(self, lr_multiplier=1.3):
        fc.store_attr()
        
    def before_fit(self):
        self.lrs, self.losses = [], []
        self.min = math.inf
        
    def after_batch(self):
        if not self.learner.model.training: 
            raise CancelEpochExceptionLoss()
        
        self.lrs.append(self.learner.opt.param_groups[0]['lr'])
        loss = self.learner.loss.detach()
        self.losses.append(loss)
        
        if loss < self.min:
            self.min = loss
            
        if loss > self.min*3: # stopping criteria
            raise CancelFitException() # cancel the whole fit! Nice
            
        for g in self.learner.opt.param_groups:
            g['lr'] *= self.lr_multiplier
            
    def plot(self):
        plt.plot(self.lrs, self.losses)
        plt.xscale('log')
        plt.xlabel("Learning rage")
        plt.ylabel("Loss")
        
    

In [None]:
lrf = LRFinderCB()
cbs = [lrf]
learn = MomentumLearner(get_model(), dls, F.cross_entropy, lr=1e-4, cbs=cbs, momentum=.85)
learn.fit(1)

In [None]:
# plt.plot(lrf.lrs, lrf.losses)
# plt.xscale('log')
# plt.xlabel("Learning rage")
# plt.ylabel("Loss")

lrf.plot() # why not make them self-contained!

Increase lr over time and plot it against the loss and then we find how high we can bring the lr b4 loss gets to inf
Lr to be chosen: 0.1.

Let's now re-implement the LRFinderCB using pytorch lr schedulers, just to see that pytorch lr schedulersare actually not doing much!

In [None]:
from torch.optim.lr_scheduler import ExponentialLR

The scheduler multipliers the lr of all params in model by a factor gamma taken as input by the scheduler. 
The multiplication occurs only when scheduler.step() is called

In [None]:
class LRFinderCB(Callback):
    def __init__(self, gamma=1.3): #gamma == lr_multiplier
        fc.store_attr()
        
    def before_fit(self):
        self.scheduler = ExponentialLR(self.learner.opt, self.gamma)
        self.lrs, self.losses = [], []
        self.min = math.inf
        
    def after_batch(self):
        if not self.learner.model.training: 
            raise CancelEpochExceptionLoss()
        
        self.lrs.append(self.learner.opt.param_groups[0]['lr'])
        loss = self.learner.loss.detach()
        self.losses.append(loss)
        
        if loss < self.min:
            self.min = loss
            
        if loss > self.min*3: # stopping criteria
            raise CancelFitException() # cancel the whole fit! Nice
            
        #for g in self.learner.opt.param_groups:
        #    g['lr'] *= self.lr_multiplier
        self.scheduler.step()
            
    def plot(self):
        plt.plot(self.lrs, self.losses)
        plt.xscale('log')
        plt.xlabel("Learning rage")
        plt.ylabel("Loss")
        
    

In [None]:
lrf = LRFinderCB()
cbs = [lrf]
learn = MomentumLearner(get_model(), dls, F.cross_entropy, lr=1e-4, cbs=cbs, momentum=.85)
learn.fit(1)
lrf.plot()

In [None]:
from fastprogress import progress_bar, master_bar

class ProgressCB(Callback):
    order = MetricCB.order+1
    def __init__(self, plot=False): self.plot = plot
    def before_fit(self):
        learn = self.learner
        learn.epochs = self.mbar = master_bar(learn.epochs)
        self.first = True
        if hasattr(learn, 'metrics'): learn.metrics._log = self._log
        self.losses = []
        self.val_losses = []

    def _log(self, d):
        if self.first:
            self.mbar.write(list(d), table=True)
            self.first = False
        self.mbar.write(list(d.values()), table=True)

    def before_epoch(self):
        learn = self.learner
        learn.dl = progress_bar(learn.dl, leave=False, parent=self.mbar)
    
    def after_batch(self):
        learn = self.learner
        learn.dl.comment = f'{learn.loss:.3f}'
        if self.plot and hasattr(learn, 'metrics') and learn.model.training:
            self.losses.append(learn.loss.item())
            if self.val_losses: self.mbar.update_graph([[fc.L.range(self.losses), self.losses],[fc.L.range(learn.epoch).map(lambda x: (x+1)*len(learn.dls.train)), self.val_losses]])
    
    def after_epoch(self):
        learn = self.learner
        if not learn.model.training:
            if self.plot and hasattr(learn, 'metrics'): 
                self.val_losses.append(learn.metrics.all_metrics['loss'].compute())
                self.mbar.update_graph([[fc.L.range(self.losses), self.losses],[fc.L.range(learn.epoch+1).map(lambda x: (x+1)*len(learn.dls.train)), self.val_losses]])
     