In [None]:
#| default_exp learner

In [None]:
#| export
import math, torch, matplotlib.pyplot as plt

import fastcore.all as fc
from collections.abc import Mapping
from operator import attrgetter
from functools import partial
from copy import copy

from torch import optim
import torch.nn.functional as F

from fastai_course.conv import *

from fastprogress import progress_bar,master_bar

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import matplotlib as mpl
import torchvision.transforms.functional as TF
from contextlib import contextmanager
from torch import nn,tensor
from datasets import load_dataset,load_dataset_builder
from fastai_course.datasets import *
import logging
from fastcore.test import test_close

In [None]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'

In [None]:
logging.disable(logging.ERROR)

In [None]:
x,y = 'image','label'
name = 'fashion_mnist'
dsd = load_dataset(name)

In [None]:
dsd['train'][0]

In [None]:
@inplace
def transformi(b):
    b[x] = [torch.flatten(TF.to_tensor(o)) for o in b[x]]

In [None]:
bs = 1024
tds = dsd.with_transform(transformi)

In [None]:
tds['train'][0]['image'].shape

In [None]:
dls = DataLoaders.from_dd(tds, bs, num_workers=4)
dt = dls.train
xb, yb = next(iter(dt))
xb.shape, yb.shape, yb[:10]

In [None]:
#|export
class CancelFitException(Exception): pass
class CancelBatchException(Exception): pass
class CancelEpochException(Exception): pass

In [None]:
#|export
class Callback(): order = 0

In [None]:
a = Callback()

In [None]:
attrgetter('order')(a)

In [None]:
#|export
def run_cbs(cbs, method_nm, learn=None):
    for cb in sorted(cbs, key=attrgetter('order')):
        method = getattr(cb, method_nm, None)
        # print(method)
        if method is not None: method(learn)

In [None]:
class CompletionCB(Callback):
    def before_fit(self, learn): self.count = 0
    def after_batch(self, learn): self.count += 1
    def after_fit(self, learn): print(f'completed {self.count} batches')

In [None]:
cbs = [CompletionCB()]
run_cbs(cbs, 'before_fit')
run_cbs(cbs, 'after_batch')
run_cbs(cbs, 'after_fit')

In [None]:
class Learner():
    def __init__(self, model, dls, loss_func, lr, cbs, opt_func=optim.SGD): fc.store_attr()

    def one_batch(self):
        self.preds = self.model(self.batch[0])
        self.loss = self.loss_func(self.preds, self.batch[1])
        if self.model.training:
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()
    
    def one_epoch(self, train):
        self.model.train(train)
        self.dl = self.dls.train if train else self.dls.valid
        try:
            self.callback('before_epoch')
            for self.iter, self.batch in enumerate(self.dl):
                try:
                    self.callback('before_batch')
                    self.one_batch()
                    self.callback('after_batch')
                except CancelBatchException: pass
            self.callback('after_epoch')
        except CancelEpochException: pass
    
    def fit(self, n_epochs):
        self.n_epochs = n_epochs
        self.epochs = range(n_epochs)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        try:
            self.callback('before_fit')
            for self.epoch in self.epochs:
                self.one_epoch(True)
                self.one_epoch(False)
            self.callback('after_fit')
        except CancelFitException: pass
    
    def callback(self, method_name):
        run_cbs(self.cbs, method_name, self)

In [None]:
m, nh = 28*28, 50
model = nn.Sequential(
    nn.Linear(m, nh),
    nn.ReLU(),
    nn.Linear(nh, 10)
)

In [None]:
learn = Learner(model, dls, F.cross_entropy, lr=0.2, cbs=[CompletionCB()])
learn.fit(1)

In [None]:
#| export
class SingleBatchCB(Callback):
    order = 1
    def after_batch(self, learn): 
        print('stop training')
        raise CancelFitException()

In [None]:
learn = Learner(model, dls, F.cross_entropy, lr=0.2, cbs=[CompletionCB(), SingleBatchCB()])
learn.fit(1)

In [None]:
#| export
from torcheval.metrics import MulticlassAccuracy, Mean, BinaryNormalizedEntropy

In [None]:
metric = MulticlassAccuracy()
metric.update(tensor([0, 2, 1, 3]), tensor([0, 1, 2, 3]))
metric.compute()

In [None]:
metric.reset()
metric.compute()

