A rough copy of https://blog.paperspace.com/alexnet-pytorch/

In [1]:
import numpy as np
from PIL import Image
import torch
from torch import nn
from datasets import fetch_cifar
from helpers import train, evaluate
torch.manual_seed(1337)

<torch._C.Generator at 0x110e7c6d0>

In [2]:
class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 96, 11, stride=4, padding=0),
            nn.BatchNorm2d(96),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(96, 256, 5, stride=1, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2),
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(256, 384, 3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(384, 384, 3, stride=1, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(),
        )
        self.layer5 = nn.Sequential(
            nn.Conv2d(384, 256, 3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(3, stride=2),
        )
        self.fc1 = nn.Sequential(nn.Dropout(0.5), nn.Linear(9216, 4096), nn.ReLU())
        self.fc2 = nn.Sequential(nn.Dropout(0.5), nn.Linear(4096, 4096), nn.ReLU())
        self.fc3 = nn.Linear(4096, 10)
    
    def __call__(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = out.reshape(out.shape[0], -1)
        out = self.fc1(out)
        out = self.fc2(out)
        out = self.fc3(out)
        return out

In [3]:
def transform(x):
    x = [[Image.fromarray(z).resize((227, 227)) for z in y] for y in x]
    x = np.stack([np.stack([np.asarray(z) for z in y], axis=0) for y in x], axis=0)
    x = x.reshape(-1, 3, 227, 227)
    return x

In [4]:
(X_train, Y_train), (X_test, Y_test) = fetch_cifar(), fetch_cifar(train=False)
model = AlexNet()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train(model, X_train, Y_train, optimizer, 5000, BS=128, transform=transform)
evaluate(model, X_test, Y_test, transform=transform)

loss 1.75 accuracy 0.30:   5%|‚ñç         | 240/5000 [21:11<7:37:05,  5.76s/it] 