In [13]:
import fastcore.basics as fc
import torch
from torch import optim, nn

from utils import with_cbs, run_cbs, inplace

from exports import Callback

import torch.nn.functional as F

from datasets import load_dataset, load_dataset_builder

import torchvision.transforms.functional as TF

from torcheval.metrics import MulticlassAccuracy, Mean

from copy import copy


In [14]:
class Learner:
    
    def __init__(self, model, dls=(0,), loss_func=F.mse_loss, lr=0.1, opt_func=optim.SGD, cbs=None):
        cbs = list(cbs)
        fc.store_attr()

    def predict(self): pass
    def get_loss(self): pass
    def backward(self): pass
    def step(self): pass
    def zero_grad(self): pass

    @with_cbs('batch')
    def _one_batch(self):
        self.predict()
        self.callback('after_predict')
        self.get_loss()
        self.callback('after_loss')
        
        if self.model.training:
            self.backward()
            self.callback('after_backward')
            self.step()
            self.callback('after_step')
            self.zero_grad()

    @with_cbs('epoch')
    def _one_epoch(self):
        for self.iter, self.batch in enumerate(self.dl):
            self._one_batch()

    def one_epoch(self, training):
        self.model.train(training)
        self.dl = self.dls.train if training else self.dls.valid
        self._one_epoch()

    @with_cbs('fit')
    def _fit(self, train, valid):
        for self.epoch in self.epochs:
            if train: self.one_epoch(True)
            if valid: torch.no_grad()(self.one_epoch)(False)

    def fit(self, n_epochs=1, train=True, valid=True, lr=None, cbs=list()):
        for cb in cbs: self.cbs.append(cb)
        try: 
            self.n_epochs = n_epochs
            self.epochs = range(n_epochs)
            if lr is None: lr = self.lr
            if self.opt_func: self.opt = self.opt_func(self.model.parameters(), lr)
            self._fit(train, valid)
        finally:
            for cb in cbs: self.cbs.remove(cb)

    def callback(self, method_nm): run_cbs(self.cbs, method_nm=method_nm)

# decorator with_cbs 

class TrainLearner(Learner):
    def predict(self): self.preds = self.model(self.batch[0])
    def get_loss(self): self.loss = self.loss_func(self.preds, self.batch[1])
    def backward(self): self.loss.backward()
    def step(self): self.opt.step()
    def zero_grad(self): self.opt.zero_grad()


In [15]:
from torch.utils.data import default_collate, DataLoader
from operator import itemgetter


def collate_dict(ds):
    get = itemgetter(*ds.features)
    def _f(b): return get(default_collate(b))
    return _f

def get_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs*2, **kwargs))

class DataLoaders:

    def __init__(self, *dls):
        self.train, self.valid = dls[:2]

    @classmethod
    def from_dd(cls, dd, batch_size, as_tuple=True, **kwargs):
        f = collate_dict(dd['train'])
        return cls(*get_dls(*dd.values(), bs=batch_size, collate_fn=f, **kwargs))
        

In [16]:
dsd = load_dataset('fashion_mnist')
bs = 1024
x,y = 'image', 'label'

@inplace
def transformi(b):
    b[x] = [torch.flatten(TF.to_tensor(o)) for o in b[x]]

tds = dsd.with_transform(transformi)

In [22]:
class MetricsCB(Callback):

    def __init__(self, *ms, **metrics):
        super().__init__()
        for o in ms: metrics[type(o).__name__] = o
        self.metrics = metrics
        self.all_metrics = copy(metrics)
        self.all_metrics['loss'] = self.loss = Mean()

    def before_fit(self, *args, **kwargs):
        print()

In [23]:
dls = DataLoaders.from_dd(tds, bs)

m,nh = 28*28,50
def get_model(): return nn.Sequential(nn.Linear(m,nh), nn.ReLU(), nn.Linear(nh,10))

cbs = [MetricsCB(accuracy=MulticlassAccuracy())]
model = get_model()

learner = TrainLearner(model, dls, loss_func=F.cross_entropy, lr=0.2, opt_func=optim.SGD, cbs=cbs)
learner.fit(1)




AttributeError: 'TrainLearner' object has no attribute 'calc_stats'