In [221]:
#| default_exp learner

## Imports

In [222]:
#|export
import math,torch,matplotlib.pyplot as plt
import fastcore.all as fc
from collections.abc import Mapping
from operator import attrgetter
from functools import partial
from copy import copy

from torch import optim
import torch.nn.functional as F
from torch.utils.data import default_collate

# from miniai.conv import *

from fastprogress import progress_bar,master_bar
from operator import itemgetter

In [223]:
import matplotlib as mpl
import torchvision.transforms.functional as TF
from contextlib import contextmanager
from torch import nn,tensor
from datasets import load_dataset,load_dataset_builder
import logging 

In [224]:
from torch.utils.data.dataloader import DataLoader, Dataset

In [225]:
logging.disable(logging.WARNING)

In [226]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'

## Dataset

In [227]:
#| export
class DataLoaders:
    def __init__(self, *dls): self.train,self.valid = dls[:2]

    @classmethod
    def from_dd(cls, dd, batch_size, as_tuple=True, **kwargs):
        f = collate_dict(dd['train'])
        return cls(*get_dls(*dd.values(), bs=batch_size, collate_fn=f, **kwargs))

In [228]:
x,y = 'image','label'
name = "fashion_mnist"
dsd = load_dataset(name)

  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
dsd["train"][x][0]

In [None]:
#| export 
def inplace(f):
    def _f(b):
        f(b)
        return b
    return _f

In [229]:
#| export 
def collate_dict(ds):
    get = itemgetter(*ds.features)
    def _f(b): return get(default_collate(b))
    return _f

In [230]:
@inplace
def transformi(b): b[x] = [torch.flatten(TF.to_tensor(o)) for o in b[x]]

In [231]:
bs = 1024
tds = dsd.with_transform(transformi)

In [232]:
def get_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs*2, **kwargs))

In [233]:
dls = DataLoaders.from_dd(tds, bs, num_workers=4)
dt = dls.train


In [234]:
dt.collate_fn

<function __main__.collate_dict.<locals>._f(b)>

In [235]:
xb,yb = next(iter(dt))
xb.shape,yb[:10]

(torch.Size([1024, 784]), tensor([5, 4, 9, 4, 3, 0, 6, 5, 7, 6]))

## Callback 

In [236]:
#| export
class cb:
    order = 0
        

In [237]:
class testcb(cb):
    def __init__(self):fc.store_attr()
        
    def before_fit(self, learn):
        print(f"starting fit function")
        
    def after_fit(self, learn):
        print(f"ending fit function with loss : {learn.loss}")

In [238]:
# class testcb1(cb):
#     def __init__(self):
#         fc.store_attr()
#         self.order = -1
        
#     def before_fit(self, learn):
#         print(f"starting fit function")
        
#     def after_fit(self, learn):
#         print(f"ending fit function with loss : {learn.loss}")

In [239]:
cbs = [ testcb()]

In [240]:
sorted(cbs, key= attrgetter("order"))

[<__main__.testcb at 0x7ff3b9d75300>]

In [241]:
if getattr(testcb(), "before_fit"): pass

In [242]:
#| export

def rcb(cbs, method_name, learn):
    for cb in sorted(cbs, key= attrgetter("order")):
        method =  getattr(cb, method_name, None )
        if method : return method(learn)
        

In [243]:
rcb(cbs, "after_fit",learn)

ending fit function with loss : 0.781036913394928


## Learner Framework Basic

In [244]:
#| export 
def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

def to_device(x, device=def_device):
    if isinstance(x, torch.Tensor): return x.to(device)
    if isinstance(x, Mapping): return {k:v.to(device) for k,v in x.items()}
    return type(x)(to_device(o, device) for o in x)

In [245]:

class learner:
    
    def __init__(self, model, dls,loss_func, lr,opt_func=optim.SGD):fc.store_attr()
    
    def one_batch(self):
        x,y = to_device(self.b)
        self.preds = self.model(x)
        self.loss = self.loss_func(self.preds, y)
        if self.model.training: 
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()
#         else :
#             with torch.no_grad():
#                 self.preds = self.model(b[0])
            
        
    
    def one_epoch(self, train):
        self.model.training = train
        dl = self.dls.train if train else self.dls.valid
        for self.b in dl:self.one_batch()
        
    def fit(self, epochs):
        self.model = self.model.to(def_device)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        for e in range(epochs):
            self.one_epoch(True)
            with torch.no_grad(): self.one_epoch(False)
                
            print(f"epoch {e} loss :{self.loss}")
        
    

In [246]:
m,nh = 28*28,50
model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))

In [247]:
learn = learner(model, dls, F.cross_entropy, lr=0.2)
learn.fit(3)

epoch 0 loss :0.7705832719802856
epoch 1 loss :0.6212591528892517
epoch 2 loss :0.7252416014671326


## Flexible Learner

In [None]:
from functools import 

In [None]:
class ctx_mng():
    def __init__(self, name):self.name = name
        
    def __enter__(self):print(f"Executing code before {self.name} function")
        
    def __exit__(self, *args,**kwargs):
        print(f"Executing code after{self.name} function")
   
    

    

In [None]:
with ctx_mng("before") as ctx:
    print("function")

