In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F
from torchsummary import summary

In [2]:
class Stem(nn.Module):
  def __init__(self):
    super(Stem , self).__init__()
    self.conv1 = nn.Conv2d(in_channels= 3 , out_channels= 64 ,kernel_size=(7,7) , stride= (2,2) , padding=(3,3))
    self.conv2 = nn.Conv2d(in_channels= 64 , out_channels= 64 ,kernel_size=(1,1) , stride= (1,1), padding=0)
    self.conv3 = nn.Conv2d(in_channels= 64 , out_channels= 192 ,kernel_size=(3,3) , stride= (1,1), padding=(1,1))
    self.maxPool = nn.MaxPool2d(kernel_size=(3,3) , stride=(2,2) , padding=1)

  def forward(self , x):
    out = self.conv1(x)
    out = F.relu(out)

    out = self.maxPool(out)

    out = self.conv2(out)
    out = F.relu(out)

    out = self.conv3(out)
    out = F.relu(out)

    out = self.maxPool(out)

    return out

In [3]:
class InceptionBlock(nn.Module):
  def __init__(self , nbr_channels ,nbr_kernels):
    super(InceptionBlock , self).__init__()
    k_1 , k_2_1 , k_2_2 , k_3_1 , k_3_2 , k_4 = nbr_kernels

    self.branch1 = nn.Sequential(
        nn.Conv2d(in_channels = nbr_channels , out_channels= k_1 , kernel_size=(1,1) , stride=(1,1)),
        nn.ReLU()
    )

    self.branch2 = nn.Sequential(
        nn.Conv2d(in_channels= nbr_channels , out_channels= k_2_1 , kernel_size= (1,1), stride=(1,1)),
        nn.ReLU(),
        nn.Conv2d(in_channels= k_2_1 , out_channels= k_2_2 , kernel_size= (3,3) , stride=(1,1) , padding=(1,1)),
        nn.ReLU()
    )

    self.branch3 = nn.Sequential(
        nn.Conv2d(in_channels= nbr_channels , out_channels= k_3_1 , kernel_size= (1,1) , stride=(1,1)),
        nn.ReLU(),
        nn.Conv2d(in_channels= k_3_1 , out_channels= k_3_2 , kernel_size= (5,5),  stride=(1,1) , padding = (2,2)),
        nn.ReLU()
    )

    self.branch4 = nn.Sequential(
        nn.MaxPool2d(kernel_size=(3,3) , stride=(1,1) , padding=(1,1)),
        nn.Conv2d(in_channels= nbr_channels , out_channels= k_4 , kernel_size= (1,1), stride=(1,1)),
        nn.ReLU()
    )

  def forward(self , x):
    out1 = self.branch1(x)
    out2 = self.branch2(x)
    out3 = self.branch3(x)
    out4 = self.branch4(x)

    return torch.cat([out1 ,out2 , out3 , out4] , 1)


In [4]:
class GoogleNet(nn.Module):
  def __init__(self):
    super(GoogleNet , self).__init__()

    # Output Size : 28*28*192
    self.stem = Stem()

    # Output Size : 28*28*256
    self.Inception1_1 = InceptionBlock(192,[64 , 96 , 128 , 16 , 32 , 32])

    # Output Size : 28*28*480
    self.Inception1_2 = InceptionBlock(256,[128 , 128 , 192 , 32 , 96 , 64])

    # Output Size : 14*14*512
    self.Inception2 = InceptionBlock(480,[192 , 96 , 208 , 16 , 48 , 64])

    # Output Size : 14*14*512
    self.Inception3_1 = InceptionBlock(512,[160 , 112 , 224 , 24 , 64 , 64])
    # Output Size : 14*14*512
    self.Inception3_2 = InceptionBlock(512,[128 , 128 , 256 , 24 , 64 , 64])
    # Output Size : 14*14*528
    self.Inception3_3 = InceptionBlock(512,[112 , 144 , 288 , 32 , 64 , 64])

    # Output Size : 14*14*832
    self.Inception4 = InceptionBlock(528,[256 , 160 , 320 , 32 , 128 , 128])


    # Output Size : 7*7*832
    self.Inception5_1 = InceptionBlock(832,[256 , 160 , 320 , 32 , 128 , 128])

    # Output Size : 7*7*1024
    self.Inception5_2 = InceptionBlock(832,[384 , 192 , 384 , 48 , 128 , 128])

    self.maxPool = nn.MaxPool2d(kernel_size=(3,3) , stride=(2,2) , padding=1)
    self.avgPool = nn.AvgPool2d(kernel_size=(7,7) , stride=(1,1))

    self.fc1 = nn.Linear(in_features=1024 , out_features =1000 )
    self.fc2 = nn.Linear(in_features=1000 , out_features =1000 )

    self.auxiliary_classifier_1 = nn.Sequential(
        nn.AvgPool2d(kernel_size=(5,5) , stride=(3,3)),
        nn.Conv2d(in_channels=512 , out_channels=128 , kernel_size=(1,1) , stride=(1,1)),
        nn.ReLU(),
        nn.Linear(in_features = 4, out_features=1024),
        nn.ReLU(),
        nn.Linear(in_features=1024 , out_features=1000),
        nn.Softmax()
    )

    self.auxiliary_classifier_2 = nn.Sequential(
        nn.AvgPool2d(kernel_size=(5,5) , stride=(3,3) , padding = (1,1)),
        nn.Conv2d(in_channels=528 , out_channels=128 , kernel_size=(1,1) , stride=(1,1)),
        nn.ReLU(),
        nn.Linear(in_features = 4, out_features=1024),
        nn.ReLU(),
        nn.Linear(in_features=1024 , out_features=1000),
        nn.Softmax()
    )

  def forward(self , x):

    out = self.stem(x)

    out = self.Inception1_1(out)
    out = self.Inception1_2(out)

    out = self.maxPool(out)

    out = self.Inception2(out)

    aux1 = self.auxiliary_classifier_1(out)

    out = self.Inception3_1(out)
    out = self.Inception3_2(out)
    out = self.Inception3_3(out)

    aux2 = self.auxiliary_classifier_2(out)

    out = self.Inception4(out)

    out = self.maxPool(out)

    out = self.Inception5_1(out)
    out = self.Inception5_2(out)

    out = self.avgPool(out)

    out = out.reshape(out.shape[0] , -1)

    out = self.fc1(out)
    out = F.relu(out)
    out = nn.Dropout(p=0.4)(out)
    out = self.fc2(out)
    out = F.softmax(out)

    return out


In [5]:
model = GoogleNet()
summary(model , (3 , 224 , 224))

  return self._call_impl(*args, **kwargs)
  out = F.softmax(out)


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,472
         MaxPool2d-2           [-1, 64, 56, 56]               0
            Conv2d-3           [-1, 64, 56, 56]           4,160
            Conv2d-4          [-1, 192, 56, 56]         110,784
         MaxPool2d-5          [-1, 192, 28, 28]               0
              Stem-6          [-1, 192, 28, 28]               0
            Conv2d-7           [-1, 64, 28, 28]          12,352
              ReLU-8           [-1, 64, 28, 28]               0
            Conv2d-9           [-1, 96, 28, 28]          18,528
             ReLU-10           [-1, 96, 28, 28]               0
           Conv2d-11          [-1, 128, 28, 28]         110,720
             ReLU-12          [-1, 128, 28, 28]               0
           Conv2d-13           [-1, 16, 28, 28]           3,088
             ReLU-14           [-1, 16,