In [1]:
import pickle, gzip, math, os, time, shutil
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# torch
import torch
from torch import tensor, nn
import torch.nn.functional as F

# huggingface datasets
import datasets
from datasets import Dataset
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# download mnist from hugginface datasets
mnist = datasets.load_dataset('mnist')

Found cached dataset mnist (/Users/diegomedina-bernal/.cache/huggingface/datasets/mnist/mnist/1.0.0/9d494b7f466d6931c64fb39d58bb1249a4d85c9eb9865d9bc20960b999e2a332)
100%|██████████| 2/2 [00:00<00:00, 680.23it/s]


In [3]:
mnist.set_format('torch')
# mnist.set_format(type='np', columns=['image', 'label'])

# lets get them tensors
train = mnist['train']
test = mnist['test']
x_train, y_train = train['image'], train['label']
x_test, y_test = test['image'], test['label']

In [55]:
class DS:
    def __init__(self, x, y):
        self.x = x.view(-1, 28*28).float()/255.
        self.y = y[:, None].long().view(-1)
    def __len__(self):
        return len(self.x)
    def __getitem__(self, i):
        return self.x[i], self.y[i]
    
# scuffed way to create our own dataset
train_ds = DS(x_train, y_train)
test_ds = DS(x_test, y_test)


# scuffed but works?
BS = 256
train_dl = DataLoader(train_ds, batch_size=BS, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=256, shuffle=False)

In [56]:
# 2 layer nn
class Model(nn.Module):
    def __init__(self, nin, nh, nout, act_fn=nn.ReLU()):
        super().__init__()
        assert isinstance(act_fn, nn.Module)
        self.layers = [
            nn.Linear(nin, nh),
            act_fn,
            nn.Linear(nh, nout)]
        
    def forward(self, x):
        for l in self.layers: x = l(x)
        return x

In [57]:
# testing
xb, yb = next(iter(train_dl))
xb.shape, yb.shape

(torch.Size([256, 784]), torch.Size([256]))

In [58]:
# simple model
model = Model(
    nin=28*28,           # 28x28 image -> we flatten on each batch
    nh=50,               # hidden layer size
    nout=10,             # output size
    act_fn=nn.ReLU()     # activation function
    )

In [59]:
# does it work? yes? no?
out = model(xb)
out.shape

torch.Size([256, 10])

In [60]:
out.shape, yb.shape

(torch.Size([256, 10]), torch.Size([256]))

In [61]:
# can we do softmax? yes? no? im not writing the new one
F.log_softmax(out, dim=1)

tensor([[-2.1796, -2.3081, -2.2505,  ..., -2.3799, -2.4470, -2.3495],
        [-2.1958, -2.3052, -2.2289,  ..., -2.4134, -2.3280, -2.3587],
        [-2.2498, -2.2982, -2.3291,  ..., -2.4082, -2.3797, -2.3591],
        ...,
        [-2.1820, -2.2844, -2.3027,  ..., -2.3895, -2.3692, -2.3150],
        [-2.1812, -2.2435, -2.3376,  ..., -2.4798, -2.3794, -2.3735],
        [-2.1512, -2.2716, -2.2467,  ..., -2.3967, -2.3707, -2.4110]],
       grad_fn=<LogSoftmaxBackward0>)

In [62]:
def log_softmax(x):
    return (x.exp()/(x.exp().sum(-1, keepdim=True))).log()

In [63]:
# accuracy function
def accuracy(out, yb): return (out.argmax(dim=1)==yb).float().mean()
# function to print
def report(loss, preds, yb): print(f"{loss:.2f}, {accuracy(preds, yb):.2f}")

In [64]:
# to test out a new function we need a new model
# this is practically the same thing as the previous model
class SimpleModel(nn.Module):
    def __init__(self, nin, nh, nout=10, act_fn=nn.ReLU):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(nin, nh),
            act_fn(),
            nn.Linear(nh, nout))
        
    def forward(self, x):
        return self.layers(x)

In [67]:
model = SimpleModel(28*28, 50, act_fn=nn.ReLU)

In [68]:
# let's train our simple simple model and see if it learns?
# maybe it will!
lr = 1e-2
opt = torch.optim.Adam(model.parameters(), lr=lr)
epochs = 100
for epoch in range(epochs):
    for xb, yb in train_dl:
        preds = model(xb)
        loss = F.cross_entropy(preds, yb)
        loss.backward()
        opt.step()
        opt.zero_grad()
    if epoch % 10 == 0: report(loss, preds, yb)

0.16, 0.96
0.08, 0.98
0.00, 1.00
0.09, 0.97
0.02, 0.99
0.00, 1.00
0.00, 1.00
0.00, 1.00
0.00, 1.00
0.18, 0.98
