In [1]:
#| default_exp learner

In [2]:
#|export
import math, torch, matplotlib.pyplot as plt
import fastcore.all as fc
from collections.abc import Mapping
from operator import attrgetter
from functools import partial
from copy import copy
from torch import optim
import torch.nn.functional as F

from diy_stable_diffusion.conv import *

from fastprogress import progress_bar,master_bar

In [3]:
import matplotlib as mpl
import torchvision.transforms.functional as TF
from torch import nn, tensor
from datasets import load_dataset, load_dataset_builder
from diy_stable_diffusion.datasets import *
import logging
from fastcore.test import test_close

In [4]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'

In [5]:
logging.disable(logging.WARNING)

In [6]:
datasetdict = load_dataset("fashion_mnist")

Downloading builder script:   0%|          | 0.00/2.00k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.36k [00:00<?, ?B/s]

Downloading and preparing dataset fashion_mnist/fashion_mnist (download: 29.45 MiB, generated: 34.84 MiB, post-processed: Unknown size, total: 64.29 MiB) to /root/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1...


Downloading data files:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/26.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/29.5k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.42M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.15k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/4 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/60000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset fashion_mnist downloaded and prepared to /root/.cache/huggingface/datasets/fashion_mnist/fashion_mnist/1.0.0/8d6c32399aa01613d96e2cbc9b13638f359ef62bb33612b077b4c247f6ef99c1. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
@inplace
def transformi(b):
    b['image'] = [torch.flatten(TF.to_tensor(i)) for i in b['image']]

In [8]:
bs = 1024
datasetdict = datasetdict.with_transform(transformi)
dls = DataLoaders.from_dd(datasetdict, bs, num_workers=4)

In [53]:
#|export

class CancelFitException(Exception):
    pass
class CancelBatchException(Exception):
    pass
class CancelEpochException(Exception):
    pass

class Callback:
    order = 0
    _fwd = 'model','opt', 'batch', 'epoch'
    
    def __getattr__(self, name):
        if name in self._fwd:
            return getattr(self.learn, name)
        super().__getattr__(self, name)
    
    def __setattr__(self, name, val):
        if name in self._fwd:
            warn(f'setting {name} in callback, did you mean to set it on learner? Attribute accessible by proxy on the callback.')
        super().__setattr__(name, val)
    
    @property
    def training(self):
        return self.model.training
    

def run_cbs(cbs, methodname, learn=None):
    for cb in sorted(cbs, key=attrgetter('order')):
        method = getattr(cb, methodname, None)
        if method is not None:
            method(learn)



In [24]:
#|export
class SingleBatchCB(Callback):
    order = 1
    def after_batch(self, learn):
        raise CancelFitException()

In [25]:
class Metrics:
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.vals = []
        self.ns = []
        
    def add(self, inp, targ = None, n=1):
        self.last = self.calc(inp, targ)
        self.vals.append(self.last)
        self.ns.append(n)
    
    @property
    def value(self):
        ns = tensor(self.ns)
        return (tensor(self.vals)*ns).sum() / ns.sum() # weighted mean

    def calc(self, inps, targs):
        return inps

In [26]:
class Accuracy(Metrics):
    def calc(self, inps, targs):
        return (inps == targs).float().mean()

In [27]:
acc = Accuracy()
acc.add(tensor([1, 1, 1]), tensor([2, 2, 2]))
acc.add(tensor([2]), tensor([2]))
acc.value

tensor(0.50)

In [28]:
#|export
from torcheval.metrics import Mean,MulticlassAccuracy

In [29]:
#|export
def to_cpu(x):
    if isinstance(x, Mapping):
        return {k:to_cpu(v) for k,v in x.items()}
    if isinstance(x, list):
        return [to_cpu(i) for i in x]
    if isinstance(x, tuple):
        return tuple(to_cpu(list(x)))
    
    x = x.detach().cpu()
    
    return x.float() if x.dtype==torch.float16 else x

In [30]:
# a bunch of all the rest of the stuff in the tuple packed into _
a, b, *_ = (1, 2, 3, 4, 5)
a, b, _

(1, 2, [3, 4, 5])

In [31]:
'''
note; copy pasted, only hand-copied the exported version
'''

