In [34]:
import math,torch,matplotlib.pyplot as plt
import fastcore.all as fc
from collections.abc import Mapping
from operator import attrgetter
from functools import partial
from copy import copy

from torch import optim
import torch.nn.functional as F

from n_framework import *

from fastprogress import progress_bar,master_bar

In [35]:
import matplotlib as mpl
import torchvision.transforms.functional as TF
from contextlib import contextmanager
from torch import nn,tensor
from datasets import load_dataset,load_dataset_builder
from n_framework import *
import logging
from fastcore.test import test_close

In [36]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(1)
mpl.rcParams['image.cmap'] = 'gray'

In [37]:
logging.disable(logging.WARNING)

## Learner

In [38]:
x,y = 'image','label'
name = 'fashion_mnist'
dsd = load_dataset(name)

  0%|          | 0/2 [00:00<?, ?it/s]

In [62]:
@inplace
def transformi(b): b[x]=[torch.flatten(TF.to_tensor(o))for o in b[x]]

In [63]:
bs=1024
tds = dsd.with_transform(transformi)

In [64]:
#|export
class DataLoaders:
    def __init__(self,*dls): self.train,self.valid = dls[:2]
    @classmethod
    def from_dd(cls,dd,batch_size, as_tuple=True):
        #return cls(*[DataLoader(ds, batch_size,num_workers=4, collate_fn=collate_dict(ds)) for ds in dd.values()])
        #return cls(*[DataLoader(ds, batch_size, collate_fn=collate_dict(ds)) for ds in dd.values()])
        return cls(*[DataLoader(ds, batch_size,num_workers=8, collate_fn=collate_dict(ds)) for ds in dd.values()])
        

In [65]:
dls = DataLoaders.from_dd(tds,bs)
dt = dls.train
xb,yb = next(iter(dt))
xb.shape,yb[:10]

(torch.Size([1024, 784]), tensor([9, 0, 0, 3, 0, 2, 7, 2, 5, 5]))

In [66]:
class Learner:
    def __init__(self, model, dls, loss_func, lr, opt_func=optim.SGD): fc.store_attr()

    def one_batch(self):
        self.xb,self.yb = to_device(self.batch)
        self.preds = self.model(self.xb)
        self.loss = self.loss_func(self.preds, self.yb)
        if self.model.training:
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()
        with torch.no_grad(): self.calc_stats()

    def calc_stats(self):
        acc = (self.preds.argmax(dim=1)==self.yb).float().sum()
        self.accs.append(acc)
        n = len(self.xb)
        self.losses.append(self.loss*n)
        self.ns.append(n)

    def one_epoch(self, train):
        self.model.training = train
        dl = self.dls.train if train else self.dls.valid
        for self.num,self.batch in enumerate(dl): self.one_batch()
        n = sum(self.ns)
        print(self.epoch, self.model.training, sum(self.losses).item()/n, sum(self.accs).item()/n)
    
    def fit(self, n_epochs):
        self.accs,self.losses,self.ns = [],[],[]
        self.model.to(def_device)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        self.n_epochs = n_epochs
        for self.epoch in range(n_epochs):
            self.one_epoch(True)
            with torch.no_grad(): self.one_epoch(False)

In [67]:
m,nh =28*28,50
model = nn.Sequential(nn.Linear(m,nh),nn.ReLU(),nn.Linear(nh,10))

In [68]:
def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

def to_device(x, device=def_device):
    if isinstance(x, torch.Tensor): return x.to(device)
    if isinstance(x, Mapping): return {k:v.to(device) for k,v in x.items()}
    return type(x)(to_device(o, device) for o in x)

def collate_device(b): return to_device(default_collate(b))

In [33]:
#When the num_workers=8
learn = Learner(model,dls, F.cross_entropy, lr=0.2)
%time learn.fit(1)

0 True 1.1843188802083333 0.5944
0 False 1.142808705357143 0.6042571428571428
CPU times: user 480 ms, sys: 200 ms, total: 681 ms
Wall time: 2.86 s


In [25]:
#When the num_workers=4
learn = Learner(model,dls, F.cross_entropy, lr=0.2)
%time learn.fit(1)

0 True 1.169957421875 0.60065
0 False 1.1254754464285714 0.6133714285714286
CPU times: user 437 ms, sys: 146 ms, total: 584 ms
Wall time: 3.6 s


In [19]:
#When the num_workers is not indicated
learn = Learner(model,dls, F.cross_entropy, lr=0.2)
%time learn.fit(1)

0 True 1.17436875 0.6026333333333334
0 False 1.1363360491071428 0.6131714285714286
CPU times: user 8.31 s, sys: 102 ms, total: 8.41 s
Wall time: 8.47 s


## Metric

In [None]:
#|export
class Metric:
    def __init__(self): self.reset()
    def reset(self): self.vals, self.ns = [],[]
    def add(self,inp, targ=None, n=1): # n is the number of items in the mini-batch.
        self.last = self.calc(inp,targ)
        self.vals.append(self.last)
        self.ns.append(n)
    @property
    def value(self):
        ns = tensor(self.ns)
        return (tensor(self.vals)*ns).sum()/ns.sum()
    def calc(self, inps,targs): return inps 

