## Device configuration

In [1]:
import torch

if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")  
print(device)

mps


## Model

In [2]:
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Conv2d(1, 32, kernel_size=(3,3))
        self.l2 = nn.Conv2d(32, 64, kernel_size=(3,3))
        self.l3 = nn.Linear(1600, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.l1(x)), (2,2))
        x = F.max_pool2d(F.relu(self.l2(x)), (2,2))
        return self.l3(F.dropout(x.flatten(1), 0.5, self.training))

## Get the MNIST dataset

In [3]:
from tinygrad.nn.datasets import mnist

X_train, Y_train, X_test, Y_test = mnist()

# Convert tinygrad Tensors to PyTorch tensors 
X_train = torch.from_numpy(X_train.numpy()).float().reshape(-1, 1, 28, 28) 
Y_train = torch.from_numpy(Y_train.numpy()).long()
X_test = torch.from_numpy(X_test.numpy()).float().reshape(-1, 1, 28, 28)
Y_test = torch.from_numpy(Y_test.numpy()).long()

print(X_train.shape, X_train.dtype, Y_train.shape, Y_train.dtype)

https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz: 4
https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz: 6
https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz: 7.
https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz: 10


torch.Size([60000, 1, 28, 28]) torch.float32 torch.Size([60000]) torch.int64


## Use the model

In [4]:
# Move model and data to the device
model = Model().to(device)
X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)

In [5]:
acc = (model(X_test).argmax(axis=1) == Y_test).float().mean()
print(acc.item())  

0.12239999324083328


### Use the tinygrad weights

In [6]:
import numpy as np

loaded_weights = np.load('tinygrad_weights.npy', allow_pickle=True).item()

model.l1.weight.data = torch.tensor(loaded_weights['l1.weight']).to(device)
model.l1.bias.data = torch.tensor(loaded_weights['l1.bias']).to(device)
model.l2.weight.data = torch.tensor(loaded_weights['l2.weight']).to(device)
model.l2.bias.data = torch.tensor(loaded_weights['l2.bias']).to(device)
model.l3.weight.data = torch.tensor(loaded_weights['l3.weight']).to(device)
model.l3.bias.data = torch.tensor(loaded_weights['l3.bias']).to(device)

## Final probabilities

In [7]:
test_image = X_test[0:1]
model.eval()
with torch.no_grad():
    pytorch_probs = F.softmax(model(test_image), dim=1).cpu().numpy()
print("PyTorch probabilities:", pytorch_probs)

PyTorch probabilities: [[2.6096042e-13 1.2902114e-15 3.1515359e-07 2.1935040e-09 4.6694077e-19
  1.0634544e-16 2.9910848e-26 9.9999964e-01 2.9652318e-12 2.1830036e-12]]
