In [3]:
import os
os.environ['DEBUG'] = '1'

import numpy as np
from typing import Union
import math
from time import perf_counter

from tinygrad.tensor import Tensor
from tinygrad.nn.optim import SGD
from tinygrad.nn import Conv2d, BatchNorm2d, Linear
from tinygrad.nn.state import get_parameters

from lib.utils import get_mnist
from lib.dataloader import SimpleDataLoader

In [6]:
X_train, Y_train, X_test, Y_test = get_mnist("../../data") # these need to be tensors??

First, lets build and test with tinygrad's built in methods


In [13]:
class ConvBlock:
    def __init__(self, input_channels, output_channels, kernel_size):
        self.conv_layer = Conv2d(input_channels, output_channels, kernel_size)
        self.batch_norm_layer = BatchNorm2d(output_channels)

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv_layer(x) # (batch_size, 28, 28, 1) -> (batch_size, 26, 26, 32)
        x = self.batch_norm_layer(x).relu()
        return x
    
    def __call__(self, x): return self.forward(x)

    def parameters(self) -> list:
        return get_parameters(self.conv_layer) + get_parameters(self.batch_norm_layer)

class TinyConv:
    def __init__(self):
        self.conv1 = ConvBlock(1, 32, 3)  # (batch_size, 1, 28, 28) -> (batch_size, 32, 26, 26)
        self.conv2 = ConvBlock(32, 64, 3)  # (batch_size, 32, 13, 13) -> (batch_size, 64, 11, 11)
        self.fc1 = Linear(64 * 5 * 5, 128)  # (batch_size, 1600) -> (batch_size, 128)
        self.fc2 = Linear(128, 10)  # (batch_size, 128) -> (batch_size, 10)

    def forward(self, x: Tensor) -> Tensor:
        x = x.reshape(-1, 1, 28, 28)  # (batch_size, 784) -> (batch_size, 1, 28, 28)
        x = self.conv1(x)  # (batch_size, 1, 28, 28) -> (batch_size, 32, 26, 26)
        x = x.max_pool2d(kernel_size=(2,2))  # (batch_size, 32, 26, 26) -> (batch_size, 32, 13, 13)
        x = self.conv2(x)  # (batch_size, 32, 13, 13) -> (batch_size, 64, 11, 11)
        x = x.max_pool2d(kernel_size=(2,2))  # (batch_size, 64, 11, 11) -> (batch_size, 64, 5, 5)
        x = x.reshape(x.shape[0], -1)  # (batch_size, 64, 5, 5) -> (batch_size, 1600)
        x = self.fc1(x).relu()  # (batch_size, 1600) -> (batch_size, 128)
        x = self.fc2(x).softmax()  # (batch_size, 128) -> (batch_size, 10)

        return x
    
    def __call__(self, x): return self.forward(x)

    def parameters(self) -> list:
        return get_parameters(self.conv1) + get_parameters(self.conv2) + get_parameters(self.fc1) + get_parameters(self.fc2)

In [7]:
train_loader = SimpleDataLoader(X_train, Y_train, batch_size=64, shuffle=True)
test_loader = SimpleDataLoader(X_test, Y_test, batch_size=64, shuffle=False)

In [18]:
model = TinyConv()
optim = SGD(model.parameters(), lr=0.001) # instantiate the optimizer

EPOCHS = 20
STEPS = 1000 # num of batches per epoch
BATCH_SIZE = 64
max_batches_per_epoch = math.ceil(len(X_train) / BATCH_SIZE) # handle smaller last batch

In [19]:
total_time = 0.0
steps = min(STEPS, max_batches_per_epoch)
for epoch in range(EPOCHS):
    start = perf_counter()
    running_train_loss = 0.0
    for step in range(steps):
        with Tensor.train():
            samp = np.random.randint(0, X_train.shape[0], size=(64))

            # get batch and labels
            batch = Tensor(X_train[samp], requires_grad=False)
            labels = Tensor(Y_train[samp])

            out = model(batch) # forward pass
            loss = out.sparse_categorical_crossentropy(labels) # calculate loss
            optim.zero_grad() # zero out gradients
            loss.backward() # backward pass
            optim.step() # update weights

            running_train_loss += loss.numpy()

    train_loss = running_train_loss / STEPS # loss over all batches, over num batches

    # test accuracy over the whole dataset
    out = model(Tensor(X_test))
    pred = out.argmax(axis=1) # get the index of the max value
    accuracy = (pred == Tensor(Y_test)).mean().numpy()

    elapsed = perf_counter() - start
    total_time += elapsed

    print(f"Epoch {epoch+1}/{EPOCHS}: {steps} Batches (max: {max_batches_per_epoch}) | Train Loss: {train_loss:.4f} | Test Accuracy: {accuracy:.4f} | Time: {elapsed:.2f}s")

