In [None]:
from operator import attrgetter

In [None]:
class Callback(): order = 0

In [None]:
#from pdb import set_trace

In [None]:
def run_cbs(cbs, method_nm, learn=None):
    #set_trace()
    for cb in sorted(cbs, key=attrgetter('order')):
        method = getattr(cb, method_nm, None)
        if method is not None: method(learn)

#### **Explanation for the `sorted` mechanism**
There are a lot going on here `sorted(cbs, key=attrgetter('order')`, in the sorted function, `attrgetter('order')` returns a function for the `sorted` then sorted uses this function for sorting callbacks which are class instances e.g:

```Python
class MyClass():      
    order = 0
    
# Create instance of MyClass
obj = MyClass()

# Import the attrgetter function from the operator module
from operator import attrgetter

# Use attrgetter to get the 'order' attribute from a instance
get_order = attrgetter('order')

# Get the 'order' attribute from objects
order = get_order(obj)


# Print the results
print(order)  # Output: 0

```

In [None]:
class CompletionCB(Callback):
    def before_fit(self, learn): self.count = 0
    def after_batch(self, learn): self.count += 1
    def after_fit(self, learn): print(f'Completed {self.count} batches')

In [None]:
cbs =[CompletionCB()]
run_cbs(cbs, 'before_fit')
run_cbs(cbs, 'after_batch')
run_cbs(cbs, 'after_fit')

Completed 1 batches


#### **Understand `run_cbs` by calling callbacks step by step below**
Note:`getattr(cb,'before_fit')(None)` syntax is not the same with the one in the recording because I'm using the latest `run_cbs` function that is updated by JH after the lesson.

In [None]:
cb = cbs[0]

`getattr` returns a method and it is called in the the `run_cbs`

In [None]:
getattr(cb,'before_fit')(None)

In [None]:
getattr(cb,'after_batch')(None)

In [None]:
getattr(cb,'after_fit')(None)

Completed 1 batches


as expected `self.count` line called only one time

:::{.callout-note}
#### **with_cbs**
the explanation decorators in the `Learner` class is  in Lesson 15 around 1:28,
`o` in the code is I think is `self` that means Learner Class itself.

:::

In [None]:
class with_cbs:
    def __init__(self, nm): self.nm = nm
    def __call__(self, f):
        def _f(o, *args, **kwargs):
            #print(o)
            try:
                o.callback(f'before_{self.nm}')
                # this `o` is `self`
                f(o, *args, **kwargs)
                o.callback(f'after_{self.nm}')
            except globals()[f'Cancel{self.nm.title()}Exception']: pass
            finally: o.callback(f'cleanup_{self.nm}')
        return _f

In [None]:
import torch.nn.functional as F
from torch import optim

### **Updated version  Learner class**
Updated version of the Learner class as  it is in the original notebook. I skipped previous versions that seen it the video and earlier examples.

In [None]:
import torch

In [None]:
torch.no_grad()

<torch.autograd.grad_mode.no_grad>

In [None]:
# #|export
# class Learner():
#     def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.SGD):
#         #check below for fc.L(cbs)
#         cbs = fc.L(cbs)
#         fc.store_attr()

#     @with_cbs('batch')
#     def _one_batch(self):
#         self.predict()
#         self.callback('after_predict')
#         self.get_loss()
#         self.callback('after_loss')
#         if self.training:
#             self.backward()
#             self.callback('after_backward')
#             self.step()
#             self.callback('after_step')
#             self.zero_grad()

#     @with_cbs('epoch')
#     def _one_epoch(self):
#         for self.iter,self.batch in enumerate(self.dl): self._one_batch()

#     def one_epoch(self, training):
#         self.model.train(training) # this 'train' comes from torch/nn/modules/module.py
#         self.dl = self.dls.train if training else self.dls.valid
#         self._one_epoch()

#     @with_cbs('fit')
#     def _fit(self, train, valid):
#         for self.epoch in self.epochs:
#             if train: self.one_epoch(True)
#             if valid: torch.no_grad()(self.one_epoch)(False) #wow what is going on here? No `with` statement.

#     def fit(self, n_epochs=1, train=True, valid=True, cbs=None, lr=None):
#         print(cbs)
#         #this is changes the 'None' to empty list but makes the code a bit less readable for me.
#         cbs = fc.L(cbs)
#         print(cbs)
#         # `add_cb` and `rm_cb` were added in lesson 18
#         for cb in cbs: self.cbs.append(cb)
#         try:
#             self.n_epochs = n_epochs # is it redundant?
#             self.epochs = range(n_epochs)
#             if lr is None: lr = self.lr
#             if self.opt_func: self.opt = self.opt_func(self.model.parameters(), lr)
#             self._fit(train, valid)
#         finally:
#             for cb in cbs: self.cbs.remove(cb)

#     def __getattr__(self, name):
#         if name in ('predict','get_loss','backward','step','zero_grad'): return partial(self.callback, name)
#         raise AttributeError(name)

#     def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)
    
#     @property
#     def training(self): return self.model.training

