In [1]:
import os
os.environ['DEBUG'] = '1'

import numpy as np
from typing import Union
import math
from time import perf_counter

from tinygrad.tensor import Tensor
from tinygrad.nn.optim import SGD

from lib.utils import get_mnist
from lib.dataloader import SimpleDataLoader

In [2]:
X_train, Y_train, X_test, Y_test = get_mnist("../data") # these need to be tensors??

In [3]:
class TinyConv:
    def __init__(self):
        pass

In [None]:
train_loader = SimpleDataLoader(X_train, Y_train, batch_size=64, shuffle=True)
test_loader = SimpleDataLoader(X_test, Y_test, batch_size=64, shuffle=False)

In [None]:
model = TinyConv()
optim = SGD(model.parameters(), lr=0.001) # instantiate the optimizer

EPOCHS = 10
STEPS = 1000 # num of batches per epoch
BATCH_SIZE = 64
max_batches_per_epoch = math.ceil(len(X_train) / BATCH_SIZE) # handle smaller last batch

In [None]:
total_time = 0.0
steps = min(STEPS, max_batches_per_epoch)
for epoch in range(EPOCHS):
    start = perf_counter()
    running_train_loss = 0.0
    for step in range(steps):
        with Tensor.train():
            samp = np.random.randint(0, X_train.shape[0], size=(64))

            # get batch and labels
            batch = Tensor(X_train[samp], requires_grad=False)
            labels = Tensor(Y_train[samp])

            out = model(batch) # forward pass
            loss = out.sparse_categorical_crossentropy(labels) # calculate loss
            optim.zero_grad() # zero out gradients
            loss.backward() # backward pass
            optim.step() # update weights

            running_train_loss += loss.numpy()

    train_loss = running_train_loss / STEPS # loss over all batches, over num batches

    # test accuracy over the whole dataset
    out = model(Tensor(X_test))
    pred = out.argmax(axis=1) # get the index of the max value
    accuracy = (pred == Tensor(Y_test)).mean().numpy()

    elapsed = perf_counter() - start
    total_time += elapsed

    print(f"Epoch {epoch+1}/{EPOCHS}: {steps} Batches (max: {max_batches_per_epoch}) | Train Loss: {train_loss:.4f} | Test Accuracy: {accuracy:.4f} | Time: {elapsed:.2f}s")

print(f"Total training time: {total_time:.2f}s")