"""
Same as micrograd.py, but uses PyTorch for the autograd engine.
This is a way for us to check and verify correctness, and also
shows some of the similarities/differences in how PyTorch would
implement the same MLP. PyTorch lets you specify the forward pass,
records all the operations performed, and then calls backward()
"under the hood" inside its autograd engine.
"""

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter

In [2]:
# -----------------------------------------------------------------------------

# class that mimics the random interface in Python, fully deterministic,
# and in a way that we also control fully, and can also use in C, etc.
class RNG:
    def __init__(self, seed):
        self.state = seed

    def random_u32(self):
        # xorshift rng: https://en.wikipedia.org/wiki/Xorshift#xorshift.2A
        # doing & 0xFFFFFFFFFFFFFFFF is the same as cast to uint64 in C
        # doing & 0xFFFFFFFF is the same as cast to uint32 in C
        self.state ^= (self.state >> 12) & 0xFFFFFFFFFFFFFFFF
        self.state ^= (self.state << 25) & 0xFFFFFFFFFFFFFFFF
        self.state ^= (self.state >> 27) & 0xFFFFFFFFFFFFFFFF
        return ((self.state * 0x2545F4914F6CDD1D) >> 32) & 0xFFFFFFFF

    def random(self):
        # random float32 in [0, 1)
        return (self.random_u32() >> 8) / 16777216.0

    def uniform(self, a=0.0, b=1.0):
        # random float32 in [a, b)
        return a + (b-a) * self.random()

# generate a random dataset with 100 2-dimensional datapoints in 3 classes
def gen_data(random: RNG, n=100):
    pts = []
    for _ in range(n):
        x = random.uniform(-2.0, 2.0)
        y = random.uniform(-2.0, 2.0)
        # concentric circles
        # label = 0 if x**2 + y**2 < 1 else 1 if x**2 + y**2 < 2 else 2
        # very simple dataset
        label = 0 if x < 0 else 1 if y < 0 else 2
        pts.append(([x, y], label))
    # create train/val/test splits of the data (80%, 10%, 10%)
    tr = pts[:int(0.8*n)]
    val = pts[int(0.8*n):int(0.9*n)]
    te = pts[int(0.9*n):]
    return tr, val, te

In [3]:
random = RNG(42)


In [4]:
# Multi-Layer Perceptron (MLP) network

class Neuron(nn.Module):

    def __init__(self, nin, nonlin=True):
        super().__init__()
        self.w = Parameter(torch.tensor([random.uniform(-1, 1) * nin**-0.5 for _ in range(nin)]))
        self.b = Parameter(torch.zeros(1))
        self.nonlin = nonlin

    def forward(self, x):
        act = torch.sum(self.w * x) + self.b
        return act.tanh() if self.nonlin else act

    def __repr__(self):
        return f"{'TanH' if self.nonlin else 'Linear'}Neuron({len(self.w)})"

In [6]:
class Layer(nn.Module):

    def __init__(self, nin, nout, **kwargs):
        super().__init__()
        self.neurons = nn.ModuleList([Neuron(nin, **kwargs) for _ in range(nout)])

    def forward(self, x):
        out = [n(x) for n in self.neurons]
        return torch.stack(out, dim=-1)

    def __repr__(self):
        return f"Layer of [{', '.join(str(n) for n in self.neurons)}]"

In [7]:

class MLP(nn.Module):

    def __init__(self, nin, nouts):
        super().__init__()
        sz = [nin] + nouts
        self.layers = nn.ModuleList([Layer(sz[i], sz[i+1], nonlin=i!=len(nouts)-1) for i in range(len(nouts))])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

    def __repr__(self):
        return f"MLP of [{', '.join(str(layer) for layer in self.layers)}]"

In [8]:
# let's train!

train_split, val_split, test_split = gen_data(random, n=100)

In [9]:
# init the model: 2D inputs, 16 neurons, 3 outputs (logits)
model = MLP(2, [16, 3])
model.to(torch.float64) # ensure we're using double precision

MLP of [Layer of [TanHNeuron(2), TanHNeuron(2), TanHNeuron(2), TanHNeuron(2), TanHNeuron(2), TanHNeuron(2), TanHNeuron(2), TanHNeuron(2), TanHNeuron(2), TanHNeuron(2), TanHNeuron(2), TanHNeuron(2), TanHNeuron(2), TanHNeuron(2), TanHNeuron(2), TanHNeuron(2)], Layer of [LinearNeuron(16), LinearNeuron(16), LinearNeuron(16)]]

In [11]:
@torch.no_grad()
def eval_split(model, split):
    model.eval()
    # evaluate the loss of a split
    loss = 0.0
    for x, y in split:
        logits = model(torch.tensor(x))
        y = torch.tensor(y).view(-1)
        loss += F.cross_entropy(logits, y).item()
    loss = loss * (1.0/len(split)) # normalize the loss
    return loss

In [12]:
# optimize using Adam
learning_rate = 1e-1
beta1 = 0.9
beta2 = 0.95
weight_decay = 1e-4
optimizer = torch.optim.AdamW(model.parameters(),
                              lr=learning_rate,
                              betas=(beta1, beta2),
                              weight_decay=weight_decay)


In [13]:
# train
for step in range(100):

    # evaluate the validation split every few steps
    if step % 10 == 0:
        val_loss = eval_split(model, val_split)
        print(f"step {step}, val loss {val_loss}")

    # forward the network (get logits of all training datapoints)
    model.train()
    losses = []
    for x, y in train_split:
        logits = model(torch.tensor(x))
        loss = F.cross_entropy(logits, torch.tensor(y).view(-1))
        losses.append(loss)
    loss = torch.stack(losses).mean()
    # backward pass (deposit the gradients)
    loss.backward()
    # update with AdamW
    optimizer.step()
    model.zero_grad()

    print(f"step {step}, train loss {loss.data}")

step 0, val loss 0.9170899790006554
step 0, train loss 0.9811897723592257
step 1, train loss 0.5148874460244011
step 2, train loss 0.3165791872557939
step 3, train loss 0.22556572887242873
step 4, train loss 0.16796694850720875
step 5, train loss 0.13766378380671834
step 6, train loss 0.12536394889857927
step 7, train loss 0.11410184236874368
step 8, train loss 0.10023924035832203
step 9, train loss 0.08682478525877427
step 10, val loss 0.0039209179136893545
step 10, train loss 0.07694778828258682
step 11, train loss 0.07041304847928265
step 12, train loss 0.06321229877098362
step 13, train loss 0.05415612070051541
step 14, train loss 0.045596503151213816
step 15, train loss 0.03945831399740126
step 16, train loss 0.03563276787500751
step 17, train loss 0.03273043483103781
step 18, train loss 0.029522486797248632
step 19, train loss 0.025962779805612846
step 20, val loss 0.002645744970112932
step 20, train loss 0.022925037924814106
step 21, train loss 0.0206919879692237
step 22, train 