In [1]:
import torch
import torch.nn as nn

## GoogleNet

The main hallmark of this architecture is the improved utilization of the computing resources inside the network. This was achieved by a carefully crafted design that allows for increasing the depth and width of the network while keeping the computational budget constant. To optimize quality, the architectural decisions were based on the Hebbian principle and the intuition of multi-scale processing. One particular incarnation used in our submission for ILSVRC14 is called GoogLeNet, a 22 layers deep network, the quality of which is assessed in the context of classification and detection. [Paper](https://arxiv.org/pdf/1409.4842)

#### Inception Module
<img src="https://i.ibb.co/b6TTddX/image.png" alt="image" border="0">

In [2]:
class ConvBatchRelu(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding=0):
    super().__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
                          stride=stride, padding=padding, bias=False)
    self.bn = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU(inplace=True)

  def forward(self, x):
    return self.relu(self.bn(self.conv(x)))

In [7]:
class Inception(nn.Module):
    def __init__(self, in_channels, out1x1, in3x3, out3x3, in5x5, out5x5, pool_proj, num_classes=None, aux=False):
        super().__init__()
        self.branch1 = ConvBatchRelu(in_channels, out1x1, kernel_size=1, stride=1, padding=0)

        self.branch2 = nn.Sequential(
            ConvBatchRelu(in_channels, in3x3, kernel_size=1, stride=1, padding=0),
            ConvBatchRelu(in3x3, out3x3, kernel_size=3, stride=1, padding=1),
        )

        self.branch3 = nn.Sequential(
            ConvBatchRelu(in_channels, in5x5, kernel_size=1, stride=1, padding=0),
            ConvBatchRelu(in5x5, out5x5, kernel_size=3, stride=1, padding=1),
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True),
            ConvBatchRelu(in_channels, pool_proj, kernel_size=1, stride=1, padding=0),
        )
        self.aux = aux
        self.num_classes = num_classes
        if aux:
          self.aux_branch = nn.Sequential(
              nn.AvgPool2d(kernel_size=5, stride=3),
          )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)
        out = [branch1, branch2, branch3, branch4]
        out = torch.cat(out, 1)

        if self.aux:
          aux_out = self.aux_branch(x)
          aux_out = torch.flatten(aux_out, 1)
          out2 = nn.Linear(aux_out.size(1), self.num_classes)
          aux_out = out2(aux_out)
          return out, aux_out

        else: 
            return out

#### Architecture

<img src="https://i.ibb.co/vLb6ys5/image-2024-06-13-100528355.png" alt="image-2024-06-13-100528355" border="0">

In [3]:
class Stem(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = ConvBatchRelu(3, 64, kernel_size=7, stride=2, padding=3)
    self.mp1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.norm = ConvBatchRelu(64, 64, kernel_size=1, stride=1)
    self.conv2 = ConvBatchRelu(64, 192, kernel_size=3, stride=1, padding=1)
    self.mp2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

  def forward(self, x):
    x = self.mp1(self.conv1(x))
    x = self.norm(x)
    return self.mp2(self.conv2(x))

class GGNet(nn.Module):
  def __init__(self, num_classes):
    super().__init__()
    self.stem = Stem()

    # Inception Module 3
    self.inception3a = Inception(in_channels=192, out1x1=64,
                                 in3x3=96, out3x3=128,
                                 in5x5=16, out5x5=32,
                                 pool_proj=32)

    self.inception3b = Inception(in_channels=256, out1x1=128,
                                 in3x3=128, out3x3=192,
                                 in5x5=32, out5x5=96,
                                 pool_proj=64)

    self.mp1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    # Inception Module 4
    self.inception4a = Inception(in_channels=480, out1x1=192,
                                 in3x3=96, out3x3=208,
                                 in5x5=16, out5x5=48,
                                 pool_proj=64)

    self.inception4b = Inception(in_channels=512, out1x1=160,
                                 in3x3=112, out3x3=224,
                                 in5x5=24, out5x5=64,
                                 pool_proj=64, num_classes=num_classes, aux=True)

    self.inception4c = Inception(in_channels=512, out1x1=128,
                                 in3x3=128, out3x3=256,
                                 in5x5=24, out5x5=64,
                                 pool_proj=64)

    self.inception4d = Inception(in_channels=512, out1x1=112,
                                 in3x3=144, out3x3=288,
                                 in5x5=32, out5x5=64,
                                 pool_proj=64)

    self.inception4e = Inception(in_channels=528, out1x1=256,
                                 in3x3=160, out3x3=320,
                                 in5x5=32, out5x5=128,
                                 pool_proj=128, num_classes=num_classes, aux=True)

    self.mp2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

    # Inception Module 5
    self.inception5a = Inception(in_channels=832, out1x1=256,
                                 in3x3=160, out3x3=320,
                                 in5x5=32, out5x5=128,
                                 pool_proj=128)

    self.inception5b = Inception(in_channels=832, out1x1=384,
                                 in3x3=192, out3x3=384,
                                 in5x5=48, out5x5=128,
                                 pool_proj=128)

    # Output
    self.avg_pool = nn.AvgPool2d(kernel_size=7, stride=1)
    self.dropout = nn.Dropout(0.1, True)
    self.fc = nn.Linear(1024, num_classes)


  def forward(self, x):
    x = self.stem(x)
    x = self.inception3a(x)
    x = self.inception3b(x)
    x = self.mp1(x)

    x = self.inception4a(x)
    x, aux_4b = self.inception4b(x)
    x = self.inception4c(x)
    x = self.inception4d(x)
    x, aux_4e = self.inception4e(x)
    x = self.mp2(x)

    x = self.inception5a(x)
    x = self.inception5b(x)

    x = self.avg_pool(x)
    x = self.dropout(x)
    x = torch.flatten(x, 1)
    x = self.fc(x)
    return x, aux_4b, aux_4e

In [6]:
x = torch.randn(12, 3, 224, 224)
model = GGNet(100)
y, aux1, aux2 = model(x)

y.shape, aux1.shape, aux2.shape

(torch.Size([12, 100]), torch.Size([12, 100]), torch.Size([12, 100]))