In [73]:
#export
import pickle, gzip, math, os, time, shutil, torch, matplotlib as mpl, numpy as np
import pandas as pd
from pathlib import Path
from torch import nn
import torch
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader, default_collate
from typing import Mapping, Tuple, Sequence, Union, TypeVar, List, Callable

import miniai.dataset as ds
import miniai.training as training
from datasets import Dataset
from torch.utils.data import DataLoader
import datasets

In [2]:
# download mnist from hugginface datasets
mnist = datasets.load_dataset('mnist')

Found cached dataset mnist (/Users/diegomedina-bernal/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)
100%|██████████| 2/2 [00:00<00:00, 106.70it/s]


In [3]:
mnist

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 60000
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 10000
    })
})

In [4]:
x, y = "image", "label"

In [5]:
# For classification tasks, its common to use a stride of 2 and a padding of 1. By using a stride of 2, you are essentially cutting down the computation as you are reducing the size of the image by half.

In [6]:
mnist.set_format(type='torch', columns=[x, y])

In [8]:
# one example
mnist['train']['image'][0].shape

torch.Size([28, 28])

In [30]:
# quickly make a dataset and dataloaders from huggingface datesets - images
class DS:
    def __init__(self, x, y):
        # we need the 1 for this network being a single channel
        self.x = x.view(-1, 1, 28,28).float()/255.
        self.y = y[:, None].long().view(-1)
    def __len__(self):
        return len(self.x)
    def __getitem__(self, i):
        return self.x[i], self.y[i]

def to_ds(hf_ds): return DS(hf_ds['image'], hf_ds['label'])
def to_dls(train_ds, valid_ds, bs, **kwargs):
    return (DataLoader(train_ds, batch_size=bs, shuffle=True, **kwargs),
            DataLoader(valid_ds, batch_size=bs*2, **kwargs))

In [31]:
train_ds, valid_ds = to_ds(mnist['train']), to_ds(mnist['test'])

In [32]:
x, y = train_ds[0]
x.shape, y.shape, y

(torch.Size([1, 28, 28]), torch.Size([]), tensor(5))

In [33]:
# our same example but in a Dataset format
train_ds[0]

(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 

In [34]:
# formatting into daataloaders
train_dl, valid_dl = to_dls(train_ds, valid_ds, 256)

In [35]:
xb, yb = next(iter(train_dl))
xb.shape, yb.shape

(torch.Size([256, 1, 28, 28]), torch.Size([256]))

## Creating the CNN

In [36]:
#export
def conv(ni, nf, ks=3, stride=2, act=True):
    res = nn.Conv2d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2)
    if act: res = nn.Sequential(res, nn.ReLU())
    return res

In [39]:
# create a simple model
simple_cnn = nn.Sequential(
    conv(1, 4),
    conv(4, 8),
    conv(8, 16),
    conv(16, 16),
    conv(16, 10, act=False),
    nn.Flatten() # flatten the output to just get the classification
)

In [40]:
out = simple_cnn(xb)
out.shape

torch.Size([256, 10])

In [42]:
#export
def_device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'

def to_device(x, device=def_device):
    if isinstance(x, torch.Tensor): return x.to(device)
    if isinstance(x, Mapping): return {k:v.to(device) for k,v in x.items()}
    return type(x)(to_device(o, device) for o in x)

def collate_device(b): return to_device(default_collate(b))

In [44]:
# we will take a more classical approach and rebuild our fit function again, mainly for learnign purposes i just want to repeat
loss_fn =  F.cross_entropy

In [49]:
loss_fn(out, yb)

tensor(2.3061, grad_fn=<NllLossBackward0>)

In [82]:
#export
def accuracy(out, yb): return (out.argmax(dim=1)==yb).float().mean()

In [99]:
((out.argmax(dim=1)==yb).float()).mean() * yb.size(0) 

tensor(21.)

In [100]:
(out.argmax(1)==yb).sum()

tensor(21)