class Learner():
    def __init__(self, model, dls, loss_func, lr, cbs, opt_func=optim.SGD): fc.store_attr()

    def one_batch(self):
        self.preds = self.model(self.batch[0])
        self.loss = self.loss_func(self.preds, self.batch[1])
        if self.model.training:
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()

    def one_epoch(self, train):
        self.model.train(train)
        self.dl = self.dls.train if train else self.dls.valid
        try:
            self.callback('before_epoch')
            for self.iter,self.batch in enumerate(self.dl):
                try:
                    self.callback('before_batch')
                    self.one_batch()
                    self.callback('after_batch')
                except CancelBatchException: pass
            self.callback('after_epoch')
        except CancelEpochException: pass
    
    def fit(self, n_epochs):
        self.n_epochs = n_epochs
        self.epochs = range(n_epochs)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        try:
            self.callback('before_fit')
            for self.epoch in self.epochs:
                self.one_epoch(True)
                self.one_epoch(False)
            self.callback('after_fit')
        except CancelFitException: pass

    def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)

In [57]:
#|export
class MetricsCB(Callback):
    def __init__(self, *ms, **metrics):
        for o in ms: metrics[type(o).__name__] = o
        self.metrics = metrics
        self.all_metrics = copy(metrics)
        self.all_metrics['loss'] = self.loss = Mean()

    def _log(self, d): print(d)
    def before_fit(self, learn): learn.metrics = self
    def before_epoch(self, learn): [o.reset() for o in self.all_metrics.values()]

    def after_epoch(self, learn):
        log = {k:f'{v.compute():.3f}' for k,v in self.all_metrics.items()}
        log['epoch'] = learn.epoch
        log['train'] = 'train' if learn.model.training else 'eval'
        self._log(log)

    def after_batch(self, learn):
        x,y,*_ = to_cpu(learn.batch)
        for m in self.metrics.values(): m.update(to_cpu(learn.preds), y)
        self.loss.update(to_cpu(learn.loss), weight=len(x))

In [39]:
#|export
class DeviceCB(Callback):
    def __init__(self, device = def_device):
        self.device = def_device
    def before_fit(self, learn):
        if hasattr(learn.model, 'to'):
            learn.model.to(self.device)
    def before_batch(self, learn):
        learn.batch = to_device(learn.batch, device = self.device)

In [40]:
in_features = 784
hidden_features = 50
def get_model():
    return nn.Sequential(
        nn.Linear(in_features,hidden_features),
        nn.ReLU(),
        nn.Linear(hidden_features,10)
    )

In [41]:
model=get_model()

In [58]:
metrics = MetricsCB(accuracy = MulticlassAccuracy())

In [None]:
learn = Learner(model, dls, F.cross_entropy, lr=0.2, 
                cbs = [DeviceCB(), MetricsCB()])

In [None]:
learn.fit(1)

## Updated learner
using decorator for callbacks, at the end of lesson nb

In [None]:
#|export
class with_cbs:
    def __init__(self, cbname):
        self.cbname = cbname
    
    def __call__(self, f):
        def _f(other, *args, **kwargs):
            try:
                '''
                here, other means the self of the method, hence, a Learner
                instance
                '''
                other.callback(f'before_{self.cbname}')
                f(other, *args, **kwargs)
                other.callback(f'after_{self.cbname}')
            # Note: if CancelXException doesn't exist then what? what then?
            except globals()[f'Cancel{self.cbname.title()}Exception']:
                pass
            finally:
                other.callback(f'cleanup_{self.cbname}')
            
        return _f

In [None]:
#|export
class Learner():
    def __init__(self, model, dls=(0,), loss_func = F.mse_loss, lr=0.1,
                 cbs=None, opt_func=optim.SGD):
        cbs = fc.L(cbs)
        fc.store_attr()
    
    @with_cbs('batch')
    def _one_batch(self):
        self.predict()
        self.callback('after_predict')
        self.get_loss()
        self.callback('after_loss')
        if self.training:
            self.backward()
            self.callback('after_backward')
            self.step()
            self.callback('after_step')
            self.zero_grad()
    
    
    @with_cbs('epoch')
    def _one_epoch(self):
        for self.iter, self.batch in enumerate(self.dl):
            self._one_batch()
    
    def one_epoch(self, is_training):
        self.model.train(is_training)
        self.dl = self.dls.train if is_training else self.dls.valid
        self._one_epoch()
        
    @with_cbs('fit')
    def _fit(self, do_train, do_validate):
        # have to iterate over an "epochs" range, because
        # this has to be extensible in arbitrary, unintuitive, unexpected ways
        for self.epoch in self.epoch_range:
            if do_train:
                self.one_epoch(True)
            if do_validate:
                with torch.no_grad():
                    self.one_epoch(False)
                    
    def fit(self, n_epochs=1, do_train=True, do_validate=True, cbs=None, lr=None):
        cbs=fc.L(cbs)
        for cb in cbs:
            # removed at the finally
            self.cbs.append(cb) 
        try:
            self.n_epochs = n_epochs
            self.epoch_range = range(n_epochs)
            if lr is None:
                lr = self.lr

            # I guess just nothing works if this isn't provided in __init__
            if self.opt_func is not None:
                # another permanent change to learner state
                self.opt = self.opt_func(self.model.parameters(), lr)
            self._fit(do_train, do_validate)
        finally:
            for cb in cbs:
                self.cbs.remove(cb)
    
    def __getattr__(self, name):
        if name in ('predict', 'get_loss', 'backward', 'step', 'zero_grad'):
            return partial(self.callback, name)
        else:
            raise AttributeError(name)
            
    def callback(self, name):
        run_cbs(self.cbs, name, self)
    
    @property
    def training(self):
        return self.model.training