In [None]:
#| export
def to_cpu(x):
    if isinstance(x, Mapping):
        return {k: to_cpu(v) for k, v in x.items()}
    if isinstance(x, list):
        return [to_cpu(o) for o in x]
    if isinstance(x, tuple):
        return tuple(to_cpu(list(x)))
    res = x.detach().cpu()
    return res.float() if res.dtype==torch.float16 else res

In [None]:
#| export
class MetricsCB(Callback):
    def __init__(self, *ms, **metrics):
        for o in ms: metrics[type(o).__name__] = o
        self.metrics = metrics
        self.all_metrics = copy(metrics)
        self.all_metrics['loss'] = self.loss = Mean()
        print(self.metrics)
    
    def _log(self, d): print(d)
    def before_fit(self, learn): learn.metrics = self
    def before_epoch(self, learn):
        [o.reset() for o in self.all_metrics.values()]
    
    def after_epoch(self, learn):
        log = {k: f'{v.compute():.3f}' for k,v in self.all_metrics.items()}
        log['epoch'] = learn.epoch
        log['train'] = 'train' if learn.model.training else 'eval'
        self._log(log)
    
    def after_batch(self, learn):
        x,y,*_ = to_cpu(learn.batch)
        for m in self.metrics.values():
            m.update(to_cpu(learn.preds), y)
        self.loss.update(to_cpu(learn.loss), weight=len(x))

In [None]:
#| export
class DeviceCB(Callback):
    def __init__(self, device=def_device): fc.store_attr()

    def before_fit(self, learn):
        if hasattr(learn.model, 'to'): learn.model.to(self.device)
    
    def before_batch(self, learn):
        learn.batch = to_device(learn.batch, device=self.device)

In [None]:
def get_model():
    return nn.Sequential(
        nn.Linear(m, nh),
        nn.ReLU(),
        nn.Linear(nh, 10)
    )

In [None]:
model = get_model()
metrics = MetricsCB(accuracy=MulticlassAccuracy())

In [None]:
learn = Learner(model, dls, F.cross_entropy, lr=0.2, cbs=[DeviceCB(), metrics])
learn.fit(5)

In [None]:
import contextlib

In [None]:
@contextlib.contextmanager
def simple_context_manager():
    print('Enter')
    yield
    print('Exit')

with simple_context_manager():
    print('Inside the with block')

In [None]:
@contextlib.contextmanager
def buggy_context_manager():
    print("Enter")
    raise Exception("Exception before yield")
    yield
    print("Exit")

try:
    with buggy_context_manager():
        print("Inside the with block")
except Exception as e:
    print(f"Caught exception: {e}")

In [None]:
b = next(iter(dls.train))

In [None]:
b[0].shape

In [None]:
#| export
class TrainerCB(Callback):
    def __init__(self, n_inp=1): fc.store_attr()
    def predict(self, learn):
        learn.preds = learn.model(*learn.batch[:self.n_inp])
    def get_loss(self, learn): learn.loss = learn.loss_func(learn.preds, *learn.batch[self.n_inp:])
    def backward(self, learn): learn.loss.backward()
    def step(self, learn): learn.opt.step()
    def zero_grad(self, learn): learn.opt.zero_grad()

In [None]:
t = TrainerCB()
t.n_inp

In [None]:
fc.L.range(20).map(lambda x: x * 2)

In [None]:
#|export
class ProgressCB(Callback):
    order = MetricsCB.order + 1
    def __init__(self, plot=False): fc.store_attr()
    def before_fit(self, learn):
        learn.epochs = self.mbar = master_bar(learn.epochs)
        self.first = True
        if hasattr(learn, 'metrics'): learn.metrics._log = self._log
        self.train_losses, self.val_losses = [], []
    
    def _log(self, d):
        if self.first:
            self.mbar.write(list(d), table=True)
            self.first = False
        self.mbar.write(list(d.values()), table=True)
    
    def before_epoch(self, learn): learn.dl = progress_bar(learn.dl, leave=False, parent=self.mbar)
    def after_batch(self, learn):
        learn.dl.comment = f"{learn.loss:.3f}"
        if self.plot and hasattr(learn, 'metrics') and learn.training:
            self.train_losses.append(learn.loss.item())
            if self.val_losses:
                self.mbar.update_graph([
                    [fc.L.range(self.train_losses), self.train_losses],
                    [fc.L.range(learn.epoch).map(lambda x: (x+1) * len(learn.dls.train)), self.val_losses]
                ])
    
    def after_epoch(self, learn):
        if not learn.training and self.plot and hasattr(learn, 'metrics'):
            self.val_losses.append(learn.metrics.all_metrics['loss'].compute())
            # print(f"len val_losses = {len(self.val_losses)}, epoch={learn.epoch}")
            # learn.epoch has fisished, so learn.epoch + 1 = len(self.vasl_losses) here.
            self.mbar.update_graph([
                    [fc.L.range(self.train_losses), self.train_losses],
                    [fc.L.range(learn.epoch+1).map(lambda x: (x+1) * len(learn.dls.train)), self.val_losses]
                ])