In [86]:
def fit(epochs, model, loss_fn, opt, train_dl, valid_dl, device=def_device):
    model = model.to(device)
    for epoch in range(epochs):
        train_acc, valid_acc = 0., 0.
        train_loss, valid_loss = 0., 0.
        train_n, valid_n = 0, 0
        model.train()
        for xb, yb in train_dl:
            xb, yb = xb.to(device), yb.to(device)
            out = model(xb)
            loss = loss_fn(out, yb)
            metric = accuracy(out, yb)
            loss.backward()
            opt.step()
            opt.zero_grad()
            train_n += yb.size(0)
            train_acc += metric*yb.size(0)
            train_loss += loss.item()*yb.size(0)
        model.eval()
        with torch.no_grad():
            for xb, yb in valid_dl:
                xb, yb = xb.to(device), yb.to(device)
                out = model(xb)
                loss = loss_fn(out, yb)
                metric = accuracy(out, yb)
                valid_n += yb.size(0)
                valid_acc += metric*yb.size(0)
                valid_loss += loss.item()*yb.size(0)
        print(f"Epoch: {epoch+1}/{epochs}")
        print(f"train_acc: {train_acc/train_n} train_loss: {train_loss/train_n}")
        print(f"valid_acc: {valid_acc/valid_n} valid_loss: {valid_loss/valid_n}")
        print("-------"*3)

In [87]:
# INSTANTIATE AND CHANGE THINGS HERE
# create a simple model - just moved it here to put everything in the same spot to instantiate
simple_cnn = nn.Sequential(
    conv(1, 4),
    conv(4, 8),
    conv(8, 16),
    conv(16, 16),
    conv(16, 10, act=False),
    nn.Flatten() # flatten the output to just get the classification
)

loss_fn = F.cross_entropy
lr = 1e-2
opt = optim.Adam(simple_cnn.parameters(), lr=lr)

In [88]:
# Train the model
fit(
    epochs=10, model=simple_cnn, loss_fn=loss_fn, opt=opt,
    train_dl=train_dl, valid_dl=valid_dl, device=def_device)

Epoch: 1/10
train_acc: 0.8591499924659729 train_loss: 0.4430797946373622
valid_acc: 0.953000009059906 valid_loss: 0.14700570645332336
---------------------
Epoch: 2/10
train_acc: 0.9566166400909424 train_loss: 0.14215702647368114
valid_acc: 0.9653000235557556 valid_loss: 0.10716609196662903
---------------------
Epoch: 3/10
train_acc: 0.9674999713897705 train_loss: 0.10673344823122025
valid_acc: 0.9668999910354614 valid_loss: 0.10464389328956604
---------------------
Epoch: 4/10
train_acc: 0.9703166484832764 train_loss: 0.09585008159081142
valid_acc: 0.9731000065803528 valid_loss: 0.08459901063442231
---------------------
Epoch: 5/10
train_acc: 0.9746666550636292 train_loss: 0.08311864763895671
valid_acc: 0.9653000235557556 valid_loss: 0.10152685663700103
---------------------
Epoch: 6/10
train_acc: 0.9753000140190125 train_loss: 0.07941865905026595
valid_acc: 0.9753999710083008 valid_loss: 0.07555193499326705
---------------------
Epoch: 7/10
train_acc: 0.9767333269119263 train_loss: 

In [89]:
# tune it further with a smaller learning rate
opt = optim.Adam(simple_cnn.parameters(), lr=lr/4)
fit(
    epochs=4, model=simple_cnn, loss_fn=loss_fn, opt=opt,
    train_dl=train_dl, valid_dl=valid_dl, device=def_device)

Epoch: 1/4
train_acc: 0.9857333302497864 train_loss: 0.04549092020293077
valid_acc: 0.9822999835014343 valid_loss: 0.05534515790939331
---------------------
Epoch: 2/4
train_acc: 0.9864166378974915 train_loss: 0.04183988188902537
valid_acc: 0.9818999767303467 valid_loss: 0.059020787489414216
---------------------
Epoch: 3/4
train_acc: 0.9878166913986206 train_loss: 0.0397154920215408
valid_acc: 0.980400025844574 valid_loss: 0.0611995587348938
---------------------
Epoch: 4/4
train_acc: 0.9880666732788086 train_loss: 0.03836365413069725
valid_acc: 0.9818999767303467 valid_loss: 0.05482535640001297
---------------------


In [None]:
# IMPLMENT https://arxiv.org/pdf/2306.16999.pdf