In [None]:
#|export
class TrainCB(Callback):
    def __init__(self, n_inp=1):
        self.n_inp = n_inp
    
    def predict(self, learn):
        learn.preds = learn.model(*learn.batch[:self.n_inp])
        
    def get_loss(self, learn):
        learn.loss = learn.loss_func(learn.preds, *learn.batch[self.n_inp:])
    
    def backward(self, learn):
        learn.loss.backward()
    
    def step(self, learn):
        learn.opt.step()
    def zero_grad(self, learn):
        learn.opt.zero_grad()

In [None]:
x =master_bar(range(10))

In [None]:
for i in x:
    z=[i for i in range(int(1e7))]
    x.show()

In [None]:
m={'a': 1, 'z': 99}
list(m),m.keys(),list(m.keys())

In [None]:
#|export
class ProgressCB(Callback):
    order = MetricsCB.order + 1
    def __init__(self, plot=False):
        self.plot = plot
    
    def before_fit(self, learn):
        learn.epoch_range = self.mbar = master_bar(learn.epoch_range)
        self.first = True
        
        # this just waits for a bug where some different thing is called
        # metrics. This whole code base is an exercise in making sure
        # every possible bit in the computer's RAM is tightly coupled
        # with every other bit.
        if hasattr(learn, 'metrics'):
            learn.metrics._log = self._log
        
        self.losses = []
        self.val_losses = []
    
    def _log(self, logdict):
        if self.first:
            self.mbar.write(list(logdict.keys()), table=True)
            self.first = False
        
        self.mbar.write(list(logdict.values()), table=True)
        
    def before_epoch(self, learn):
        learn.dl = progress_bar(learn.dl, leave=False, parent=self.mbar)
    
    def after_batch(self, learn):
        # Assumes dl field of learner has been mutated by this class, turned
        # into a progress bar type. Sure do hope no other callbacks make
        # conflicting assumptions about learn.dl, learn.epoch_range, etc
        learn.dl.comment = f'{learn.loss:.3f}'
        
        if self.plot and hasattr(learn, 'metrics') and learn.training:
            self.losses.append(learn.loss.item())
            if self.val_losses:
                self.mbar.update_graph(
                    [
                        [fc.L.range(self.losses), self.losses],
                        [
                            fc.L.range(learn.epoch).map(lambda x: (x+1)*len(learn.dls.train)),
                            self.val_losses
                        ]
                    ])
    
    def after_epoch(self, learn):
        if not learn.training:
            if self.plot and hasattr(learn, 'metrics'):
                self.val_losses.append(learn.metrics.all_metrics['loss'].compute())
                self.mbar.update_graph(
                    [[fc.L.range(self.losses), self.losses],
                     [
                         fc.L.range(learn.epoch+1).map(lambda x: (x+1)*len(learn.dls.train)),
                         self.val_losses
                     ]
                    ])    

In [None]:
model=get_model()
metrics=MetricsCB(accuracy=MulticlassAccuracy())


In [None]:
cbs=[TrainCB(), DeviceCB(), metrics, ProgressCB(plot=True)]

In [None]:
learn = Learner(model, dls, F.cross_entropy, lr=0.2, cbs=cbs)

In [None]:
learn.fit(2)

