In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [None]:
'''
    ResNet(2016)
'''

class ResBlock(nn.Module):
    def __init__(self, input_channels, out_channels, use_1x1conv=False, strides=1):
        super(ResBlock, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, out_channels,
                               kernel_size=3, padding=1, strides=strides)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                              kernel_size=3, padding=1, strides=strides)
        self.conv3 = None
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, out_channels,
                                  kernel_size=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        Y = F.relu(self.bn1(self.conv1(x)))
        
        Y = self.bn2(self.conv2(Y))
        
        if self.conv3:
            Y += self.conv3(x)
            
        return F.relu(Y)
        

In [None]:
'''
    GoogleNet(2014)--Inception Block
'''

class Inception(nn.Module):
    def __init__(self, input_channels, c1, c2, c3, c4):
        super(Inception, self).__init__()
        
        self.path_1 = nn.Conv2d(input_channels, c1,
                                kernel_size=1)
        

        self.path_2_1 = nn.Conv2d(input_channels, c2[0],
                                  kernel_size=1)
        self.path_2_2 = nn.Conv2d(c2[0], c2[1],
                                  kernel_size=3, padding=1)
        
        
        self.path_3_1 = nn.Conv2d(input_channels, c3[0],
                                  kernel_size=1)
        self.path_3_2 = nn.Conv2d(c3[0], c3[1], 
                                  kernel_size=5, padding=2)
        
        self.path_4_1 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.path_4_2 = nn.Conv2(input_channels, c4,
                                 kernel_size=1)
        
        
    def forward(self, x):
        Y1 = F.relu(self.path_1(x))
        
        Y2 = F.relu(self.path_2_1(x))
        Y2 = F.relu(self.path_2_2(Y2))
        
        Y3 = F.relu(self.path_3_1(x))
        Y3 = F.relu(self.path_3_2(Y3))
        
        Y4 = F.relu(self.path_4_1(x))
        Y4 = F.relu(self.path_4_2(Y4))
        
        return torch.cat((Y1, Y2, Y3, Y4), dim=1)
        
        
        
