<center><h1>GoogLeNet</h1></center>

<center><p><a href="http://arxiv.org/abs/1409.4842">Going Deeper with Convolutions</a></p></center>

<img src="https://www.mdpi.com/ijms/ijms-22-07721/article_deploy/html/images/ijms-22-07721-g002.png" width="600"/>

In [1]:
from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn, Tensor

# Blocks

## Basic Conv Layer

In [2]:
class BasicConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)

    def forward(self, x):
        x = self.conv(x)
        return F.relu(x, inplace=True)

## Inception

In [3]:
class Inception(nn.Module):
    def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):
        super().__init__()

        self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)

        self.branch2 = nn.Sequential(
            BasicConv2d(in_channels, ch3x3red, kernel_size=1),
            BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1),
        )

        self.branch3 = nn.Sequential(
            BasicConv2d(in_channels, ch5x5red, kernel_size=1),
            # In torchvision.models.googlenet,
            # they use kernel_size=3 instead of kernel_size=5, which is a known bug.
            # Please see https://github.com/pytorch/vision/issues/906 for details.
            BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2),
        )

        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            BasicConv2d(in_channels, pool_proj, kernel_size=1),
        )

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)

        outputs = [branch1, branch2, branch3, branch4]
        return torch.cat(outputs, 1)

## Inception Auxiliary

In [4]:
class InceptionAux(nn.Module):
    def __init__(self, in_channels, num_classes, dropout=0.7):
        super().__init__()
        self.conv = BasicConv2d(in_channels, 128, kernel_size=1)

        self.fc1 = nn.Linear(2048, 1024)
        self.fc2 = nn.Linear(1024, num_classes)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        # aux1: 512 x 14 x 14 -> 512 x 4 x 4 -> 128 x 4 x 4 -> 2048 -> 1024 -> 1000
        # aux2: 528 x 14 x 14 -> 528 x 4 x 4 -> 128 x 4 x 4 -> 2048 -> 1024 -> 1000
        x = F.adaptive_avg_pool2d(x, (4, 4))
        x = self.conv(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x), inplace=True)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# GoogLeNet

In [5]:
class GoogLeNet(nn.Module):
    def __init__(self, num_classes=1000, aux_logits=True, init_weights=False, dropout=0.2, dropout_aux=0.7):
        super().__init__()
        self.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)  # 224 -> 112
        self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)  # 112 -> 56
        self.conv2 = BasicConv2d(64, 64, kernel_size=1)
        self.conv3 = BasicConv2d(64, 192, kernel_size=3, padding=1)
        self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)  # 56 -> 28

        self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)  # 64+128+32+32=256
        self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)  # 128+192+96+64=480
        self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)  # 28 -> 14

        self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)  # 192+208+48+64=512
        self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)  # 160+224+64+64=512
        self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)  # 128+256+64+64=512
        self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)  # 112+288+64+64=528
        self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)  # 256+320+128+128=832
        self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 14 -> 7

        self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)  # 256+320+128+128=832
        self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)  # 384+384+128+128=1024

        if aux_logits:
            self.aux1 = InceptionAux(512, num_classes, dropout=dropout_aux)
            self.aux2 = InceptionAux(528, num_classes, dropout=dropout_aux)
        else:
            self.aux1 = None
            self.aux2 = None

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.dropout = nn.Dropout(p=dropout)
        self.fc = nn.Linear(1024, num_classes)

        if init_weights:
            self._initialize_weight()

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.conv3(x)
        x = self.maxpool2(x)

        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.maxpool3(x)

        x = self.inception4a(x)
        aux1: Optional[Tensor] = None
        if self.training and self.aux1:
            aux1 = self.aux1(x)

        x = self.inception4b(x)
        x = self.inception4c(x)
        x = self.inception4d(x)
        aux2: Optional[Tensor] = None
        if self.training and self.aux2:
            aux2 = self.aux2(x)

        x = self.inception4e(x)
        x = self.maxpool4(x)

        x = self.inception5a(x)
        x = self.inception5b(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)

        return x, aux1, aux2

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

# Summary

## Data

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

data = torch.randn((32, 3, 224, 224)).to(device)

## GoogLeNet

In [7]:
from torchkeras import summary

net = GoogLeNet(num_classes=1000).to(device)

summary(net, input_data=data)
del net

--------------------------------------------------------------------------
Layer (type)                            Output Shape              Param #
Conv2d-1                          [-1, 64, 112, 112]                9,472
MaxPool2d-2                         [-1, 64, 56, 56]                    0
Conv2d-3                            [-1, 64, 56, 56]                4,160
Conv2d-4                           [-1, 192, 56, 56]              110,784
MaxPool2d-5                        [-1, 192, 28, 28]                    0
Conv2d-6                            [-1, 64, 28, 28]               12,352
Conv2d-7                            [-1, 96, 28, 28]               18,528
Conv2d-8                           [-1, 128, 28, 28]              110,720
Conv2d-9                            [-1, 16, 28, 28]                3,088
Conv2d-10                           [-1, 32, 28, 28]               12,832
MaxPool2d-11                       [-1, 192, 28, 28]                    0
Conv2d-12                           [