In [None]:
#|export
class TrainLearner(Learner):
    def predict(self): self.preds = self.model(self.batch[0])
    def get_loss(self): self.loss = self.loss_func(self.preds, self.batch[1])
    def backward(self): self.loss.backward()
    def step(self): self.opt.step()
    def zero_grad(self): self.opt.zero_grad()

In [None]:
#|export
class MomentumLearner(TrainLearner):
    def __init__(self, model, dls, loss_func, lr=None, cbs=None, opt_func=optim.SGD, mom=0.85):
        self.mom = mom
        super().__init__(model, dls, loss_func, lr, cbs, opt_func)

    def zero_grad(self):
        with torch.no_grad():
            for p in self.model.parameters(): p.grad *= self.mom

In [None]:

class TrainMomentumCB(TrainCB):
    
    def __init__(self, mom):
        self.mom = mom
        super().__init__()
    
    def zero_grad(self, learn):
        with torch.no_grad():
            for p in learn.model.parameters():
                p.grad *= self.mom

In [None]:
metrics = MetricsCB(accuracy=MulticlassAccuracy())
cbs = [TrainMomentumCB(0.85), DeviceCB(), metrics, ProgressCB(plot=True)]
learn = Learner(get_model(), dls, F.cross_entropy, lr=0.25, cbs=cbs)
learn.fit(3)

## LR finder callback


In [None]:
class LRFinderCB(Callback):
    def __init__(self, lr_mult = 1.3):
        self.lr_mult = lr_mult
    
    def before_fit(self, learn):
        self.lrs, self.losses = [], []
        self.min = math.inf
        
    def after_batch(self, learn):
        if not learn.training:
            raise CancelEpochException()
        # I mean, just, you know, it's fine.
        self.lrs.append(learn.opt.param_groups[0]['lr'])
        loss = to_cpu(learn.loss)
        self.losses.append(loss)
        
        if loss < self.min:
            self.min = loss
        if loss > self.min * 3:
            raise CancelFitException()
        
        for g in learn.opt.param_groups:
            # ?????? deep learning I mean deep magic more like it???
            g['lr'] *= self.lr_mult
    def after_fit(self, learn):
        plt.plot(self.lrs, self.losses)
        plt.xscale('log')


In [None]:
lrfind = LRFinderCB()
cbs = [lrfind, TrainMomentumCB(0.85), DeviceCB(), metrics, ProgressCB(plot=True)]
learn = Learner(get_model(), dls, F.cross_entropy, lr=1e-4, cbs=cbs)
learn.fit(2)

In [None]:
plt.plot(lrfind.lrs, lrfind.losses)

In [None]:
plt.plot(lrfind.lrs, lrfind.losses)
plt.xscale('log')

In [None]:
#|export
from torch.optim.lr_scheduler import ExponentialLR

In [None]:
#|export
class LRFinderCB(Callback):
    def __init__(self, gamma=1.3, max_mult=3):
        fc.store_attr()
    
    def before_fit(self, learn):
        # sched.step will mutate lr inside the optimizer according to gamma
        # exponent
        self.sched = ExponentialLR(learn.opt, self.gamma)
        self.lrs = []
        self.losses = []
        self.min = math.inf
        
    def after_batch(self, learn):
        if not learn.training:
            raise CancelEpochException()
        
        self.lrs.append(learn.opt.param_groups[0]['lr']) # Note, I changed this, is it broken?
        loss = to_cpu(learn.loss)
        self.losses.append(loss)
        
        if loss < self.min:
            self.min = loss
        if math.isnan(loss) or (loss > self.min * self.max_mult):
            raise CancelFitException()
        
        self.sched.step()
        
    # new callback type "cleanup" will not be cancelled by
    # cancel fit exception, but after_fit will
    def cleanup_fit(self, learn):
        plt.plot(self.lrs, self.losses)
        plt.xscale('log')

In [None]:
learn = MomentumLearner(get_model(), dls, F.cross_entropy,
                        lr=1e-5, cbs=[DeviceCB()])
learn.fit(3, cbs=[LRFinderCB()])

In [None]:
@fc.patch?

In [None]:
#|export
@fc.patch
def lr_find(self:Learner, gamma=1.3, max_mult=3, start_lr=1e-5, max_epochs=10):
    self.fit(max_epochs, lr=start_lr, cbs=LRFinderCB(gamma=gamma, max_mult=max_mult))

In [None]:
MomentumLearner(get_model(), dls, F.cross_entropy, cbs=[DeviceCB()]).lr_find()

In [59]:
import nbdev; nbdev.nbdev_export()