In [None]:
from fastai.data.core import DataLoaders
from fastai.learner import Learner
from fastai.metrics import accuracy
from fastai.vision.all import torch, tensor, untar_data, Image, URLs, DataLoader, Path

In [None]:
def init_param(size):
    return torch.randn(size).requires_grad_()

# Layers

In [None]:
class Linear:
    def __init__(self, in_, out_):
        self.w = init_param((in_, out_))
        self.b = init_param(1)

    def __call__(self, x):
        return x @ self.w + self.b

    @property
    def params(self):
        return [self.w, self.b]

In [None]:
class ReLU:
    def __init__(self):
        pass

    def __call__(self, x):
        return torch.where(x > 0.0, x, torch.zeros_like(x))

# Sequential API

In [None]:
class Sequential:
    def __init__(self, *args):
        self.layers = list(args)

    def __call__(self, x):
        y = x
        for lay in self.layers:
            y = lay(y)
        return y

    @property
    def params(self):
        p_ = []

        for layer in self.layers:
            if hasattr(layer, "params"):
                p_ += layer.params

        return p_

# Optimizer

In [None]:
class SGD:
    def __init__(self, params, lr):
        self.params = list(params)
        self.lr = lr

    def step(self):
        for param in self.params:
            param.data -= param.grad.data * self.lr

    def zero_grad(self):
        for param in self.params:
            param.grad = None


In [None]:
def mnist_loss(preds, targets):
    preds = preds.sigmoid()
    return torch.where(targets == 1, 1 - preds, preds).mean()

In [None]:
def acc(preds, targets):
    preds = preds.sigmoid()
    return ((preds > 0.5) == targets).float().mean().item()

# Trainer

In [None]:
class Trainer:
    def __init__(self, train_dl, val_dl, model, lr, optim_cls, loss_fn):
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.model = model
        self.optim = optim_cls(self.model.params, lr)
        self.loss_fn = loss_fn

    def _calc_grad(self, x, y):
        out = self.model(x)
        loss = self.loss_fn(out, y)
        loss.backward()

        return loss.detach().cpu().item()

    def _train_step(self, *args, **kwargs):
        losses = []
        for x, y in self.train_dl:
            losses.append(self._calc_grad(x, y))
            self.optim.step()
            self.optim.zero_grad()

        return tensor(losses).mean().item()

    def _val_step(self, *args, **kwargs):
        preds = []
        targets = []

        with torch.no_grad():
            for x, y in self.val_dl:
                preds.append(self.model(x))
                targets.append(y)

        preds = torch.cat(preds).float()
        targets = torch.cat(targets).float()

        return (
            self.loss_fn(preds, targets).item(),
            acc(preds.reshape(-1), targets.reshape(-1)),
        )

    def fit(self, epochs):
        for eph in range(epochs):
            train_loss = self._train_step()
            val_loss, val_acc = self._val_step()
            print(
                f"Epoch [{eph}]: train_loss: {round(train_loss, 4)}, val_loss: {round(val_loss, 4)}, val_acc:{round(val_acc, 4)}"
            )

In [None]:
# download the data
path = untar_data(URLs.MNIST_SAMPLE)
print(path)

Path.BASE_PATH = path

# load and convert the images to float tensors
train_3 = (
    torch.stack(
        [tensor(Image.open(pth)) for pth in (path / "train" / "3").ls()]
    ).float()
    / 255
)
train_7 = (
    torch.stack(
        [tensor(Image.open(pth)) for pth in (path / "train" / "7").ls()]
    ).float()
    / 255
)

val_3 = (
    torch.stack(
        [tensor(Image.open(pth)) for pth in (path / "valid" / "3").ls()]
    ).float()
    / 255
)
val_7 = (
    torch.stack(
        [tensor(Image.open(pth)) for pth in (path / "valid" / "7").ls()]
    ).float()
    / 255
)

# training and validation dataset
train_X = torch.cat([train_3, train_7]).reshape(-1, 28 * 28)
train_y = tensor([1.0] * len(train_3) + [0.0] * len(train_7))

val_X = torch.cat([val_3, val_7]).reshape(-1, 28 * 28)
val_y = tensor([1.0] * len(val_3) + [0.0] * len(val_7))

# dataloaders
train_dl = DataLoader(
    list(zip(train_X, train_y)), bs=16, num_workers=4, shuffle=True
)
val_dl = DataLoader(list(zip(val_X, val_y)), bs=32, num_workers=4)

dls = DataLoaders(train_dl, val_dl)

model = Sequential(
    Linear(28 * 28, 32),
    ReLU(),
    Linear(32, 64),
    ReLU(),
    Linear(64, 1),
)

lrn = Trainer(train_dl, val_dl, model, 9e-2, SGD, mnist_loss)
lrn.fit(40)