A rough copy of https://blog.paperspace.com/alexnet-pytorch/

In [1]:
import random
from PIL import Image
import numpy as np
import torch
from torch import nn, optim
from datasets import load_dataset
from helpers import get_device, train, evaluate

In [2]:
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
device = get_device()

In [3]:
def preprocess(x):
    x = x.reshape(-1, 3, 32, 32).astype(np.float32)
    std = x.std(axis=(0, 2, 3))
    mean = x.mean(axis=(0, 2, 3))
    x = ((x - mean.reshape(1, -1, 1, 1)) / std.reshape(1, -1, 1, 1))
    return x



def random_crop(image, crop_width=31, crop_height=31):
    _, height, width = image.shape
    x = random.randint(0, width - crop_width)
    y = random.randint(0, height - crop_height)
    image = image[:, x:x+crop_width, y:y+crop_height]
    return image


def random_flip(image, flip_prob=0.25):
    if random.random() < flip_prob:
        image = np.flip(image, axis=1)
    return image


def transform(x):
    x = np.stack([random_crop(y) for y in x], axis=0)
    x = ((Image.fromarray(z).resize((227, 227)) for z in y) for y in x)
    x = np.stack([random_flip(np.stack([np.asarray(z) for z in y], axis=0)) for y in x], axis=0)
    return x


def target_transform(x):
    x = ((Image.fromarray(z).resize((227, 227)) for z in y) for y in x)
    x = np.stack([np.stack([np.asarray(z) for z in y], axis=0) for y in x], axis=0)
    return x

In [4]:
dataset = load_dataset("cifar10")

X_train = preprocess(np.array([np.array(image) for image in dataset["train"]["img"]]))
Y_train = np.array(dataset["train"]["label"], dtype=np.int32)

X_test = preprocess(np.array([np.array(image) for image in dataset["test"]["img"]]))
Y_test = np.array(dataset["test"]["label"], dtype=np.int32)

In [5]:
class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 96, 11, stride=4, padding=0),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(3, 2),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(96, 256, 5, stride=1, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, 2),
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 384, 3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(384, 384, 3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(384, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, 2),
        )
        self.linear1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(9216, 4096),
            nn.ReLU(),
        )
        self.linear2 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
        )
        self.linear3 = nn.Linear(4096, 10)

    def __call__(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = x.view(x.size(0), -1)
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)
        return x

In [6]:
model = AlexNet().to(device)
model(torch.ones((1, 3, 227, 227), device=device))

tensor([[ 0.0546,  0.2139,  0.4053, -0.1958,  0.0628,  0.2879, -0.3752, -0.0284,
         -0.1380,  0.3627]], device='mps:0', grad_fn=<LinearBackward0>)

In [7]:
lr = 0.005
weight_decay = 0.005
momentum = 0.9
epochs = 20
batch_size = 64

In [8]:
optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum)
train_steps = len(X_train) // batch_size
test_steps = len(X_test) // batch_size

for epoch in range(epochs):
    train(model, X_train, Y_train, optimizer, train_steps, device=device, transform=transform)
    evaluate(model, X_test, Y_test, device=device, target_transform=target_transform)

loss 1.43 accuracy 0.49: 100%|██████████| 781/781 [06:47<00:00,  1.91it/s]
100%|██████████| 79/79 [00:16<00:00,  4.90it/s]


test set accuracy is 0.3539


loss 1.24 accuracy 0.52: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.97it/s]


test set accuracy is 0.4514


loss 1.32 accuracy 0.52: 100%|██████████| 781/781 [07:15<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.95it/s]


test set accuracy is 0.4437


loss 1.28 accuracy 0.54: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.94it/s]


test set accuracy is 0.4737


loss 1.01 accuracy 0.61: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.98it/s]


test set accuracy is 0.4824


loss 1.10 accuracy 0.59: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.96it/s]


test set accuracy is 0.5411


loss 1.00 accuracy 0.61: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.99it/s]


test set accuracy is 0.4854


loss 0.92 accuracy 0.64: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.95it/s]


test set accuracy is 0.5142


loss 0.98 accuracy 0.68: 100%|██████████| 781/781 [07:13<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.96it/s]


test set accuracy is 0.3851


loss 0.90 accuracy 0.72: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.96it/s]


test set accuracy is 0.5457


loss 0.99 accuracy 0.63: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.96it/s]


test set accuracy is 0.4536


loss 1.08 accuracy 0.59: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.97it/s]


test set accuracy is 0.5136


loss 1.02 accuracy 0.65: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.97it/s]


test set accuracy is 0.523


loss 0.88 accuracy 0.71: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.96it/s]


test set accuracy is 0.5188


loss 0.76 accuracy 0.73: 100%|██████████| 781/781 [07:13<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.96it/s]


test set accuracy is 0.5301


loss 1.02 accuracy 0.61: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.96it/s]


test set accuracy is 0.5406


loss 0.80 accuracy 0.67: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.95it/s]


test set accuracy is 0.4693


loss 0.87 accuracy 0.70: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:16<00:00,  4.93it/s]


test set accuracy is 0.4909


loss 0.88 accuracy 0.67: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.97it/s]


test set accuracy is 0.5542


loss 0.87 accuracy 0.62: 100%|██████████| 781/781 [07:14<00:00,  1.80it/s]
100%|██████████| 79/79 [00:15<00:00,  4.96it/s]

test set accuracy is 0.56