In [None]:
#|export
class Accuracy(Metric):
    def calc(self,inps,targs): return (inps==targs).float().mean()

In [None]:
acc = Accuracy()
acc.add(tensor([0,1,2,0,1,2]),tensor([0,1,1,2,1,0]))
acc.add(tensor([1,1,2,0,1]),tensor([0,1,1,2,1]))
acc.value

In [None]:
loss=Metric()
loss.add(0.6,n=32)
loss.add(0.9,n=2)
loss.value,round((0.6*32+0.9*2)/(32+2),2)

## Basic Callback Learner

In [69]:
#|export
def identity(*args):
    #import pdb;pdb.set_trace()
    if not args:return
    x,*args = args
    return (x,)+tuple(args) if args else x

In [70]:
identity(3,'4',1)

(3, '4', 1)

In [71]:
#|export
class CancelFitException(Exception): pass
class CancelBatchException(Exception): pass
class CancelEpochException(Exception): pass

In [72]:
#| export
def run_cbs(cbs,method_nm):
    for cb in sorted(cbs, key=attrgetter('order')):
        method = getattr(cb, method_nm, None)
        if method is not None:method()

In [73]:
#|export 
class callback(): order=0

In [85]:
class CompletionCallback(callback):
    def before_fit(self): self.count= 0
    def after_batch(self): self.count += 10
    def after_fit(self): print(f'Completed {self.count} batches')

In [86]:
cbs = [CompletionCallback()]
run_cbs(cbs, 'before_fit')
run_cbs(cbs, 'after_batch')
run_cbs(cbs, 'after_fit')

Completed 10 batches


In [87]:
cb = cbs[0]

In [88]:
getattr(cb, 'after_fit', None)()

Completed 10 batches


In [76]:
#|export
class with_cbs:
    def __init__(self,nm): self.nm = nm
    def __call__(self, f):
        def _f(o,*args,**kwargs):
            try:
                o.callback(f"before_{self.nm}")
                f(o,*args,**kwargs)
                o.callback(f"After_{self.nm}")
            except globals()[f'Cancel{self.nm.title}Exception']:pass
        return _f

In [77]:
class Learner:
    def __init__(self, model, dls, loss_func, lr, cbs,opt_func=optim.SGD): fc.store_attr()

    def one_batch(self):
        self.preds = self.model(self.xb)
        self.loss = self.loss_func(self.preds, self.yb)
        if self.model.training:
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()

    def one_epoch(self, train):
        self.model.training = train
        self.dl = self.dls.train if train else self.dls.valid
        try:
            self.callback('before_epoch')            
            for self.iter,self.batch in enumerate(dl): 
                try:
                    self.callback('before_batch')
                    self.one_batch()
                    self.callback('after_batch')
                except CancelBatchException:pass
            self.callback('after_epoch')
        except CancelEpochException:pass
    
    def fit(self, n_epochs):
        self.n_epochs = n_epochs
        self.epochs = range(n_epochs)
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        try:
            self.callback('before_fit')
            for self.epoch in self.epochs:
                self.one_epoch(True)
                self.one_epcoh(False)
            self.callback('after_fit')
        except CancelFitException:pass

    def callback(self,method_nm): run_cbs(self.cbs,method_nm)

In [78]:
m,nh = 28*28,50
def get_model(): return nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))

In [79]:
learn = Learner(dls, get_model(), F.cross_entropy, lr=0.2, cbs=[CompletionCallback()])
learn.fit(1)

AttributeError: 'DataLoaders' object has no attribute 'parameters'

In [None]:
#|export
class Learner():
    def __init__(self,model,dls,loss_func,lr, cbs, opt_func=optim.SGD):
        @fc.store_attr()
        for cb in cbs: cb.learn= self
            
    @with_cbs('batch')
    def one_batch(self):
        self.predit()
        self.get_loss()
        if self.model.training:
            self.backward()
            self.step()
            self.zero_grad()
    
    def one_epoch(self,train):
        self.model.train(train)
        self.dl = self.dls.train if train else self.dls.valid
        self._one_epoch()
    
    @with_cb('epoch')
    def _one_epoch(self):
        for self.iter, self.batch in enumerate(self.dl): self.one_epoch()
    
    def fit(self, n_epochs):
        self.epochs = n_epochs
        self.epochs = range(n_epochs)
        self.opt = self.opt_func(self.model.parameters(),lr)
        self._fit()
        
    @with_cbs('fit')
    def _fit(self):
        for self.epoch in self.epochs:
            self.one_epoch(True)
            self.one_epoch(False)
    
    def __getattr__(self,name):
        if name in('predict','get_loss', 'backward','step','zero_grad'): return partial(self.callback,name)
        raise AttributeError(name)
    def callback(self,method_nm):
        for cb in sorted(self.cbs, key=attrgetter('order')): getattr(cb,method_nm,identity)()