Implement AlexNet using PyTorch.

Link: https://proceedings.neurips.cc/paper_files/paper/2012/file/c399862d3b9d6b76c8436e924a68c45b-Paper.pdf

5 convolutional and 3 fully connected layers

Use ReLU activations after each convolutional and fully connected layer.

ReLUs have the desirable property that they do not require input normalization to prevent them
from saturating.

Uses overlapping pooling (i.e. pooling windows that overlap).

Mistakes:

Missing flatten: After conv layers, you need to flatten before the first linear layer

Wrong input size to linear layer: After the conv layers, the spatial dimensions aren't 1x1, so input to fcl_1 should be calculated, not just 256

In [1]:
import torch
from torch import nn 
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset

class AlexNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.dropout = nn.Dropout(0.5) # from paper
        self.relu = nn.ReLU() # from paper 

        self.layers = nn.Sequential(
            nn.Conv2d(3, 96, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),
            
            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=2),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )

        # After self.layers definition, add:
        with torch.no_grad():
            x = torch.randn(1, 3, 32, 32)
            out = self.layers(x)
            self.flattened_size = out.view(1, -1).size(1)

        self.fcl_1 = nn.Linear(self.flattened_size, 2048)        
        self.fcl_2 = nn.Linear(2048, 2048)
        self.fcl_3 = nn.Linear(2048, 100)
    
    def forward(self, x):
        x = self.layers(x)
        x = x.view(x.size(0), -1)
        x = self.fcl_1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fcl_2(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fcl_3(x)

        return x

In [2]:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.set_float32_matmul_precision('high')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    torch.cuda.empty_cache()

train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Mild to avoid over-distortion
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5071, 0.4865, 0.4409], std=[0.2673, 0.2564, 0.2761]),
    transforms.RandomErasing(p=0.25)  # Apply after normalization for consistency
])

test_transform = transforms.Compose([
    transforms.ToTensor(), # Moved ToTensor before Normalize (good practice)
    transforms.Normalize(mean=[0.5071, 0.4867, 0.4408], std=[0.2675, 0.2565, 0.2761])
])

# Load raw datasets
cifar_train_raw = datasets.CIFAR100(root="./data", train=True, download=True, transform=None)

train_size = int(0.9 * len(cifar_train_raw))  # 48,000

train_indices = list(range(0, train_size))
val_indices = list(range(train_size, len(cifar_train_raw)))

# Create datasets with appropriate transforms
cifar_train = Subset(
    datasets.CIFAR100(root="./data", train=True, transform=train_transform),
    train_indices
)
cifar_val = Subset(
    datasets.CIFAR100(root="./data", train=True, transform=test_transform),
    val_indices
)
# Use original test set (10,000 samples) - close to 10% of 60,000
cifar_test = datasets.CIFAR100(root="./data", train=False, transform=test_transform)

train_loader = DataLoader(
    cifar_train,  # Use directly
    batch_size=1024,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

val_loader = DataLoader(
    cifar_val,  # Use directly
    batch_size=1024,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

test_loader = DataLoader(
    cifar_test,
    batch_size=1024,
    shuffle=False,
    num_workers=2,
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=6
)

num_classes = 100

model = AlexNet().to(device)

num_epochs = 40
loss_function = nn.CrossEntropyLoss(label_smoothing=0.1)
base_lr = 4e-3

batch_scale = 1024 / 256  # 4x larger batches
scaled_lr = base_lr * batch_scale**0.5  # Square root scaling

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-3,  # Lower initial LR
    weight_decay=1e-4  # Reduce weight decay
)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=3e-3,  # Much lower peak (was 5e-3)
    epochs=num_epochs,
    steps_per_epoch=len(train_loader),
    pct_start=0.3,  # Shorter warmup
    anneal_strategy='cos',
    div_factor=25.0,  # Gentler start
    final_div_factor=1000.0
)

best_val_loss = float('inf')

for epoch in range(num_epochs):
    print(f'Starting Epoch {epoch+1}')
    model.train()

    current_loss = 0.0
    num_batches = 0

    for i, data in enumerate(train_loader):
        inputs, targets = data
        inputs, targets = inputs.to(device), targets.to(device)
            
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # ADD THIS

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()

        current_loss += loss.item()
        num_batches += 1

        if i % 50 == 0:
            print(f'Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}')


    avg_train_loss = current_loss / num_batches
    print(f'Epoch {epoch+1} finished')
    print(f'Training - Loss: {avg_train_loss:.4f}')

    if (epoch + 1) % 2 == 0:
        model.eval()
        val_loss = 0.0
        val_batches = 0

        print(f'Epoch {epoch+1} finished')
        print(f'average training loss is {avg_train_loss:.4f}')

        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_targets = val_data
                val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)  # Convert inputs to FP16

                val_outputs = model(val_inputs)
                val_batch_loss = loss_function(val_outputs, val_targets)

                val_loss += val_batch_loss.item()
                val_batches += 1


        avg_val_loss = val_loss / val_batches

        print(f'Epoch {epoch+1} finished')
        print(f'Training - Loss: {avg_train_loss:.4f}')
        print(f'Validation - Loss: {avg_val_loss:.4f}')

if torch.cuda.is_available():
    torch.cuda.empty_cache()


  self.setter(val)


Using device: cuda
GPU Memory: 15.8 GB


100%|██████████| 169M/169M [00:26<00:00, 6.46MB/s] 


Starting Epoch 1
Batch 0/44, Loss: 4.6047
Epoch 1 finished
Training - Loss: 4.5423
Starting Epoch 2
Batch 0/44, Loss: 4.3966
Epoch 2 finished
Training - Loss: 4.3356
Epoch 2 finished
average training loss is 4.3356
Epoch 2 finished
Training - Loss: 4.3356
Validation - Loss: 4.2326
Starting Epoch 3
Batch 0/44, Loss: 4.2520
Epoch 3 finished
Training - Loss: 4.2309
Starting Epoch 4
Batch 0/44, Loss: 4.1746
Epoch 4 finished
Training - Loss: 4.1451
Epoch 4 finished
average training loss is 4.1451
Epoch 4 finished
Training - Loss: 4.1451
Validation - Loss: 4.0215
Starting Epoch 5
Batch 0/44, Loss: 4.0595
Epoch 5 finished
Training - Loss: 4.0858
Starting Epoch 6
Batch 0/44, Loss: 4.0886
Epoch 6 finished
Training - Loss: 4.0503
Epoch 6 finished
average training loss is 4.0503
Epoch 6 finished
Training - Loss: 4.0503
Validation - Loss: 3.8687
Starting Epoch 7
Batch 0/44, Loss: 3.9807
Epoch 7 finished
Training - Loss: 3.9509
Starting Epoch 8
Batch 0/44, Loss: 3.9163
Epoch 8 finished
Training - L

In [None]:
def evaluate_test_set(model):
    model.eval()
    correct = 0
    total = 0
    
    print("Starting evaluation...")
    
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            
            # Use autocast for consistency if you trained with it
            outputs = model(images)
            
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the network on the test images: {accuracy:.2f}%')
    return accuracy

print("\n=== Running standard evaluation ===")
standard_accuracy = evaluate_test_set(model)   
print(f'Standard Test Accuracy: {standard_accuracy:.4f} ({standard_accuracy:.2f}%)')


=== Running standard evaluation ===
Starting evaluation...
Accuracy of the network on the test images: 53.60%
Standard Test Accuracy: 53.6000 (5360.00%)