print(f"Total training time: {total_time:.2f}s")

Epoch 1/20: 782 Batches (max: 782) | Train Loss: 1.7830 | Test Accuracy: 0.2257 | Time: 101.31s
Epoch 2/20: 782 Batches (max: 782) | Train Loss: 1.7200 | Test Accuracy: 0.3329 | Time: 98.13s
Epoch 3/20: 782 Batches (max: 782) | Train Loss: 1.6503 | Test Accuracy: 0.4853 | Time: 98.77s
Epoch 4/20: 782 Batches (max: 782) | Train Loss: 1.5845 | Test Accuracy: 0.5984 | Time: 99.17s
Epoch 5/20: 782 Batches (max: 782) | Train Loss: 1.5091 | Test Accuracy: 0.7537 | Time: 98.47s
Epoch 6/20: 782 Batches (max: 782) | Train Loss: 1.4257 | Test Accuracy: 0.8202 | Time: 97.98s
Epoch 7/20: 782 Batches (max: 782) | Train Loss: 1.3680 | Test Accuracy: 0.8363 | Time: 97.17s
Epoch 8/20: 782 Batches (max: 782) | Train Loss: 1.3344 | Test Accuracy: 0.8448 | Time: 100.21s
Epoch 9/20: 782 Batches (max: 782) | Train Loss: 1.3145 | Test Accuracy: 0.8512 | Time: 98.42s
Epoch 10/20: 782 Batches (max: 782) | Train Loss: 1.3014 | Test Accuracy: 0.8564 | Time: 98.44s
Epoch 11/20: 782 Batches (max: 782) | Train Los

### Full Run:

Best Test Accuracy: 0.9614

```
Epoch 1/20: 782 Batches (max: 782) | Train Loss: 1.7830 | Test Accuracy: 0.2257 | Time: 101.31s
Epoch 2/20: 782 Batches (max: 782) | Train Loss: 1.7200 | Test Accuracy: 0.3329 | Time: 98.13s
Epoch 3/20: 782 Batches (max: 782) | Train Loss: 1.6503 | Test Accuracy: 0.4853 | Time: 98.77s
Epoch 4/20: 782 Batches (max: 782) | Train Loss: 1.5845 | Test Accuracy: 0.5984 | Time: 99.17s
Epoch 5/20: 782 Batches (max: 782) | Train Loss: 1.5091 | Test Accuracy: 0.7537 | Time: 98.47s
Epoch 6/20: 782 Batches (max: 782) | Train Loss: 1.4257 | Test Accuracy: 0.8202 | Time: 97.98s
Epoch 7/20: 782 Batches (max: 782) | Train Loss: 1.3680 | Test Accuracy: 0.8363 | Time: 97.17s
Epoch 8/20: 782 Batches (max: 782) | Train Loss: 1.3344 | Test Accuracy: 0.8448 | Time: 100.21s
Epoch 9/20: 782 Batches (max: 782) | Train Loss: 1.3145 | Test Accuracy: 0.8512 | Time: 98.42s
Epoch 10/20: 782 Batches (max: 782) | Train Loss: 1.3014 | Test Accuracy: 0.8564 | Time: 98.44s
Epoch 11/20: 782 Batches (max: 782) | Train Loss: 1.2819 | Test Accuracy: 0.9171 | Time: 99.48s
Epoch 12/20: 782 Batches (max: 782) | Train Loss: 1.2563 | Test Accuracy: 0.9356 | Time: 98.68s
Epoch 13/20: 782 Batches (max: 782) | Train Loss: 1.2420 | Test Accuracy: 0.9433 | Time: 102.76s
Epoch 14/20: 782 Batches (max: 782) | Train Loss: 1.2292 | Test Accuracy: 0.9481 | Time: 102.98s
Epoch 15/20: 782 Batches (max: 782) | Train Loss: 1.2241 | Test Accuracy: 0.9510 | Time: 103.27s
Epoch 16/20: 782 Batches (max: 782) | Train Loss: 1.2160 | Test Accuracy: 0.9535 | Time: 100.80s
Epoch 17/20: 782 Batches (max: 782) | Train Loss: 1.2120 | Test Accuracy: 0.9564 | Time: 101.49s
Epoch 18/20: 782 Batches (max: 782) | Train Loss: 1.2071 | Test Accuracy: 0.9579 | Time: 104.21s
Epoch 19/20: 782 Batches (max: 782) | Train Loss: 1.2037 | Test Accuracy: 0.9605 | Time: 103.43s
Epoch 20/20: 782 Batches (max: 782) | Train Loss: 1.1996 | Test Accuracy: 0.9614 | Time: 100.46s
Total training time: 2005.64s
```