In [None]:
#|export
class with_cbs:
    def __init__(self, nm): fc.store_attr()
    def __call__(self, f):
        # print(self.nm)
        def _f(o, *args, **kwargs):
            try:
                o.callback(f"before_{self.nm}")
                f(o, *args, **kwargs)
                o.callback(f"after_{self.nm}")
            except globals()[f'Cancel{self.nm.title()}Exception']: pass
            finally: o.callback(f'cleanup_{self.nm}')
        
        return _f

In [None]:
cbs = with_cbs('fit')
cbs('a')

In [None]:
class TestClass:
    @with_cbs('batch')
    def _one_batch(self):
        print('test callback')
    
    def callback(self, method_name):
        print(method_name)

In [None]:
t = TestClass()
t._one_batch()

In [None]:
#| export
class Learner:
    def __init__(self, model, dls=(0,), loss_func=F.cross_entropy, cbs=[], lr=0.1, opt_func=optim.SGD):
        self.model = model
        self.dls = dls
        self.loss_func = loss_func
        self.cbs = fc.L(cbs)
        self.lr = lr
        self.opt_func = opt_func
        # fc.store_attr()
    
    @with_cbs('batch')
    def _one_batch(self):
        self.predict()
        self.callback('after_predict')
        self.get_loss()
        self.callback('after_loss')
        if self.training:
            self.backward()
            self.callback('after_backward')
            self.step()
            self.callback('after_step')
            self.zero_grad()
    
    @with_cbs('epoch')
    def _one_epoch(self):
        for self.iter, self.batch in enumerate(self.dl): self._one_batch()
    
    def one_epoch(self, training):
        self.model.train(training)
        self.dl = self.dls.train if training else self.dls.valid
        self._one_epoch()
    
    @with_cbs('fit')
    def _fit(self, train, valid):
        for self.epoch in self.epochs:
            if train: self.one_epoch(True)
            if valid:
                with torch.no_grad(): self.one_epoch(False)
    
    def fit(self, n_epochs=1, train=True, valid=True, cbs=None, lr=None):
        cbs = fc.L(cbs)
        try:
            for cb in cbs:
                self.cbs.append(cb)
            self.epochs = range(n_epochs)
            if lr is None: lr = self.lr
            self.opt = self.opt_func(self.model.parameters(), lr=lr)
            self._fit(train, valid)
        finally:
            for cb in cbs: self.cbs.remove(cb)
    
    def callback(self, method_name):
        run_cbs(self.cbs, method_name, self)
    
    def __getattr__(self, name):
        if name in ('predict', 'get_loss', 'backward', 'step', 'zero_grad'):
            # print (partial(self.callback, name))
            return partial(self.callback, name)
    
    @property
    def training(self): return self.model.training

In [None]:
model = get_model()
cbs = [
    TrainerCB(),
    DeviceCB(),
    MetricsCB(accuracy=MulticlassAccuracy()),
    ProgressCB(plot=True),
    # SingleBatchCB()
]

learn = Learner(model, dls, F.cross_entropy, cbs=cbs, lr=0.2)

In [None]:
learn.fit(1)

In [None]:
dl = dls.train
batch = next(iter(dl))
batch

In [None]:
batch[0].shape, batch[1].shape

In [None]:
learn.predict()

In [None]:
f = partial(run_cbs, cbs, 'predict')

In [None]:
print(f(learn))
# Understand the progressCB more.

In [None]:
print(getattr(learn, 'loss'))

