A rough copy of https://blog.paperspace.com/alexnet-pytorch/

In [1]:
import numpy as np
from PIL import Image
import torch
from torch import nn
import torch.nn.functional as F
from datasets import fetch_cifar
from helpers import get_gpu, train, evaluate
torch.manual_seed(1337)
device = get_gpu()

In [2]:
class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 96, 11, stride=4)
        self.bn1 = nn.BatchNorm2d(96)
        self.conv2 = nn.Conv2d(96, 256, 5, padding=2)
        self.bn2 = nn.BatchNorm2d(256)
        self.conv3 = nn.Conv2d(256, 384, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(384)
        self.conv4 = nn.Conv2d(384, 384, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(384)
        self.conv5 = nn.Conv2d(384, 256, 3, padding=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.fc1 = nn.Linear(9216, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 10)
    
    def __call__(self, x):
        x = F.max_pool2d(self.bn1(self.conv1(x)).relu(), 3, stride=2)
        x = F.max_pool2d(self.bn2(self.conv2(x)).relu(), 3, stride=2)
        x = self.bn3(self.conv3(x)).relu()
        x = self.bn4(self.conv4(x)).relu()
        x = F.max_pool2d(self.bn5(self.conv5(x)).relu(), 3, stride=2)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(F.dropout(x)).relu()
        x = self.fc2(F.dropout(x)).relu()
        x = self.fc3(x)
        return x

In [3]:
def transform(x):
    x = [[Image.fromarray(z).resize((227, 227)) for z in y] for y in x]
    x = np.stack([np.stack([np.asarray(z) for z in y], axis=0) for y in x], axis=0)
    x = x.reshape(-1, 3, 227, 227)
    return x

In [4]:
(X_train, Y_train), (X_test, Y_test) = fetch_cifar(), fetch_cifar(train=False)
model = AlexNet().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, weight_decay=0.0005, momentum=0.9)
train(model, X_train, Y_train, optimizer, 10000, BS=64, transform=transform, device=device)
evaluate(model, X_test, Y_test, transform=transform, device=device)

loss 0.25 accuracy 0.89: 100%|███████████████████████████████████████| 10000/10000 [44:10<00:00,  3.77it/s]
100%|██████████████████████████████████████████████████████████████████████| 79/79 [00:16<00:00,  4.73it/s]

test set accuracy is 0.8189



