In [1]:
import torch.nn as nn
import torch
import gc
from torchsummary import summary


class AlexNet(nn.Module):
    def __init__(self, num_classes):
        super(AlexNet, self).__init__()

        self.in_channels = 3
        self.num_classes = num_classes
        self.architecture = [
            (64, 11, 4, 2), 'M',
            (192, 5, 1, 2), 'M',
            (384, 3, 1, 1),
            (256, 3, 1, 1),
            (256, 3, 1, 1), 'M'
        ]

        self.conv_layers = self._init_conv_layers()
        self.fcl = nn.Sequential(
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, num_classes)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = torch.flatten(x, 1)
        x = self.fcl(x)
        return x

    def _init_conv_layers(self):
        layers = []
        in_channels = self.in_channels

        for x in self.architecture:
            if isinstance(x, str):
                layers.append(nn.MaxPool2d(kernel_size=3, stride=2))
            else:
                out_channels = x[0]

                layers += [
                    nn.Conv2d(in_channels, out_channels, kernel_size=x[1], stride=x[2], padding=x[3]),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU()
                ]

                in_channels = out_channels

        return nn.Sequential(*layers)

In [2]:
torch.cuda.empty_cache()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# classes, train, val, test = get_datasets()

model = AlexNet(num_classes=101).to(device)
gc.collect()
summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 55, 55]          23,296
       BatchNorm2d-2           [-1, 64, 55, 55]             128
              ReLU-3           [-1, 64, 55, 55]               0
         MaxPool2d-4           [-1, 64, 27, 27]               0
            Conv2d-5          [-1, 192, 27, 27]         307,392
       BatchNorm2d-6          [-1, 192, 27, 27]             384
              ReLU-7          [-1, 192, 27, 27]               0
         MaxPool2d-8          [-1, 192, 13, 13]               0
            Conv2d-9          [-1, 384, 13, 13]         663,936
      BatchNorm2d-10          [-1, 384, 13, 13]             768
             ReLU-11          [-1, 384, 13, 13]               0
           Conv2d-12          [-1, 256, 13, 13]         884,992
      BatchNorm2d-13          [-1, 256, 13, 13]             512
             ReLU-14          [-1, 256,