In [1]:
import torch
from torch import nn
import sys
sys.path.insert(0, "..")
from helper import train, test
from data import get_dataloader
import torch.nn.functional as F
from time import time

In [2]:
device = "cuda" if torch.cuda.is_available() else "mps"

# LeNet-5
![LeNet-5](https://miro.medium.com/max/1200/1*3B8iO-se13vA2QfZ4OBRSw.png)

*The dimensions of the final Fully connected layers have been modified since AlexNet was designed for ImageNet (227x227x3) images but CIFAR10 has 32x32x3 images.

In [3]:
class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()
        
        self.feature_extractor = nn.Sequential(
            # Input Dimension 32x32x3
            nn.Conv2d(in_channels=3, out_channels=96, kernel_size=11, stride=1),
            nn.ReLU(),
            # 22x22x96
            nn.MaxPool2d(kernel_size=3, stride=2),
            # 10x10x96

            nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, padding=2),
            nn.ReLU(),
            # 10x10x256
            nn.MaxPool2d(kernel_size=3, stride=2),
            # 4x4x256

            nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, padding=1),
            nn.ReLU(),
            # 4x4x384
            nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, padding=1),
            nn.ReLU(),
            # 4x4x384
            nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(),

            # 4x4x256
            nn.MaxPool2d(kernel_size=3, stride=2),
            # 1x1x256
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(in_features=256, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=128),
            nn.ReLU(),
            nn.Linear(in_features=128, out_features=10)
        )
        

    def forward(self, x):
        
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        logits = self.classifier(x)
        return logits
        

alexnet = AlexNet().to(device)
print(alexnet)

AlexNet(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 96, kernel_size=(11, 11), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=256, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=10, bias=True)
  )
)