In [12]:
from tinygrad.nn import Linear
from tinygrad.tensor import Tensor
from tinygrad.nn.optim import SGD, get_parameters
from datasets import fetch_mnist
import numpy as np

In [10]:
class TinyNet:
    def __init__(self):
        self.l1 = Linear(784, 128, bias=False)
        self.l2 = Linear(128, 10, bias=False)
    
    def __call__(self, x):
        x = self.l1(x)
        x = x.leakyrelu()
        x = self.l2(x)
        return x.log_softmax()
    
net = TinyNet()

In [14]:
Tensor.training = True

def cross_entropy(out, Y):
    num_classes = out.shape[-1]
    YY = Y.flatten().astype(np.int32)
    y = np.zeros((YY.shape[0], num_classes), np.float32)
    y[range(y.shape[0]), YY] = -1.0*num_classes
    y = y.reshape(list(Y.shape)+[num_classes])
    y = Tensor(y)
    return out.mul(y).mean()

opt = SGD(get_parameters(net), lr=3e-4)

X_train, Y_train, X_test, Y_test = fetch_mnist()

for step in range(1000):
  # random sample a batch
  samp = np.random.randint(0, X_train.shape[0], size=(64))
  batch = Tensor(X_train[samp], requires_grad=False)
  # get the corresponding labels
  labels = Y_train[samp]

  # forward pass
  out = net(batch)

  # compute loss
  loss = cross_entropy(out, labels)

  # zero gradients
  opt.zero_grad()

  # backward pass
  loss.backward()

  # update parameters
  opt.step()

  # calculate accuracy
  pred = np.argmax(out.numpy(), axis=-1)
  acc = (pred == labels).mean()

  if step % 100 == 0:
    print(f"Step {step+1} | Loss: {loss.numpy()} | Accuracy: {acc}")


Step 1 | Loss: 0.41080746054649353 | Accuracy: 0.9375
Step 101 | Loss: 1.0151039361953735 | Accuracy: 0.8125
Step 201 | Loss: 0.3535616099834442 | Accuracy: 0.953125
Step 301 | Loss: 0.3401767611503601 | Accuracy: 0.9375
Step 401 | Loss: 0.24193640053272247 | Accuracy: 0.921875
Step 501 | Loss: 0.22169612348079681 | Accuracy: 0.90625
Step 601 | Loss: 0.2963210940361023 | Accuracy: 0.921875
Step 701 | Loss: 0.3528132438659668 | Accuracy: 0.921875
Step 801 | Loss: 0.1925179362297058 | Accuracy: 0.953125
Step 901 | Loss: 0.35431843996047974 | Accuracy: 0.9375
