In [1]:
import numpy as np
from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, BatchNorm2d, optim
from datasets import fetch_cifar
from extra.training import train, evaluate

In [2]:
class AlexNet:
    def __init__(self):
        self.conv1 = Conv2d(3, 96, 11, stride=4)
        self.bn1 = BatchNorm2d(96)
        self.conv2 = Conv2d(96, 256, 5, padding=2)
        self.bn2 = BatchNorm2d(256)
        self.conv3 = Conv2d(256, 384, 3, padding=1)
        self.bn3 = BatchNorm2d(384)
        self.conv4 = Conv2d(384, 384, 3, padding=1)
        self.bn4 = BatchNorm2d(384)
        self.conv5 = Conv2d(384, 256, 3, padding=1)
        self.bn5 = BatchNorm2d(256)
        self.fc1 = {"weight": Tensor.scaled_uniform(9216, 4096), "bias": Tensor.zeros(4096)}
        self.fc2 = {"weight": Tensor.scaled_uniform(4096, 4096), "bias": Tensor.zeros(4096)}
        self.fc3 = {"weight": Tensor.scaled_uniform(4096, 10), "bias": Tensor.zeros(10)}
    
    def __call__(self, x):
        x = self.bn1(self.conv1(x)).relu().max_pool2d((3, 3), stride=2)
        x = self.bn2(self.conv2(x)).relu().max_pool2d((3, 3), stride=2)
        x = self.bn3(self.conv3(x)).relu()
        x = self.bn4(self.conv4(x)).relu()
        x = self.bn5(self.conv5(x)).relu().max_pool2d((3, 3), stride=2)
        x = x.reshape(x.shape[0], -1)
        x = x.dropout(0.5).linear(**self.fc1).relu()
        x = x.dropout(0.5).linear(**self.fc2).relu()
        x = x.linear(**self.fc3).log_softmax()
        return x

In [3]:
def transform(x):
    x = [[Image.fromarray(z).resize((227, 227)) for z in y] for y in x]
    x = np.stack([np.stack([np.asarray(z) for z in y], axis=0) for y in x], axis=0)
    x = x.reshape(-1, 3, 227, 227)
    return x

In [4]:
(X_train, Y_train), (X_test, Y_test) = fetch_cifar(), fetch_cifar(train=False)
model = AlexNet()
learning_rate = 0.005
for _ in range(10):
    optimizer = optim.SGD(optim.get_parameters(model), lr=learning_rate, momentum=0.9)
    train(model, X_train, Y_train, optimizer, 1000, BS=64, transform=transform)
    evaluate(model, X_test, Y_test, transform=transform)
    learning_rate *= 0.6

loss 1.31 accuracy 0.50: 100%|█████████████████████████████████████████| 1000/1000 [47:46<00:00,  2.87s/it]
100%|██████████████████████████████████████████████████████████████████████| 79/79 [01:14<00:00,  1.07it/s]


test set accuracy is 0.508400


loss 0.86 accuracy 0.69: 100%|█████████████████████████████████████████| 1000/1000 [48:33<00:00,  2.91s/it]
100%|██████████████████████████████████████████████████████████████████████| 79/79 [01:14<00:00,  1.07it/s]


test set accuracy is 0.646100


loss 0.57 accuracy 0.78: 100%|█████████████████████████████████████████| 1000/1000 [48:39<00:00,  2.92s/it]
100%|██████████████████████████████████████████████████████████████████████| 79/79 [01:15<00:00,  1.04it/s]


test set accuracy is 0.745600


loss 0.42 accuracy 0.88: 100%|█████████████████████████████████████████| 1000/1000 [48:27<00:00,  2.91s/it]
100%|██████████████████████████████████████████████████████████████████████| 79/79 [01:14<00:00,  1.06it/s]


test set accuracy is 0.762300


loss 0.49 accuracy 0.83: 100%|█████████████████████████████████████████| 1000/1000 [48:11<00:00,  2.89s/it]
100%|██████████████████████████████████████████████████████████████████████| 79/79 [01:15<00:00,  1.04it/s]


test set accuracy is 0.777500


loss 0.90 accuracy 0.73: 100%|█████████████████████████████████████████| 1000/1000 [48:42<00:00,  2.92s/it]
100%|██████████████████████████████████████████████████████████████████████| 79/79 [01:15<00:00,  1.04it/s]


test set accuracy is 0.800000


loss 0.44 accuracy 0.86: 100%|█████████████████████████████████████████| 1000/1000 [48:43<00:00,  2.92s/it]
100%|██████████████████████████████████████████████████████████████████████| 79/79 [01:15<00:00,  1.04it/s]


test set accuracy is 0.812800


loss 0.30 accuracy 0.89: 100%|█████████████████████████████████████████| 1000/1000 [48:42<00:00,  2.92s/it]
100%|██████████████████████████████████████████████████████████████████████| 79/79 [01:15<00:00,  1.04it/s]


test set accuracy is 0.818000


loss 0.30 accuracy 0.88: 100%|█████████████████████████████████████████| 1000/1000 [48:04<00:00,  2.88s/it]
100%|██████████████████████████████████████████████████████████████████████| 79/79 [01:15<00:00,  1.04it/s]


test set accuracy is 0.817700


loss 0.37 accuracy 0.89: 100%|█████████████████████████████████████████| 1000/1000 [48:28<00:00,  2.91s/it]
100%|██████████████████████████████████████████████████████████████████████| 79/79 [01:15<00:00,  1.04it/s]

test set accuracy is 0.821900