In [None]:
#|export
class Learner():
    def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, cbs=None, opt_func=optim.SGD):
        #check below for fc.L(cbs)
        cbs = fc.L(cbs)
        fc.store_attr()
        
    def fit(self, n_epochs=1, train=True, valid=True, cbs=None, lr=None):
    #print(cbs)
    #this is changes the 'None' to empty list but makes the code a bit less readable for me.
        cbs = fc.L(cbs)
        #print(cbs)
        # `add_cb` and `rm_cb` were added in lesson 18
        for cb in cbs: self.cbs.append(cb)
        try:
            self.n_epochs = n_epochs # is it redundant?
            self.epochs = range(n_epochs)
            if lr is None: lr = self.lr
            if self.opt_func: self.opt = self.opt_func(self.model.parameters(), lr)
            self._fit(train, valid)
        finally:
            for cb in cbs: self.cbs.remove(cb)
        
    @with_cbs('fit')
    def _fit(self, train, valid):
        for self.epoch in self.epochs:
            if train: self.one_epoch(True)
            if valid: torch.no_grad()(self.one_epoch)(False) #wow what is going on here? No `with` statement. 

    def one_epoch(self, training):
        self.model.train(training) # this 'train' comes from torch/nn/modules/module.py
        self.dl = self.dls.train if training else self.dls.valid
        self._one_epoch()
        
    @with_cbs('epoch')
    def _one_epoch(self):
        for self.iter,self.batch in enumerate(self.dl): self._one_batch()        
        
    @with_cbs('batch')
    def _one_batch(self):
        self.predict()
        self.callback('after_predict')
        self.get_loss()
        self.callback('after_loss')
        if self.training:
            self.backward()
            self.callback('after_backward')
            self.step()
            self.callback('after_step')
            self.zero_grad()

    def __getattr__(self, name):
        if name in ('predict','get_loss','backward','step','zero_grad'): return partial(self.callback, name)
        raise AttributeError(name)

    def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)
    
    @property
    def training(self): return self.model.training

In [None]:
from torch import nn,tensor
from datasets import load_dataset
from miniai.datasets import *
import fastcore.all as fc
import torch
import torchvision.transforms.functional as TF
from functools import partial

In [None]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)

#### **To test CompletionCB we need a dataset**

In [None]:
x,y = 'image','label'
name = "fashion_mnist"
dsd = load_dataset(name)

Found cached dataset fashion_mnist (/home/niyazi/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/0a671f063342996f19779d38c0ab4abef9c64f757b35af8134b331c294d7ba48)


  0%|          | 0/2 [00:00<?, ?it/s]

___
**this is explained previously ---->>>**

In [None]:
@inplace
def transformi(b): b[x] = [torch.flatten(TF.to_tensor(o)) for o in b[x]]

In [None]:
bs = 1024
tds = dsd.with_transform(transformi)

**<<<---- this is explained previously**
___

In [None]:
dls = DataLoaders.from_dd(tds, bs, num_workers=4)

In [None]:
m,nh = 28*28,50
def get_model(): return nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))

#### **Here after one epoch we have 64 batches**

In [None]:
model = get_model()
learn = Learner(model,dls,F.cross_entropy, lr =0.2, cbs=[CompletionCB()])

In [None]:
learn.fit(1)

Completed 64 batches


#### **Now lets use exceptions**
These are below are inherit from the type `Exception`

In [None]:
class CancelFitException(Exception): pass
class CancelBatchException(Exception): pass
class CancelEpochException(Exception): pass

`SingleBatchCB`raises `CancelEpochException()` after first batch.

In [None]:
class SingleBatchCB(Callback):
    order = 1
    def after_batch(self, learn): raise CancelEpochException()

In [None]:
model = get_model()
learn = Learner(model,dls,F.cross_entropy, lr =0.2, cbs=[SingleBatchCB(),CompletionCB()])

In [None]:
learn.fit(1)

Completed 2 batches


#### **It worked**
one for training and one for validation.

#### **Metrics**
---
**that is for explanation. Not used in the code**

In [None]:
#from pdb import set_trace

In [None]:
class Metric:
    def __init__(self): self.reset()
    def reset(self): self.vals,self.ns = [],[]
    def add(self, inp, targ=None, n=1):
        #set_trace()
        self.last = self.calc(inp, targ)
        self.vals.append(self.last)
        self.ns.append(n)
    @property
    def value(self):
        ns = tensor(self.ns)
        return (tensor(self.vals)*ns).sum()/ns.sum()
    def calc(self, inps, targs): return inps

honest I do not get this polimorphism that why and how we calculated loss this way.It is not for the learner but only for the sake of explanation.

In [None]:
class Accuracy(Metric):
    def calc(self, inps, targs): return (inps==targs).float().mean()

In [None]:
# acc = Accuracy()
# acc.add(tensor([0, 1, 2, 0, 1, 2]), tensor([0, 1, 1, 2, 1, 0]))
# acc.add(tensor([1, 1, 2, 0, 1]), tensor([0, 1, 1, 2, 1]))
# acc.value

