In [None]:
%cd ..

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from dataclasses import dataclass, field
from typing import Protocol 

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

In [None]:
from tensorviewer import tv
from utils import get_mnist

# Data Preview

In [None]:
x_trn, y_trn, x_val, y_val = [torch.as_tensor(x) for x in get_mnist()]

In [None]:
imgs = x_trn[10:20].reshape((10, 28, 28))

In [None]:
tv(imgs)

# MLP

In [None]:
from functools import reduce

In [None]:
torch.set_printoptions(precision=2)

In [None]:
from functools import reduce
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, n_i, n_h, n_o):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(n_i, n_h), nn.ReLU(), nn.Linear(n_h, n_o)])
    def forward(self, x):
        return reduce(lambda x, l: l(x), self.layers, x)

In [None]:
mlp = MLP(784, 50, 10)

In [None]:
logits = mlp(x_trn[5:10])
logits

## Log Softmax and Cross-Entropy 

In [None]:
def softmax(t):
    p = torch.exp(t)
    return p / p.sum(1, keepdim=True)

In [None]:
assert torch.allclose(softmax(logits), torch.softmax(logits, dim=1))

In [None]:
probs = softmax(logits)
probs

In [None]:
def ce(logits, targets):
    probs = softmax(logits)
    probs = probs.gather(dim=1, index=targets.view(-1, 1)).squeeze(1)
    return -probs.log().mean()

In [None]:
o, y = mlp(x_trn[5:10]), y_trn[5:10]

In [None]:
assert torch.allclose(ce(o, y), F.cross_entropy(o, y))

In [None]:
def log_softmax(t):
    t.sub_(t.max(dim=1, keepdim=True)[0])
    t.sub_(t.exp().sum(dim=1, keepdim=True).log())
    return t

In [None]:
def ce(logits, targets):
    return -log_softmax(logits).gather(dim=1, index=targets.view(-1, 1)).squeeze(1).mean()

In [None]:
assert torch.allclose(ce(o, y), F.cross_entropy(o, y))

## Dataset

In [None]:
class Dataset:
    def __init__(self, x, y):
        assert len(x) == len(y)
        self.x, self.y = x, y
    def __len__(self): return len(self.x)
    def __getitem__(self, index): return self.x[index], self.y[index]

In [None]:
ds = Dataset(x_trn, y_trn)

In [None]:
ds[0][0].shape, ds[0][1].shape

## DataLoader

In [None]:
class Sampler(Protocol):
    def get_idx(self, dataset: Dataset): pass

@dataclass
class SequentialSampler(Sampler):
    def get_idx(self, dataset: Dataset): return np.arange(len(dataset))

@dataclass
class RandomSampler(Sampler):
    def get_idx(self, dataset: Dataset): 
        idx = SequentialSampler().get_idx(dataset)
        np.random.shuffle(idx)
        return idx
    
@dataclass
class DataLoader:
    dataset: Dataset
    bs: int = 1
    shuffle: bool = False
    sampler: Sampler = None
    
    def __post_init__(self):
        if self.sampler is None:
            self.sampler = (RandomSampler if self.shuffle else SequentialSampler)()
    
    def __iter__(self):
        idx = self.sampler.get_idx(self.dataset)
        for i in range(0, len(idx), self.bs):
            yield self.dataset[idx[i:i+self.bs]]
            
    def __len__(self): return int(np.ceil(len(self.dataset) // self.bs))

In [None]:
next(iter((DataLoader(ds, 4))))

In [None]:
next(iter((DataLoader(ds, 4, shuffle=True))))

## Training Loop

In [None]:
def acc(x, y): return (x == y).float().mean()

In [None]:
class Optimizer:
    def __init__(self, params, lr): 
        self.params, self.lr = list(params), lr
    def step(self): 
        for param in self.params:
            param.data -= self.lr * param.grad
    def zero_grad(self):
        for param in self.params:
            param.grad.zero_()

In [None]:
from collections import defaultdict

epochs = 10
lr = 0.03
mlp = MLP(784, 50, 10)
opt = Optimizer(mlp.parameters(), lr)

trn_dl = DataLoader(Dataset(x_trn, y_trn), 32, True)
val_dl = DataLoader(Dataset(x_val, y_val), 64)

for epoch in range(epochs):
    cnts = defaultdict(int)
    
    for (xb, yb) in trn_dl:
        loss = ce(mlp(xb), yb)
        loss.backward()
        opt.step()
        opt.zero_grad()
        cnts["train_loss"] += loss.item()
                
    cnts["train_loss"] /= len(trn_dl)
        
    with torch.no_grad():
        for (xb, yb) in val_dl:
            logits = mlp(xb)
            pred = softmax(logits).argmax(dim=1)
            cnts["valid_loss"] += ce(logits, yb) 
            cnts["valid_acc"] += acc(pred, yb)
    
    cnts["valid_loss"] /= len(val_dl)
    cnts["valid_acc"] /= len(val_dl)
    cnts["epoch"] = epoch
    
    print("Epoch {epoch:03d} | loss(trn) = {train_loss:.4f} | loss(val) = {valid_loss:.4f} | acc = {valid_acc:.4f}".format(**cnts))