In [13]:
import torch
import torch.nn as nn
from torchsummary import summary

In [14]:
class Conv_Bn_Relu(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
    super(Conv_Bn_Relu, self).__init__()

    self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
    self.bn = nn.BatchNorm2d(num_features=out_channels)
    self.act = nn.ReLU()

  def forward(self, x):
    x = self.conv(x)
    x = self.bn(x)
    x = self.act(x)
    return x


In [15]:
class StemBlock(nn.Module):
  def __init__(self):
    super(StemBlock, self).__init__()
    self.conv1 = Conv_Bn_Relu(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=0)
    self.conv2 = Conv_Bn_Relu(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=0)
    self.conv3 = Conv_Bn_Relu(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
    self.conv4 = Conv_Bn_Relu(in_channels=64, out_channels=80, kernel_size=1, stride=1, padding=0)
    self.conv5 = Conv_Bn_Relu(in_channels=80, out_channels=192, kernel_size=3, stride=1, padding=0)
    self.maxpool = nn.MaxPool2d(kernel_size=(3,3), stride=2)

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.maxpool(x)
    x = self.conv4(x)
    x = self.conv5(x)
    x = self.maxpool(x)
    return x

In [11]:
summary(model=StemBlock(), input_size=(3,299,299), batch_size=0)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [0, 32, 149, 149]             896
       BatchNorm2d-2          [0, 32, 149, 149]              64
              ReLU-3          [0, 32, 149, 149]               0
      Conv_Bn_Relu-4          [0, 32, 149, 149]               0
            Conv2d-5          [0, 32, 147, 147]           9,248
       BatchNorm2d-6          [0, 32, 147, 147]              64
              ReLU-7          [0, 32, 147, 147]               0
      Conv_Bn_Relu-8          [0, 32, 147, 147]               0
            Conv2d-9          [0, 64, 147, 147]          18,496
      BatchNorm2d-10          [0, 64, 147, 147]             128
             ReLU-11          [0, 64, 147, 147]               0
     Conv_Bn_Relu-12          [0, 64, 147, 147]               0
        MaxPool2d-13            [0, 64, 73, 73]               0
           Conv2d-14            [0, 80,

In [19]:
class Inception_block_A(nn.Module):
  """
      From the paper, figure 5 inception module.
  """
  def __init__(self, in_channels, filters):
    super(Inception_block_A, self).__init__()

    self.block1 = nn.Sequential(
        Conv_Bn_Relu(in_channels=in_channels, out_channels=64, kernel_size=1, stride=1, padding=0),
        Conv_Bn_Relu(in_channels=64, out_channels=96, kernel_size=3, stride=1, padding=1),
        Conv_Bn_Relu(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1))

    self.block2 = nn.Sequential(
        Conv_Bn_Relu(in_channels=in_channels, out_channels=48, kernel_size=1, stride=1, padding=0),
        Conv_Bn_Relu(in_channels=48, out_channels=64, kernel_size=3, stride=1, padding=1))

    self.block3 = nn.Sequential(
        nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
        Conv_Bn_Relu(in_channels=in_channels, out_channels=64, kernel_size=1, stride=1, padding=0))

    self.block4 = Conv_Bn_Relu(in_channels=in_channels, out_channels=64, kernel_size=1, stride=1, padding=0)


  def forward(self, x):
    x1 = self.block1(x)
    x2 = self.block2(x)
    x3 = self.block3(x)
    x4 = self.block4(x)

    return torch.cat([x1, x2, x3, x4], 1)

In [47]:
class Inception_block_B(nn.Module):
  """
      From the paper, figure 6 inception module.
  """
  def __init__(self, in_channels, filters):
    super(Inception_block_B, self).__init__()

    self.block1 = nn.Sequential(
      Conv_Bn_Relu(in_channels=in_channels, out_channels=filters, kernel_size=1, stride=1, padding=0),
      Conv_Bn_Relu(in_channels=filters, out_channels=filters, kernel_size=(1,7), stride=1, padding=(0,3)),
      Conv_Bn_Relu(in_channels=filters, out_channels=filters, kernel_size=(7,1), stride=1, padding=(3,0)),
      Conv_Bn_Relu(in_channels=filters, out_channels=filters, kernel_size=(1,7), stride=1, padding=(0,3)),
      Conv_Bn_Relu(in_channels=filters, out_channels=192, kernel_size=(7,1), stride=1, padding=(3,0)))

    self.block2 = nn.Sequential(
        Conv_Bn_Relu(in_channels=in_channels, out_channels=filters, kernel_size=1, stride=1, padding=0),
        Conv_Bn_Relu(in_channels=filters, out_channels=filters, kernel_size=(1,7), stride=1, padding=(0,3)),
        Conv_Bn_Relu(in_channels=filters, out_channels=192, kernel_size=(7,1), stride=1, padding=(3,0)))


    self.block3 = nn.Sequential(
        nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
        Conv_Bn_Relu(in_channels=in_channels, out_channels=192, kernel_size=1, stride=1, padding=0))

    self.block4 = Conv_Bn_Relu(in_channels=in_channels, out_channels=192, kernel_size=1, stride=1, padding=0)

  def forward(self, x):
    x1 = self.block1(x)
    x2 = self.block2(x)
    x3 = self.block3(x)
    x4 = self.block4(x)

    return torch.cat([x1, x2, x3, x4], 1)


In [41]:
class Inception_block_C(nn.Module):
  """
    From the paper, figure 7 inception module.
  """
  def __init__(self, in_channels):
    super(Inception_block_C, self).__init__()

    self.block1 = Conv_Bn_Relu(in_channels, 320, 1, 1, 0)

    self.block2 = nn.Sequential(
        nn.AvgPool2d(kernel_size=(3,3), stride=1, padding=1),
        Conv_Bn_Relu(in_channels, 192, 1, 1, 0)
    )

    self.block3 = Conv_Bn_Relu(in_channels, 384, 1, 1, 0)
    self.block_3a = Conv_Bn_Relu(384, 384, (1,3), 1, (0,1))
    self.block_3b = Conv_Bn_Relu(384, 384, (3,1), 1, (1,0))

    self.block4 = nn.Sequential(
        Conv_Bn_Relu(in_channels, 448, 1, 1, 0),
        Conv_Bn_Relu(448, 384, 3, 1, 1))
    self.block_4a = Conv_Bn_Relu(384, 384, (1,3), 1, (0,1))
    self.block_4b =  Conv_Bn_Relu(384, 384, (3,1), 1, (1,0))

  def forward(self, x):
    x1 = self.block1(x)
    x2 = self.block2(x)

    x3 = self.block3(x)
    x_3a = self.block_3a(x3)
    x_3b = self.block_3b(x3)
    x3 = torch.cat([x_3a, x_3b],1)

    x4 = self.block4(x)
    x_4a = self.block_4a(x4)
    x_4b = self.block_4b(x4)
    x4 = torch.cat([x_4a, x_4b],1)

    return torch.cat([x1, x2, x3, x4],1)



In [42]:
class Reduction_block_A(nn.Module):
  def __init__(self, in_channels):
    super(Reduction_block_A, self).__init__()

    self.block1 = nn.Sequential(
        Conv_Bn_Relu(in_channels , 64 , 1 , 1 , 0),
        Conv_Bn_Relu(64 , 96 , 3 , 1 , 1),
        Conv_Bn_Relu(96 , 96 , 3 , 2 , 0)
    )

    self.block2 = Conv_Bn_Relu(in_channels , 384 , 3 , 2 , 0)

    self.block3 = nn.MaxPool2d(kernel_size=(3,3) , stride=2 , padding=0)

  def forward(self , x):

   x1 = self.block1(x)
   x2 = self.block2(x)
   x3 = self.block3(x)

   return torch.cat([x1, x2, x3] , 1)


In [43]:
class Reduction_block_B(nn.Module):
  def __init__(self, in_channels):
    super(Reduction_block_B, self).__init__()

    self.block1 = nn.Sequential(
        Conv_Bn_Relu(in_channels , 192 , 1 , 1 , 0),
        Conv_Bn_Relu(192 , 192 , (1,7) , 1 , (0,3)),
        Conv_Bn_Relu(192 , 192 , (7,1) , 1 , (3,0)),
        Conv_Bn_Relu(192 , 192 , 3 , 2 , 0)
    )

    self.block2 = nn.Sequential(
        Conv_Bn_Relu(in_channels , 192 , 1 , 1 , 0),
        Conv_Bn_Relu(192 , 320 , 3 , 2 , 0)
    )

    self.block3 = nn.MaxPool2d(kernel_size=(3,3) , stride=2 )

  def forward(self , x):

    x1 = self.block1(x)
    x2 = self.block2(x)
    x3 = self.block3(x)

    out = torch.cat([x1, x2, x3] , 1)

    return out

In [44]:
class Aux_Block(nn.Module):

  def __init__(self , in_channels):
    super(Aux_Block , self).__init__()

    self.avgPool = nn.AvgPool2d(kernel_size=(5,5) , stride=3 , padding=0)
    self.conv1 = Conv_Bn_Relu(in_channels , 128 , 1 , 1 , 0)
    self.conv2 = Conv_Bn_Relu(128 , 768 , 5 , 1 , 0)
    self.fc1 = nn.Linear(in_features= 768 , out_features= 1024)
    self.fc2 = nn.Linear(in_features= 1024 , out_features= 1000)

  def forward(self , x):

    out = self.avgPool(x)
    out = self.conv1(out)
    out = self.conv2(out)
    out = torch.flatten(out , 1)
    out = self.fc1(out)
    out = nn.ReLU()(out)
    out = self.fc2(out)
    out = nn.Softmax()(out)

    return out

In [50]:
class InceptionV3(nn.Module):
  def __init__(self):
    super(InceptionV3 , self).__init__()
    self.stem = StemBlock()

    self.inceptionA_1 = Inception_block_A(192 , 32)
    self.inceptionA_2 = Inception_block_A(288 , 64)
    self.inceptionA_3 = Inception_block_A(288 , 64)

    self.reductionA = Reduction_block_A(288)

    self.inceptionB_1 = Inception_block_B(768 , 128)
    self.inceptionB_2 = Inception_block_B(768 , 160)
    self.inceptionB_3 = Inception_block_B(768 , 160)
    self.inceptionB_4 = Inception_block_B(768 , 192)

    self.aux = Aux_Block(768)

    self.reductionB = Reduction_block_B(768)

    self.inceptionC_1 = Inception_block_C(1280)
    self.inceptionC_2 = Inception_block_C(2048)

    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.fc1 = nn.Linear(in_features=2048 ,out_features= 2048)
    self.fc2 = nn.Linear(in_features=2048 , out_features= 1000)

  def forward(self , x):

    out = self.stem(x)

    out = self.inceptionA_1(out)
    out = self.inceptionA_2(out)
    out = self.inceptionA_3(out)

    out = self.reductionA(out)

    out = self.inceptionB_1(out)
    out = self.inceptionB_2(out)
    out = self.inceptionB_3(out)
    out = self.inceptionB_4(out)

    # aux = self.aux(out)

    out = self.reductionB(out)

    out = self.inceptionC_1(out)
    out = self.inceptionC_2(out)

    out = self.avgpool(out)
    out = out.reshape(out.shape[0] , -1)

    out = self.fc1(out)
    out = nn.ReLU()(out)

    out = self.fc2(out)
    out = nn.Softmax()(out)

    return out

In [52]:
# summary(model=InceptionV3(), input_size=(3,299,299), batch_size=0)

Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [56]:
# !pip install torchinfo
from torchinfo import summary
summary(InceptionV3(), input_size=(3,299,299), batch_dim=0)

  return self._call_impl(*args, **kwargs)


Layer (type:depth-idx)                   Output Shape              Param #
InceptionV3                              [1, 1000]                 4,371,048
├─StemBlock: 1-1                         [1, 192, 35, 35]          --
│    └─Conv_Bn_Relu: 2-1                 [1, 32, 149, 149]         --
│    │    └─Conv2d: 3-1                  [1, 32, 149, 149]         896
│    │    └─BatchNorm2d: 3-2             [1, 32, 149, 149]         64
│    │    └─ReLU: 3-3                    [1, 32, 149, 149]         --
│    └─Conv_Bn_Relu: 2-2                 [1, 32, 147, 147]         --
│    │    └─Conv2d: 3-4                  [1, 32, 147, 147]         9,248
│    │    └─BatchNorm2d: 3-5             [1, 32, 147, 147]         64
│    │    └─ReLU: 3-6                    [1, 32, 147, 147]         --
│    └─Conv_Bn_Relu: 2-3                 [1, 64, 147, 147]         --
│    │    └─Conv2d: 3-7                  [1, 64, 147, 147]         18,496
│    │    └─BatchNorm2d: 3-8             [1, 64, 147, 147]         128