# Binary MNIST classification

In [1]:
import random
from engine import Value
import nn

# using torch to download MNIST dataset
import torch
import torchvision 
from torchvision import transforms

# MLP definition

In [2]:
class MLP(nn.Module):
    def __init__(self, nin, nouts):
        sz = [nin] + nouts
        self.layers = []
        for i in range(len(nouts)):
            self.layers.append(nn.Linear(sz[i], sz[i+1]))
            self.layers.append(nn.Sigmoid())

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

    def parameters(self):
        return [p for layer in self.layers for p in layer.parameters()]

# Training

In [3]:
train_dataset = torchvision.datasets.MNIST(root='./data',
                                           train=True,
                                           transform=transforms.Compose([
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean = (0.1307,), std = (0.3081,))]),
                                           download=True)

In [4]:
xs = []
ys = []

for x, y in train_dataset:
    if y == 0 or y == 1:
        xs.append(x)
        ys.append(y)

In [5]:
len(xs), len(ys)

(12665, 12665)

In [6]:
xs = list(map(torch.flatten, xs))

In [7]:
xs = list(map(lambda x: x.tolist(), xs))

In [8]:
model = MLP(784, [16, 16, 1])
model(xs[0])

[Value(data=0.4172704528019941)]

In [9]:
def log_loss_val(pred: Value, target: int):
    return -(target * pred.log() + (1 - target) * (1 - pred).log())

a = Value(1.0)
log_loss_val(a, 0.0)

Value(data=13.815510557964274)

In [10]:
# SGD: can change number of updates to tradeoff accuracy for time
for k in range(100):
  x = xs[k]
  y = ys[k]
  
  # forward pass
  ypred = model(x)[0]
  loss = log_loss_val(ypred, y)
  
  # backward pass
  for p in model.parameters():
    p.grad = 0.0
  loss.backward()
  
  # update
  for p in model.parameters():
    p.data += -0.1 * p.grad
  
  if k % 10 == 0:
    print(k, loss.data)


0 0.5400303826650329
10 0.5547694195120182
20 0.7604396685418151
30 1.475858326430516
40 0.3243103639599521
50 1.0107908171172202
60 0.820473425167408
70 0.5241338501401879
80 0.22821401839691627
90 1.4242800255207686


# Testing

In [11]:
test_dataset = torchvision.datasets.MNIST(root='./data',
                                          train=False,
                                          transform=transforms.Compose([
                                                  transforms.ToTensor(),
                                                  transforms.Normalize(mean = (0.1307,), std = (0.3081,))]),
                                          download=True)

In [12]:
test_xs = []
test_ys = []

for x, y in test_dataset:
    if y == 0 or y == 1:
        test_xs.append(x)
        test_ys.append(y)

In [13]:
len(test_xs), len(test_ys)

(2115, 2115)

In [14]:
test_xs = list(map(torch.flatten, test_xs))

In [15]:
test_xs = list(map(lambda x: x.tolist(), test_xs))

In [16]:
correct = 0
total = 0
for x, y in zip(test_xs[:30], test_ys[:30]):
    pred = model(x)[0]
    num_pred = 1 if pred.data >= 0.5 else 0
    correct += (y == num_pred)
    total += 1

print(f"accuracy: {correct / total}")

accuracy: 0.8333333333333334