In [None]:
#|export
class TrainLearner(Learner):
    def predict(self): self.preds = self.model(self.batch[0])
    def get_loss(self): self.loss = self.loss_func(self.preds, self.batch[1])
    def backward(self): self.loss.backward()
    def step(self): self.opt.step()
    def zero_grad(self): self.opt.zero_grad()

In [None]:
model = get_model()
cbs = [
    # TrainerCB(),
    DeviceCB(),
    MetricsCB(accuracy=MulticlassAccuracy()),
    ProgressCB(plot=True),
    # SingleBatchCB()
]

learn = TrainLearner(model, dls, F.cross_entropy, cbs=cbs, lr=0.2)
learn.fit(5)

In [None]:
#|export
class MomentumLearner(TrainLearner):
    def __init__(self, model, dls=(0,), loss_func=F.cross_entropy, cbs=[], lr=0.1, opt_func=optim.SGD, mom=0.8):
        self.mom = mom
        super().__init__(model, dls, loss_func, cbs, lr, opt_func)
    
    def zero_grad(self):
        with torch.no_grad():
            for p in self.model.parameters():
                # print(p.grad)
                # if p.requires_grad:
                p.grad *= self.mom

In [None]:
model = get_model()
cbs = [
    # TrainerCB(),
    DeviceCB(),
    MetricsCB(accuracy=MulticlassAccuracy()),
    ProgressCB(plot=True),
    # SingleBatchCB()
]

learn = MomentumLearner(model, dls, F.cross_entropy, cbs=cbs, lr=0.2)
learn.fit(5)

In [None]:
#|export
from torch.optim.lr_scheduler import ExponentialLR
import pandas as pd

In [None]:
#|export
class LRFinderCB(Callback):
    def __init__(self, gamma=1.3, max_mult=3):
        fc.store_attr()
    
    def before_fit(self, learn):
        self.sched = ExponentialLR(learn.opt, self.gamma)
        self.lrs,self.losses = [],[]
        self.min = math.inf
    
    def after_batch(self, learn):
        if not learn.training:
            raise CancelEpochException()
        # last_lr = learn.opt.param_groups[0]['lr']
        last_lr = self.sched.get_last_lr()[0]
        self.lrs.append(last_lr)
        loss = learn.loss.item()
        self.losses.append(loss)
        if loss < self.min: self.min = loss
        if math.isnan(loss) or (loss > self.min * self.max_mult):
            raise CancelFitException()
        # increase the learning rate by multiplying by gamma
        self.sched.step()
    
    def cleanup_fit(self, learn):
        # print(f"learning rates are {self.lrs}, losses are {self.losses}")

        # lr_losses_pd = pd.DataFrame(
        #     {'lrs': self.lrs,
        #      'losses': self.losses
        #     })
        # print(lr_losses_pd)
        plt.plot(self.lrs, self.losses)
        plt.xscale('log')

In [None]:
model = get_model()
cbs = [
    # TrainerCB(),
    DeviceCB(),
    # MetricsCB(accuracy=MulticlassAccuracy()),
    # ProgressCB(plot=True),
    # SingleBatchCB()
]

learn = MomentumLearner(model, dls, F.cross_entropy, cbs=cbs, lr=1e-5)
learn.fit(1, cbs=LRFinderCB())

In [None]:
model = get_model()
cbs = [
    # TrainerCB(),
    DeviceCB(),
    # MetricsCB(accuracy=MulticlassAccuracy()),
    # ProgressCB(plot=True),
    # SingleBatchCB()
]

learn = MomentumLearner(model, dls, F.cross_entropy, cbs=cbs, lr=1e-5)
learn.fit(1, cbs=LRFinderCB())

In [None]:
#|export
@fc.patch
def lr_find(self: Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10):
        self.fit(max_epochs, train=True, valid=False, lr=start_lr, cbs=LRFinderCB(max_mult=max_mult, gamma=gamma))

In [None]:
learn = MomentumLearner(model, dls, F.cross_entropy, cbs=cbs, lr=1e-5)
learn.lr_find()

In [None]:
#|export
@fc.patch
def lr_find_test(self: Learner, gamma=1.3):
    print(f'gamma is {gamma}')

In [None]:
learn = MomentumLearner(model, dls, F.cross_entropy, cbs=cbs, lr=1e-5)
learn.lr_find_test()

In [None]:
import nbdev; nbdev.nbdev_export()