In [1]:
from PIL import Image
import numpy as np
import torch
from torch import nn, optim
from datasets import load_dataset
from helpers import get_device, train, evaluate

In [2]:
torch.manual_seed(0)
np.random.seed(0)
device = get_device()

In [3]:
def transform(x):
    x = x.reshape(-1, 3, 32, 32).astype(np.float32)
    x = ((Image.fromarray(z).resize((224, 224)) for z in y) for y in x)
    x = np.stack([np.stack([np.asarray(z) for z in y], axis=0) for y in x], axis=0)
    return x

In [4]:
dataset = load_dataset("cifar10")

X_train = (np.array([np.array(image) for image in dataset["train"]["img"]]) / 255.0 - 0.5) / 0.25
Y_train = np.array(dataset["train"]["label"], dtype=np.int32)

X_test = (np.array([np.array(image) for image in dataset["test"]["img"]]) / 255.0 - 0.5) / 0.25
Y_test = np.array(dataset["test"]["label"], dtype=np.int32)

Using the latest cached version of the dataset since cifar10 couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'plain_text' at /Users/jacky/.cache/huggingface/datasets/cifar10/plain_text/0.0.0/0b2714987fa478483af9968de7c934580d0bb9a2 (last modified on Thu Jan  9 21:23:26 2025).


In [5]:
class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 11, stride=4, padding=2)
        self.max_pool1 = nn.MaxPool2d(3, stride=2)
        self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2)
        self.max_pool2 = nn.MaxPool2d(2, stride=2)
        self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.max_pool3 = nn.MaxPool2d(3, stride=2)
        self.dropout1 = nn.Dropout(0.25)
        self.l1 = nn.Linear(9216, 4096)
        self.dropout2 = nn.Dropout(0.25)
        self.l2 = nn.Linear(4096, 4096)
        self.l3 = nn.Linear(4096, 10)

    def __call__(self, x):
        x = self.conv1(x).relu()
        x = self.max_pool1(x)
        x = self.conv2(x).relu()
        x = self.max_pool2(x)
        x = self.conv3(x).relu()
        x = self.conv4(x).relu()
        x = self.conv5(x).relu()
        x = self.max_pool3(x)
        x = x.view(-1, 9216)
        x = self.l1(self.dropout1(x)).relu()
        x = self.l2(self.dropout2(x)).relu()
        x = self.l3(x)
        return x

In [6]:
model = AlexNet().to(device)
model(torch.ones((1, 3, 224, 224), device=device))

tensor([[-0.0108,  0.0145, -0.0041, -0.0121,  0.0098, -0.0203,  0.0007,  0.0177,
          0.0001, -0.0193]], device='mps:0', grad_fn=<LinearBackward0>)

In [None]:
lr = 0.01
weight_decay = 0.001
momentum = 0.9
epochs = 5
batch_size = 128

In [None]:
optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum)
train_steps = len(X_train) // batch_size
test_steps = len(X_test) // batch_size

for epoch in range(epochs):
    train(model, X_train, Y_train, optimizer, train_steps, device=device, transform=transform)

evaluate(model, X_test, Y_test, device=device, transform=transform)

loss 2.09 accuracy 0.22:  71%|███████   | 275/390 [01:32<00:39,  2.92it/s]