A rough copy of https://blog.paperspace.com/alexnet-pytorch/

In [9]:
import random
from PIL import Image
import numpy as np
import torch
from torch import nn, optim
from datasets import load_dataset
from helpers import get_device, train, evaluate

In [10]:
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
device = get_device()

In [11]:
def preprocess(x, eps=1e-7):
    mean = np.array([0.4914, 0.4822, 0.4465]).reshape(1, 1, 1, -1)
    std = np.array([0.2023, 0.1994, 0.2010]).reshape(1, 1, 1, -1)
    x = x.astype(np.float32)
    x = (x - mean) / (std + eps)
    return x


def random_crop(image, crop_size=30):
    h, w, _ = image.shape
    x = random.randint(0, w - crop_size)
    y = random.randint(0, h - crop_size)
    return image[x:x+crop_size, y:y+crop_size, :]


def random_flip(image, flip_prob=0.5):
    if random.random() < flip_prob:
        image = np.fliplr(image)
    return image


def transform(x):
    x = np.array([random_crop(y) for y in x]).transpose(0, 3, 1, 2)
    x = ((Image.fromarray(z).resize((227, 227)) for z in y) for y in x)
    x = np.stack([random_flip(np.stack([np.asarray(z) for z in y], axis=0)) for y in x], axis=0)
    return x


def target_transform(x):
    x = x.transpose(0, 3, 1, 2)
    x = ((Image.fromarray(z).resize((227, 227)) for z in y) for y in x)
    x = np.stack([np.stack([np.asarray(z) for z in y], axis=0) for y in x], axis=0)
    return x

In [12]:
dataset = load_dataset("cifar10")

X_train = preprocess(np.array([np.array(image) for image in dataset["train"]["img"]]))
Y_train = np.array(dataset["train"]["label"], dtype=np.int32)

X_test = preprocess(np.array([np.array(image) for image in dataset["test"]["img"]]))
Y_test = np.array(dataset["test"]["label"], dtype=np.int32)

In [13]:
class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 96, 11, stride=4, padding=0),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(3, 2),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(96, 256, 5, stride=1, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, 2),
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 384, 3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(384, 384, 3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(384, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, 2),
        )
        self.linear1 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(9216, 4096),
            nn.ReLU(),
        )
        self.linear2 = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
        )
        self.linear3 = nn.Linear(4096, 10)

    def __call__(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = x.view(x.size(0), -1)
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)
        return x

In [14]:
model = AlexNet().to(device)
with torch.no_grad():
    model(torch.ones((1, 3, 227, 227), device=device))

In [15]:
epochs = 10
lr = 0.005
weight_decay = lr / epochs
momentum = 0.9
batch_size = 128

In [16]:
optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay, momentum=momentum)
train_steps = len(X_train) // batch_size

for epoch in range(epochs):
    train(model, X_train, Y_train, optimizer, train_steps, device=device, transform=transform)
    evaluate(model, X_test, Y_test, device=device, target_transform=target_transform)

loss 1.14 accuracy 0.59: 100%|██████████| 390/390 [03:17<00:00,  1.98it/s]
100%|██████████| 79/79 [00:14<00:00,  5.42it/s]


test set accuracy is 0.4569


loss 1.07 accuracy 0.66: 100%|██████████| 390/390 [03:13<00:00,  2.02it/s]
100%|██████████| 79/79 [00:14<00:00,  5.43it/s]


test set accuracy is 0.4008


loss 1.10 accuracy 0.56: 100%|██████████| 390/390 [03:07<00:00,  2.08it/s]
100%|██████████| 79/79 [00:14<00:00,  5.46it/s]


test set accuracy is 0.6468


loss 0.66 accuracy 0.76: 100%|██████████| 390/390 [03:08<00:00,  2.07it/s]
100%|██████████| 79/79 [00:14<00:00,  5.40it/s]


test set accuracy is 0.6306


loss 0.71 accuracy 0.76: 100%|██████████| 390/390 [03:10<00:00,  2.05it/s]
100%|██████████| 79/79 [00:14<00:00,  5.42it/s]


test set accuracy is 0.6936


loss 0.76 accuracy 0.78: 100%|██████████| 390/390 [03:06<00:00,  2.09it/s]
100%|██████████| 79/79 [00:14<00:00,  5.55it/s]


test set accuracy is 0.6189


loss 0.70 accuracy 0.74: 100%|██████████| 390/390 [03:09<00:00,  2.06it/s]
100%|██████████| 79/79 [00:14<00:00,  5.49it/s]


test set accuracy is 0.6908


loss 0.71 accuracy 0.77: 100%|██████████| 390/390 [03:11<00:00,  2.04it/s]
100%|██████████| 79/79 [00:14<00:00,  5.33it/s]


test set accuracy is 0.7148


loss 0.54 accuracy 0.83: 100%|██████████| 390/390 [03:16<00:00,  1.98it/s]
100%|██████████| 79/79 [00:14<00:00,  5.33it/s]


test set accuracy is 0.7248


loss 0.58 accuracy 0.81: 100%|██████████| 390/390 [03:10<00:00,  2.05it/s]
100%|██████████| 79/79 [00:14<00:00,  5.45it/s]

test set accuracy is 0.7514



