In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
from torch import optim
import numpy as np
import time
import copy

In [2]:
class BasicConv2d(nn.Module): # convolution -> batch normalization -> relu
    def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
        super().__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels,
                      kernel_size, bias=False, **kwargs),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, x):
        x = self.conv(x)
        
        return x

In [3]:
class Stem(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv = nn.Sequential(
            BasicConv2d(3, 32, 3, stride=2, padding=0),
            BasicConv2d(32, 32, 3, stride=1, padding=0),
            BasicConv2d(32, 64, 3, stride=1, padding=1)
        )
        
        self.branch_conv = BasicConv2d(64, 96, 3, stride=2, padding=0)
        
        self.branch_pool = nn.MaxPool2d(4, stride=2, padding=1)
        
        self.branch2_conv =nn.Sequential(
            BasicConv2d(160, 64, 1,stride=1, padding=0),
            BasicConv2d(64, 96, 3, stride=1, padding=0)
        )
        
        self.branch2_conv2 = nn.Sequential(
            BasicConv2d(160, 64, 1, stride=1, padding=0),
            BasicConv2d(64, 64, (7, 1), stride=1, padding=(3, 0)),
            BasicConv2d(64, 64, (1, 7), stride=1, padding=(0, 3)),
            BasicConv2d(64, 96, 3, stride=1, padding=0)
        )
        
        self.branch3_conv = BasicConv2d(192, 192, 3, stride=2, padding=0)
        
        self.branch3_pool = nn.MaxPool2d(4, 2, 1)
        
    def forward(self, x):
        x = self.conv(x)
        x = torch.cat((self.branch_conv(x), self.branch_pool(x)), dim=1)
        x = torch.cat((self.branch2_conv(x), self.branch2_conv2(x)), dim=1)
        x = torch.cat((self.branch3_conv(x), self.branch3_pool(x)), dim=1)
        
        return x

In [4]:
class Inception_Resnet_A(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        
        self.branch_conv = BasicConv2d(in_channels, 32,1, stride=1, padding=0)
        
        self.branch_conv2 = nn.Sequential(
            BasicConv2d(in_channels, 32, 1, stride=1, padding=0),
            BasicConv2d(32, 32, 3, stride=1, padding=1)
        )
        
        self.branch_conv3 = nn.Sequential(
            BasicConv2d(in_channels, 32, 1, stride=1, padding=0),
            BasicConv2d(32, 48, 3, stride=1, padding=1),
            BasicConv2d(48, 64, 3, stride=1, padding=1)
        )
        
        self.reduction_conv = nn.Conv2d(128, 384, 1, stride=1, padding=0)
        
        self.shortcut = nn.Conv2d(in_channels, 384, 1, stride=1, padding=0)
        
        self.bn = nn.BatchNorm2d(384)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x_shortcut = self.shortcut(x)
        
        x = torch.cat((self.branch_conv(x), self.branch_conv2(x), self.branch_conv3(x)), dim=1)
        x = self.reduction_conv(x)
        
        x = self.bn(x_shortcut + x)
        x = self.relu(x)
        
        return x

In [5]:
class ReductionA(nn.Module):
    def __init__(self, in_channels, k, l, m, n):
        super().__init__()
        
        self.branch_pool = nn.MaxPool2d(3, 2)
        
        self.branch_conv = BasicConv2d(in_channels, n, 3, stride=2, padding=0)
        
        self.branch_conv2 = nn.Sequential(
            BasicConv2d(in_channels, k, 1, stride=1, padding=0),
            BasicConv2d(k, l, 3, stride=1, padding=1),
            BasicConv2d(l, m, 3, stride=2, padding=0)
        )
        
        self.output_channels = in_channels + n + m
        
    def forward(self, x):
        x = torch.cat((self.branch_pool(x), self.branch_conv(x), self.branch_conv2(x)), dim=1)
        
        return x

In [6]:
class Inception_Resnet_B(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        
        self.branch_conv = BasicConv2d(in_channels, 192, 1, stride=1, padding=0)
        
        self.branch_conv2 = nn.Sequential(
            BasicConv2d(in_channels, 128, 1, stride=1, padding=0),
            BasicConv2d(128, 160, (1, 7), stride=1, padding=(0, 3)),
            BasicConv2d(160, 192, (7, 1), stride=1, padding=(3, 0))
        )
        
        self.reduction_conv = nn.Conv2d(384, 1152, 1, stride=1, padding=0)
        
        self.shortcut = nn.Conv2d(in_channels, 1152, 1, stride=1, padding=0)
        
        self.bn = nn.BatchNorm2d(1152)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x_shortcut = self.shortcut(x)

        x = torch.cat((self.branch_conv(x), self.branch_conv2(x)), dim=1)
        x = self.reduction_conv(x) * 0.1
        
        x = self.bn(x + x_shortcut)
        x = self.relu(x)
        
        return x

In [7]:
class ReductionB(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        
        self.branch_pool = nn.MaxPool2d(3, 2)
        
        self.branch_conv = nn.Sequential(
            BasicConv2d(in_channels, 256, 1, stride=1, padding=0),
            BasicConv2d(256, 384, 3, stride=2, padding=0)
        )
        
        self.branch_conv2 = nn.Sequential(
            BasicConv2d(in_channels, 256, 1, stride=1, padding=0),
            BasicConv2d(256, 288, 3, stride=2, padding=0)
        )
        
        self.branch_conv3 = nn.Sequential(
            BasicConv2d(in_channels, 256, 1, stride=1, padding=0),
            BasicConv2d(256, 288, 3, stride=1, padding=1),
            BasicConv2d(288, 320, 3, stride=2, padding=0)
        )
        
    def forward(self, x):
        x = torch.cat((self.branch_pool(x), self.branch_conv(x), self.branch_conv2(x), self.branch_conv3(x)), dim=1)
        
        return x

In [8]:
class Inception_Resnet_C(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        
        self.branch_conv = BasicConv2d(in_channels, 192, 1, stride=1, padding=0)
        
        self.branch_conv2 = nn.Sequential(
            BasicConv2d(in_channels, 192, 1, stride=1, padding=0),
            BasicConv2d(192, 224, (1, 3), stride=1, padding=(0, 1)),
            BasicConv2d(224, 256, (3, 1), stride=1, padding=(1, 0))
        )
        
        self.reduction_conv = nn.Conv2d(448, 2144, 1, stride=1, padding=0)
        
        self.shortcut = nn.Conv2d(in_channels, 2144, 1, stride=1, padding=0)
        
        self.bn = nn.BatchNorm2d(2144)
        
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x_shortcut = self.shortcut(x)
        
        x = torch.cat((self.branch_conv(x), self.branch_conv2(x)), dim=1)
        x = self.reduction_conv(x) * 0.1
        
        x = self.bn(x_shortcut + x)
        x = self.relu(x)
        
        return x

In [9]:
class InceptionResNetV2(nn.Module):
    def __init__(self, A, B, C, k=256, l=256, m=384, n=384, num_classes=10):
        super().__init__()
        
        blocks = []
        
        blocks.append(Stem())
        
        for i in range(A):
            blocks.append(Inception_Resnet_A(384))
        blocks.append(ReductionA(384, k, l, m, n))
        
        for i in range(B):
            blocks.append(Inception_Resnet_B(1152))
        blocks.append(ReductionB(1152))
        
        for i in range(C):
            blocks.append(Inception_Resnet_C(2144))
            
        self.features = nn.Sequential(*blocks)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.dropout = nn.Dropout2d(0.2)
        
        self.linear = nn.Linear(2144, num_classes)
        
        
    def forward(self, x):
        x = self.features(x)
        
        x = self.avgpool(x)
        
        x = x.view(x.size(0), -1)
        
        x = self.dropout(x)
        
        x = self.linear(x)
        
        return x

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = InceptionResNetV2(10, 20, 10).to(device)

summary(model, (3, 299, 299), device=device.type)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 149, 149]             864
       BatchNorm2d-2         [-1, 32, 149, 149]              64
              ReLU-3         [-1, 32, 149, 149]               0
       BasicConv2d-4         [-1, 32, 149, 149]               0
            Conv2d-5         [-1, 32, 147, 147]           9,216
       BatchNorm2d-6         [-1, 32, 147, 147]              64
              ReLU-7         [-1, 32, 147, 147]               0
       BasicConv2d-8         [-1, 32, 147, 147]               0
            Conv2d-9         [-1, 64, 147, 147]          18,432
      BatchNorm2d-10         [-1, 64, 147, 147]             128
             ReLU-11         [-1, 64, 147, 147]               0
      BasicConv2d-12         [-1, 64, 147, 147]               0
           Conv2d-13           [-1, 96, 73, 73]          55,296
      BatchNorm2d-14           [-1, 96,