tensor(0.45)

In [None]:
# loss = Metric()
# loss.add(0.6, n=32)
# loss.add(0.9, n=2)
# loss.value, round((0.6*32+0.9*2)/(32+2), 2)

(tensor(0.62), 0.62)

**that is for explanation. Not used in the code**

---

#### **Add MetricsCB**

In [None]:
def to_cpu(x):
    if isinstance(x, Mapping): return {k:to_cpu(v) for k,v in x.items()}
    if isinstance(x, list): return [to_cpu(o) for o in x]
    if isinstance(x, tuple): return tuple(to_cpu(list(x)))
    return x.detach().cpu()

`learn.preds` created in the TrainCB so if you use MetricCB without TrainCB which is ok.

In [None]:
class MetricsCB(Callback):
    def __init__(self, *ms, **metrics):
        for o in ms: metrics[type(o).__name__] = o
        self.metrics = metrics
        self.all_metrics = copy(metrics)
        self.all_metrics['loss'] = self.loss = Mean()

    def _log(self, d): print(d)
    def before_fit(self, learn): learn.metrics = self
    def before_epoch(self, learn): [o.reset() for o in self.all_metrics.values()]

    def after_epoch(self, learn):
        log = {k:f'{v.compute():.3f}' for k,v in self.all_metrics.items()}
        log['epoch'] = learn.epoch
        log['train'] = 'train' if learn.model.training else 'eval'
        self._log(log)

    def after_batch(self, learn):
        x,y,*_ = to_cpu(learn.batch)
        for m in self.metrics.values(): m.update(to_cpu(learn.preds), y)
        self.loss.update(to_cpu(learn.loss), weight=len(x))

In [None]:
from torcheval.metrics import MulticlassAccuracy,Mean
from copy import copy
from miniai.conv import * #this is for def_device
from collections.abc import Mapping

In [None]:
class DeviceCB(Callback):
    def __init__(self, device=def_device): fc.store_attr()
    def before_fit(self, learn):
        if hasattr(learn.model, 'to'): learn.model.to(self.device)
    def before_batch(self, learn): learn.batch = to_device(learn.batch, device=self.device)

note: around lesson 16 32:00 JH talks about using HF data styles dictionary and Accelerator with this Callback context.

In [None]:
class TrainCB(Callback):
    def __init__(self, n_inp=1): self.n_inp = n_inp
    def predict(self, learn): learn.preds = learn.model(*learn.batch[:self.n_inp])
    def get_loss(self, learn): learn.loss = learn.loss_func(learn.preds, *learn.batch[self.n_inp:])
    def backward(self, learn): learn.loss.backward()
    def step(self, learn): learn.opt.step()
    def zero_grad(self, learn): learn.opt.zero_grad()

In [None]:
model = get_model()
metrics = MetricsCB(accuracy=MulticlassAccuracy())
learn = Learner(model, dls, F.cross_entropy, lr=0.2, cbs=[TrainCB(),DeviceCB(), metrics])

In [None]:
learn.fit(1)

{'accuracy': '0.599', 'loss': '1.210', 'epoch': 0, 'train': 'train'}
{'accuracy': '0.707', 'loss': '0.808', 'epoch': 0, 'train': 'eval'}


In [None]:
class MetricsCB_bits(Callback):
    def __init__(self, *ms, **metrics):
        print(f'ms:{ms} type:{type(ms)},-- metrics:{metrics} type:{type(metrics)}')
        for o in ms:            
            metrics[type(o).__name__] = o
        self.metrics = metrics
        self.all_metrics = copy(metrics)
        self.all_metrics['loss'] = self.loss = Mean()
        

In [None]:
metrics_bits = MetricsCB_bits(MulticlassAccuracy())

ms:(<torcheval.metrics.classification.accuracy.MulticlassAccuracy object>,) type:<class 'tuple'>,-- metrics:{} type:<class 'dict'>


In [None]:
metrics_bits.metrics

{'MulticlassAccuracy': <torcheval.metrics.classification.accuracy.MulticlassAccuracy>}

In [None]:
print(metrics_bits.all_metrics.values())

dict_values([<torcheval.metrics.classification.accuracy.MulticlassAccuracy object>, <torcheval.metrics.aggregation.mean.Mean object>])


In [None]:
metrics_bits.ms

(<torcheval.metrics.classification.accuracy.MulticlassAccuracy>,)

In [None]:
for i in metrics_bits.ms:
    print(i)
    

<torcheval.metrics.classification.accuracy.MulticlassAccuracy object>


In [None]:
metrics_bits.all_metrics.values()

dict_values([<torcheval.metrics.classification.accuracy.MulticlassAccuracy object>, <torcheval.metrics.aggregation.mean.Mean object>])

In [None]:
[o.reset for o in metrics_bits.all_metrics.values()]

[<bound method Metric.reset of <torcheval.metrics.classification.accuracy.MulticlassAccuracy object>>,
 <bound method Metric.reset of <torcheval.metrics.aggregation.mean.Mean object>>]