In [None]:
#| default_exp learner

In [None]:
#|export
import math,torch,matplotlib.pyplot as plt
import fastcore.all as fc
from collections.abc import Mapping
from operator import attrgetter
from functools import partial
from copy import copy
from minima import optim
from fastprogress import progress_bar,master_bar
from operator import itemgetter
from itertools import zip_longest
import minima as mi
import minima.nn as nn
from minima.data import DataLoader, Dataset

In [None]:
import matplotlib as mpl
from contextlib import contextmanager
from datasets import load_dataset,load_dataset_builder
import logging
from fastcore.test import test_close

In [None]:
class DataLoaders:
    def __init__(self, *dls): self.train,self.valid = dls[:2]

    @classmethod
    def from_dd(cls, dd, batch_size, as_tuple=True, **kwargs):
        # f = collate_dict(dd['train'])
        return cls(*get_dls(*dd.values(), bs=batch_size, **kwargs))

def get_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs*2, **kwargs))     

In [None]:
def load_mnist(path, kind='train'):
    import os
    import gzip
    import numpy as np

    """Load MNIST data from `path`"""
    labels_path = os.path.join(path,
                               '%s-labels-idx1-ubyte.gz'
                               % kind)
    images_path = os.path.join(path,
                               '%s-images-idx3-ubyte.gz'
                               % kind)

    with gzip.open(labels_path, 'rb') as lbpath:
        labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
                               offset=8)

    with gzip.open(images_path, 'rb') as imgpath:
        images = np.frombuffer(imgpath.read(), dtype=np.uint8,
                               offset=16).reshape(len(labels), 784)

    return images, labels

In [None]:
X_train, y_train = load_mnist('../datasets/fashion', kind='train')
X_test, y_test = load_mnist('../datasets/fashion', kind='t10k')

In [None]:
X_train.shape, y_train.shape

((60000, 784), (60000,))

In [None]:
X_test.shape, y_test.shape

((10000, 784), (10000,))

In [None]:
X_train[0].shape

(784,)

In [None]:
X_tr, y_tr, X_val, y_val = map(mi.Tensor, (X_train, y_train, X_test, y_test))

In [None]:
# Custom Dataset class
class MyDataset(Dataset):
    def __init__(self, X, y):
        self.X = mi.Tensor(X)
        self.y = mi.Tensor(y)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, index):
        return self.X[index], self.y[index]

tr_ds = MyDataset(X_tr, y_tr)
val_ds = MyDataset(X_val, y_val)

In [None]:
dls = get_dls(train_ds=tr_ds, valid_ds=val_ds, bs=64)
dt, dv = dls
xb,yb = next(iter(dt))
xb.shape,yb[:10]

((64, 784), minima.Tensor([5 0 2 7 3 5 8 5 1 4]))

In [None]:

class Learner:
    def __init__(self, model, dls, loss_func, lr, opt_func=optim.SGD): fc.store_attr()

    def one_batch(self):
        self.xb,self.yb = self.batch
        self.preds = self.model(self.xb)
        self.loss = self.loss_func(self.preds, self.yb)
        if self.model.training:
            self.loss.backward()
            self.opt.step()
            self.opt.zero_grad()
        with torch.no_grad(): self.calc_stats()

    def calc_stats(self):
        acc = (self.preds.argmax(dim=1)==self.yb).float().sum()
        self.accs.append(acc)
        n = len(self.xb)
        self.losses.append(self.loss*n)
        self.ns.append(n)

    def one_epoch(self, train):
        self.model.training = train
        dl = self.dls.train if train else self.dls.valid
        for self.num,self.batch in enumerate(dl): self.one_batch()
        n = sum(self.ns)
        print(self.epoch, self.model.training, sum(self.losses).item()/n, sum(self.accs).item()/n)
    
    def fit(self, n_epochs):
        self.accs,self.losses,self.ns = [],[],[]
        self.opt = self.opt_func(self.model.parameters(), self.lr)
        self.n_epochs = n_epochs
        for self.epoch in range(n_epochs):
            self.one_epoch(True)
            with torch.no_grad(): self.one_epoch(False)