We implement Krizhevsky's AlexNet, originally developed for the ImageNet contest, as a modification of the architecture in the following [image](https://upload.wikimedia.org/wikipedia/commons/c/cc/Comparison_image_neural_networks.svg). We apply it to the CIFAR10 dataset.

In [1]:
import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
import numpy as np

In [2]:
device = (
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

print(f"Using {device}")

if device == "cuda":
    print(torch.cuda.get_device_name(torch.cuda.current_device()))

Using cuda
Tesla T4


In [5]:
class AlexNet(nn.Module):
    # in_channels: number of input channels
    def __init__(self, in_channels: int) -> None:
        super().__init__()

        self.relu = nn.ReLU()
        # Assume input is 32 x 32 x 3
        # The original implementation had kernel_size=11 and stride=4
        # 32 x 32 x 96
        self.conv0 = nn.Conv2d(in_channels, 96, kernel_size=5, padding=2)
        # Original implementation maxpools had kernel_size=3
        # 16 x 16 x 96
        self.maxpool0 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 16 x 16 x 256
        self.conv1 = nn.Conv2d(96, 256, kernel_size=5, padding=2)
        # 8 x 8 x 96
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # 8 x 8 x 384
        self.conv2 = nn.Conv2d(256, 384, kernel_size=5, padding=2)
        # 8 x 8 x 384
        self.conv3 = nn.Conv2d(384, 384, kernel_size=5, padding=2)
        # 8 x 8 x 256
        self.conv4 = nn.Conv2d(384, 256, kernel_size=5, padding=2)
        # 4 x 4 x 256
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(p=0.5)
        # We make the layers have output 256 versus the original 4096
        self.layer0 = nn.Linear(4096, 256)
        self.layer1 = nn.Linear(256, 256)
        self.layer2 = nn.Linear(256, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv0(x)
        x = self.relu(x)
        x = self.maxpool0(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.conv4(x)
        x = self.relu(x)
        x = self.maxpool2(x)
        x = x.reshape(x.shape[0], (-1))
        x = self.layer0(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.layer1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.layer2(x)
        return x

In [6]:
transform_CIFAR10 = transforms.Compose([
    transforms.ToTensor(),
    # Mean and standard deviation for CIFAR10 dataset
    # Sourced from gist.github.com/weiaicunzai/e623931921efefd4c331622c344d8151
    transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2470, 0.2435, 0.2616]
    )
])

train_data = datasets.CIFAR10(
    root="data",
    train=True,
    transform=transform_CIFAR10,
    download=True
)

test_data = datasets.CIFAR10(
    root="data",
    train=False,
    transform=transform_CIFAR10,
    download=True
)

BATCH_SIZE = 64
VALID_SIZE = 0.9

train_indices = list(range(len(train_data)))
np.random.shuffle(train_indices)
valid_split = int(len(train_data) * VALID_SIZE)
valid_indices = train_indices[valid_split:]
train_indices = train_indices[:valid_split]
valid_data = Subset(train_data, valid_indices)
train_data = Subset(train_data, train_indices)

print(f"Training data: {len(train_data)}")
print(f"Validation data: {len(valid_data)}")
print(f"Test data: {len(test_data)}")

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE)

Files already downloaded and verified
Files already downloaded and verified
Training data: 45000
Validation data: 5000
Test data: 10000


In [8]:
EPOCHS = 20
LR = 0.01
WEIGHT_DECAY = 0.001
MOMENTUM = 0.9

model = AlexNet(3)
model = nn.DataParallel(model)
model.to(device)

metric = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=LR,
    weight_decay=WEIGHT_DECAY,
    momentum=MOMENTUM
)

In [9]:
def train(
    loader: DataLoader,
    model: nn.Module,
    metric: nn.Module,
    optimizer: torch.optim.Optimizer
) -> None:
    total = len(loader.dataset)
    model.train()

    for batch, (x, y) in enumerate(loader):
        x = x.to(device)
        y = y.to(device)

        pred = model(x)
        loss = metric(pred, y)

        if batch % 100 == 99:
            progress = (batch + 1) * len(x)
            print(f"\tLoss: {loss.item():>7f} [{progress:>5d} / {total:>5d}]")
            print(f"\t\tLearning rate: {optimizer.param_groups[0]['lr']:>8f}")

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [10]:
def test(
    loader: DataLoader,
    model: nn.Module,
    metric: nn.Module,
) -> None:
    total = len(loader.dataset)
    batch_total = len(loader)
    total_loss = 0
    total_correct = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)

            pred = model(x)
            total_loss += metric(pred, y).item()
            pred_correct = pred.argmax(1) == y
            total_correct += pred_correct.type(torch.float).sum().item()

        total_loss /= batch_total
        total_correct /= total
        print(f"\tAccuracy: {(100 * total_correct):>0.1f}%")
        print(f"\tAverage loss: {total_loss:>8f}")

In [11]:
for t in range(EPOCHS):
    print(f"Epoch: {t + 1}")
    train(train_loader, model, metric, optimizer)
    test(valid_loader, model, metric)

torch.save(model.state_dict(), "alexnet.pth")

Epoch: 1
	Loss: 2.302569 [ 6400 / 45000]
		Learning rate: 0.010000
	Loss: 2.111802 [12800 / 45000]
		Learning rate: 0.010000
	Loss: 1.858279 [19200 / 45000]
		Learning rate: 0.010000
	Loss: 2.042349 [25600 / 45000]
		Learning rate: 0.010000
	Loss: 1.786326 [32000 / 45000]
		Learning rate: 0.010000
	Loss: 1.893871 [38400 / 45000]
		Learning rate: 0.010000
	Loss: 1.700656 [44800 / 45000]
		Learning rate: 0.010000
	Accuracy: 37.1%
	Average loss: 1.617891
Epoch: 2
	Loss: 1.425370 [ 6400 / 45000]
		Learning rate: 0.010000
	Loss: 1.585904 [12800 / 45000]
		Learning rate: 0.010000
	Loss: 1.771493 [19200 / 45000]
		Learning rate: 0.010000
	Loss: 1.666008 [25600 / 45000]
		Learning rate: 0.010000
	Loss: 1.376256 [32000 / 45000]
		Learning rate: 0.010000
	Loss: 1.598711 [38400 / 45000]
		Learning rate: 0.010000
	Loss: 1.350612 [44800 / 45000]
		Learning rate: 0.010000
	Accuracy: 44.9%
	Average loss: 1.396186
Epoch: 3
	Loss: 1.244387 [ 6400 / 45000]
		Learning rate: 0.010000
	Loss: 1.372116 [1280

In [12]:
test(test_loader, model, metric)

	Accuracy: 79.1%
	Average loss: 0.800118
