A rough copy of https://blog.paperspace.com/alexnet-pytorch/

In [1]:
import numpy as np
from PIL import Image
import torch
from torch import nn
from datasets import fetch_cifar
from helpers import train, evaluate
torch.manual_seed(1337)

<torch._C.Generator at 0x116e7c650>

In [2]:
class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 96, 11, stride=4, padding=0),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(96, 256, 5, stride=1, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2),
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 384, 3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(384, 384, 3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(384, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2),
        )
        self.fc1 = nn.Sequential(nn.Dropout(0.5), nn.Linear(1024, 256), nn.ReLU())
        self.fc2 = nn.Sequential(nn.Dropout(0.5), nn.Linear(256, 64), nn.ReLU())
        self.fc3 = nn.Linear(64, 10)
    
    def __call__(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = out.reshape(out.shape[0], -1)
        out = self.fc1(out)
        out = self.fc2(out)
        out = self.fc3(out)
        return out

In [3]:
def transform(x):
    x = [[Image.fromarray(z).resize((128, 128)) for z in y] for y in x]
    x = np.stack([np.stack([np.asarray(z) for z in y], axis=0) for y in x], axis=0)
    x = x.reshape(-1, 3, 128, 128)
    return x

In [4]:
(X_train, Y_train), (X_test, Y_test) = fetch_cifar(), fetch_cifar(train=False)
model = AlexNet()
learning_rate = 0.001

for _ in range(5):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    train(model, X_train, Y_train, optimizer, 1000, BS=128, transform=transform)
    evaluate(model, X_test, Y_test, transform=transform)
    learning_rate /= 2

loss 0.95 accuracy 0.66: 100%|██████████| 1000/1000 [15:48<00:00,  1.05it/s]
100%|██████████| 79/79 [00:36<00:00,  2.14it/s]


test set accuracy is 0.62


loss 0.69 accuracy 0.75: 100%|██████████| 1000/1000 [15:44<00:00,  1.06it/s]
100%|██████████| 79/79 [00:33<00:00,  2.37it/s]


test set accuracy is 0.6512


loss 0.34 accuracy 0.90: 100%|██████████| 1000/1000 [15:07<00:00,  1.10it/s]
100%|██████████| 79/79 [00:32<00:00,  2.40it/s]


test set accuracy is 0.8141


loss 0.16 accuracy 0.95: 100%|██████████| 1000/1000 [15:06<00:00,  1.10it/s]
100%|██████████| 79/79 [00:33<00:00,  2.36it/s]


test set accuracy is 0.8275


loss 0.11 accuracy 0.97: 100%|██████████| 1000/1000 [15:11<00:00,  1.10it/s]
100%|██████████| 79/79 [00:35<00:00,  2.22it/s]

test set accuracy is 0.8383