In [None]:
class learner:
    
    def __init__(self, model, dls,loss_func, lr,cbs ,opt_func=optim.SGD):fc.store_attr()
    
    def one_batch(self):
        x,y = to_device(self.b)
        with ctx_mng('batch'):
            self.preds = self.model(x)
            self.loss = self.loss_func(self.preds, y)
            if self.model.training: 
                self.loss.backward()
                self.opt.step()
                self.opt.zero_grad()
#         else :
#             with torch.no_grad():
#                 self.preds = self.model(b[0])
            
        
    
    def one_epoch(self, train):
        self.model.training = train
        dl = self.dls.train if train else self.dls.valid
        with ctx_mng('epoch'):
            for self.b in dl:self.one_batch()
        
    def fit(self, epochs):
        self.model = self.model.to(def_device)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        with ctx_mng('fit'):
            for e in range(epochs):
                self.one_epoch(True)
                with torch.no_grad(): self.one_epoch(False)

                print(f"epoch {e} loss :{self.loss}")

In [None]:
m,nh = 28*28,50
model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))

In [None]:
learn = learner(model, dls, F.cross_entropy, lr=0.2, cbs = [])
learn.fit(1)

## More Flexible learner 

In [None]:
#| export

class CancelFitException:pass
class CancelEpochException:pass
class CancelBatchException:pass

In [248]:
"fit".title()

'Fit'

In [249]:
#| export
class cb_dec:
    def __init__(self, name):
        fc.store_attr()
        
    def __call__(self, f):
        def _f(o,*args,**kwargs):
            try:
                o.callback(f"before_{self.name}")
                f(o,*args, **kwargs)
                o.callback(f"after_{self.name}")
            except: globals()[f"Cancel{self.name.title()}Exception"]()
        return _f 
            
            

In [250]:
#| export 
class learner:
    
    def __init__(self, model, dls,loss_func, lr,cbs=[] ,opt_func=optim.SGD):fc.store_attr()
    
    @cb_dec("batch")
    def one_batch(self):
        x,y = to_device(self.b)
        
        self.preds = self.model(x)
        self.loss = self.loss_func(self.preds, y)
        if self.model.training: 
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()
#         else :
#             with torch.no_grad():
#                 self.preds = self.model(b[0])
            
        
    @cb_dec("epoch")
    def one_epoch(self, train):
        self.model.training = train
        dl = self.dls.train if train else self.dls.valid
        
        for self.b in dl:self.one_batch()
    
    @cb_dec("fit")
    def fit(self, epochs):
        self.model = self.model.to(def_device)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        
        for e in range(epochs):
            self.one_epoch(True)
            with torch.no_grad(): self.one_epoch(False)

            print(f"epoch {e} loss :{self.loss}")
                
    def callback(self, nm): rcb(self.cbs, nm, self)
        

In [251]:
# def rcb(cbs, method_name, learn):
#     for cb in sorted(cbs, key= attrgetter("order")):
#         try: return getattr(cb, method_name)(learn)
#         except Exception as e: print(e)

In [252]:
m,nh = 28*28,50
model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))

In [253]:
learn = learner(model, dls, F.cross_entropy,lr=0.2,cbs = [testcb()])
learn.fit(1)

starting fit function
epoch 0 loss :0.7497103214263916
ending fit function with loss : 0.7497103214263916


## Trainer Learner

In [None]:
#|export
class TrainLearner(learner):
    def predict(self): self.preds = self.model(self.batch[0])
    def get_loss(self): self.loss = self.loss_func(self.preds, self.batch[1])
    def backward(self): self.loss.backward()
    def step(self): self.opt.step()
    def zero_grad(self): self.opt.zero_grad()

In [254]:
m,nh = 28*28,50
model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))

learn = TrainLearner(model, dls, F.cross_entropy,lr=0.2,cbs = [testcb()])
learn.fit(1)

starting fit function
epoch 0 loss :0.8289142847061157
ending fit function with loss : 0.8289142847061157


## Momentum Learner


In [None]:
#| export 
class momentumLearner(learner):
    def __init__(self, model, dls, loss_func, lr=None, cbs=None, opt_func=optim.SGD, mom=0.85):
        self.mom = mom
        super().__init__(model, dls, loss_func, lr, cbs, opt_func)
        
    def step(self): self.opt.step()
        
    def zero_grad(self): 
        with torch.no_grad():
            for p in self.model.parameters(): p.grad *= self.mom

In [255]:
m,nh = 28*28,50
model = nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))

learn = TrainLearner(model, dls, F.cross_entropy,lr=0.2,cbs = [testcb()])
learn.fit(1)

starting fit function
epoch 0 loss :0.7780116200447083
ending fit function with loss : 0.7780116200447083


## Metrics Callback

In [256]:
from torcheval.metrics import MulticlassAccuracy

In [257]:
ms = [MulticlassAccuracy()]

In [258]:
metricss={}
metricss[type(ms[0]).__name__] = ms[0]

In [259]:
metricss

{'MulticlassAccuracy': <torcheval.metrics.classification.accuracy.MulticlassAccuracy at 0x7ff39a795090>}

## Export 

In [260]:
#| hide 
import nbdev; nbdev.nbdev_export()