In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim

In [None]:
class MyAlexNet(nn.Module):
    def __init__(
        self,
        image_size: int = 224,
        image_channels: int = 3,
        num_classes: int = 43,
    ) -> None:
        super().__init__()
        self.input_size = (image_channels, image_size, image_size)
        self.conv1 = nn.Conv2d(image_channels, 96, kernel_size=11, stride=4)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv2 = nn.LazyConv2d(256, kernel_size=5, padding=2)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.conv3 = nn.LazyConv2d(384, kernel_size=3, padding=1)
        self.conv4 = nn.LazyConv2d(384, kernel_size=3, padding=1)
        self.conv5 = nn.LazyConv2d(256, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2)
        self.flatten = nn.Flatten()
        self.fc1 = nn.LazyLinear(4096)
        self.fc2 = nn.LazyLinear(4096)
        self.fc3 = nn.LazyLinear(num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        print(f'input: {x.shape}')
        x = F.relu(self.conv1(x))
        print(f'conv1: {x.shape}')
        x = self.pool1(x)
        print(f'pool1: {x.shape}')
        x = F.relu(self.conv2(x))
        print(f'conv2: {x.shape}')
        x = self.pool2(x)
        print(f'pool2: {x.shape}')
        x = F.relu(self.conv3(x))
        print(f'conv3: {x.shape}')
        x = F.relu(self.conv4(x))
        print(f'conv4: {x.shape}')
        x = self.pool3(x)
        print(f'pool3: {x.shape}')
        x = F.relu(self.conv5(x))
        print(f'conv5: {x.shape}')
        x = self.flatten(x)
        print(f'flatten: {x.shape}')
        x = F.relu(self.fc1(x))
        print(f'fc1: {x.shape}')
        x = F.relu(self.fc2(x))
        print(f'fc2: {x.shape}')
        x = self.fc3(x)
        print(f'fc3: {x.shape}')
        return x

    def summary(self) -> None:
        try:
            from torchinfo import summary

            return summary(self, (1, *self.input_size), device='cpu')
        except Exception as e:
            print(f'Error: {e}')
            print(self)

In [30]:
x = torch.randn(1, 3, 224, 224)
model = MyAlexNet()
model.summary()

input: torch.Size([1, 3, 224, 224])
conv1: torch.Size([1, 96, 54, 54])
pool1: torch.Size([1, 96, 26, 26])
conv2: torch.Size([1, 256, 26, 26])
pool2: torch.Size([1, 256, 12, 12])
conv3: torch.Size([1, 384, 12, 12])
conv4: torch.Size([1, 384, 12, 12])
pool3: torch.Size([1, 384, 5, 5])
conv5: torch.Size([1, 256, 5, 5])
flatten: torch.Size([1, 6400])
fc1: torch.Size([1, 4096])
fc2: torch.Size([1, 4096])
fc3: torch.Size([1, 43])


Layer (type:depth-idx)                   Output Shape              Param #
MyAlexNet                                [1, 43]                   --
├─Conv2d: 1-1                            [1, 96, 54, 54]           34,944
├─MaxPool2d: 1-2                         [1, 96, 26, 26]           --
├─Conv2d: 1-3                            [1, 256, 26, 26]          614,656
├─MaxPool2d: 1-4                         [1, 256, 12, 12]          --
├─Conv2d: 1-5                            [1, 384, 12, 12]          885,120
├─Conv2d: 1-6                            [1, 384, 12, 12]          1,327,488
├─MaxPool2d: 1-7                         [1, 384, 5, 5]            --
├─Conv2d: 1-8                            [1, 256, 5, 5]            884,992
├─Flatten: 1-9                           [1, 6400]                 --
├─Linear: 1-10                           [1, 4096]                 26,218,496
├─Linear: 1-11                           [1, 4096]                 16,781,312
├─Linear: 1-12                           [1