# Binary MNIST classification

In [108]:
import random
from engine import Value

# using torch to download MNIST dataset
import torch
import torchvision 
from torchvision import transforms

# MLP definition

In [109]:
class Neuron:
    def __init__(self, nin):
        self.w = [Value(random.uniform(-1, 1)) for _ in range(nin)]
        self.b = Value(random.uniform(-1, 1))

    def __call__(self, x):
        # w * x + b
        act = sum((wi * xi for wi, xi in zip(self.w, x)), self.b)
        out = act.sigmoid()
        return out

    def parameters(self):
        return self.w + [self.b]


class Layer:
    def __init__(self, nin, nout):
        self.neurons = [Neuron(nin) for _ in range(nout)]

    def __call__(self, x):
        outs = [n(x) for n in self.neurons]
        return outs[0] if len(outs) == 1 else outs

    def parameters(self):
        return [p for neuron in self.neurons for p in neuron.parameters()]


class MLP:
    def __init__(self, nin, nouts):
        sz = [nin] + nouts
        self.layers = [Layer(sz[i], sz[i + 1]) for i in range(len(nouts))]

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

    def parameters(self):
        return [p for layer in self.layers for p in layer.parameters()]


# Training

In [110]:
train_dataset = torchvision.datasets.MNIST(root='./data',
                                           train=True,
                                           transform=transforms.Compose([
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean = (0.1307,), std = (0.3081,))]),
                                           download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [111]:
xs = []
ys = []

for x, y in train_dataset:
    if y == 0 or y == 1:
        xs.append(x)
        ys.append(y)

In [112]:
len(xs), len(ys)

(12665, 12665)

In [113]:
xs = list(map(torch.flatten, xs))

In [None]:
xs = list(map(lambda x: x.tolist(), xs))

In [118]:
n = MLP(784, [16, 16, 1])
n(xs[0])

Value(data=0.7329184707686976)

In [119]:
def log_loss_val(pred: Value, target: int):
    return -(target * pred.log() + (1 - target) * (1 - pred).log())

a = Value(1.0)
log_loss_val(a, 0.0)

Value(data=13.815510557964274)

In [120]:
# SGD: can probably use a lot less updates to save time
for k in range(1000):
  x = xs[k]
  y = ys[k]
  
  # forward pass
  ypred = n(x)
  loss = log_loss_val(ypred, y)
  
  # backward pass
  for p in n.parameters():
    p.grad = 0.0
  loss.backward()
  
  # update
  for p in n.parameters():
    p.data += -0.1 * p.grad
  
  print(k, loss.data)


0 1.3201975700983304
1 0.4289396961294389
2 0.501282814794764
3 0.3734861206064659
4 0.19984579775294842
5 2.0669692056787667
6 0.334165153218932
7 0.29620568655098656
8 1.3909599025071806
9 1.088080672464734
10 0.5502548292173876
11 1.423318039265411
12 1.3902456225359145
13 0.28657204169860495
14 1.3903703163110397
15 0.3200585348270228
16 1.2148779711050313
17 1.3023598570298394
18 0.2553177309601149
19 0.2766008236718173
20 1.3974460984587393
21 0.2780028487660291
22 0.2740176070883971
23 1.735127599142037
24 1.3609730191540474
25 1.800396445936939
26 0.43997614879451497
27 0.25302514794296327
28 0.3900939907871834
29 0.2431536228638919
30 1.429686630093435
31 0.20287774689073876
32 0.2294699906670556
33 1.7578618674233804
34 1.2084902496226912
35 1.404351389504489
36 1.0960302129210155
37 0.3614240807027042
38 0.2320585472403726
39 0.17500245959449925
40 0.14575200343479694
41 1.8864153290063834
42 1.719642312242307
43 0.15611420875158497
44 0.1938378023407314
45 0.143697443617994

# Testing

In [121]:
test_dataset = torchvision.datasets.MNIST(root='./data',
                                          train=False,
                                          transform=transforms.Compose([
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean = (0.1307,), std = (0.3081,))]),
                                          download=True)

In [122]:
test_xs = []
test_ys = []

for x, y in test_dataset:
    if y == 0 or y == 1:
        test_xs.append(x)
        test_ys.append(y)

In [123]:
len(test_xs), len(test_ys)

(2115, 2115)

In [124]:
test_xs = list(map(torch.flatten, test_xs))

In [125]:
test_xs = list(map(lambda x: x.tolist(), test_xs))

In [136]:
correct = 0
total = 0
# only using 100 examples for the sake of time
for x, y in zip(test_xs[:100], test_ys[:100]):
    pred = n(x)
    num_pred = 1 if pred.data >= 0.5 else 0
    correct += (y == num_pred)
    total += 1

print(f"accuracy: {correct / total}")

accuracy: 0.99
