In [None]:
# Imports
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import torchvision.transforms as transforms
import torch.nn.functional as func
import torch.utils.model_zoo as model_zoo

In [None]:
# The argument for ResNet: Having really deep networks struggle with optimization
# issues (even when batch normalization is used) and it was argued that this
# underfitting was not due to vanishing gradients (happens in very deep networks)
# as it is seen in even deep networks

# If the block is the first block of the layer, then the out_ch is double
# the in_channels, otherwise, they are the same

# Downsampling achieved by increasing the stride as opposed to maxpooling
# Stride of 2 is used in the first block of each layer to downsample
class ResBlock(torch.nn.Module):
    def __init__(self, in_ch, out_ch, stride):
        # Kernel size of 3 for all and padding of 1 for all 
        super(ResBlock, self).__init__()
        self.k_size = 3
        self.padding = 1
        self.conv_bn_1 = torch.nn.Sequential(
            torch.nn.Conv2d(in_ch, out_ch, self.k_size, stride, self.padding),
            torch.nn.BatchNorm2d(out_ch)
        )
        self.conv_bn_2 = torch.nn.Sequential(
            torch.nn.Conv2d(out_ch, out_ch, self.k_size, stride, self.padding),
            torch.nn.BatchNorm2d(out_ch)
        )
  
  # Forward prop for the residual block
    def forward(self, x): 
        conv_1_output = self.conv_bn_1(x)
        actv_output_1 = func.relu(conv_1_output)
        conv_2_output = self.conv_bn_2(actv_output_1)
        return func.relu(conv_2_output + x)


In [None]:
# ResNet Structure: trained on ImageNet
# ResNet34
class ResNet(torch.nn.Module):
    def __init__(self, n_classes):
        # May have to adjust if 1) Input image is greyscale 2) Input size is different
        # given the formula O = (I + 2P - K)/S + 1
        super(ResNet, self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channels=3 , out_channels=64, kernel_size=7,
                                     stride=2, padding=3)
        self.bn = torch.nn.BatchNorm2d(64)
        self.maxpool1 = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        # Returns a list with all the blocks 
        self.first_layer = self.add_resnet_layer(in_ch=64, out_ch=64, layer_size=3)
        self.second_layer = self.add_resnet_layer(in_ch=64, out_ch=128, layer_size=4)
        self.third_layer = self.add_resnet_layer(in_ch=128, out_ch=256, layer_size=6)
        self.fourth_layer = self.add_resnet_layer(in_ch=256, out_ch=512, layer_size=3)
        # Keeps channels put finds the global average as opposed to the max of a 
        # 2 by 2 window. Finaly output is 1by1byn_channels
        self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
        self.output_layer = torch.nn.Linear(512, n_classes)

    # layer_size = number of times the block repeats  
    def add_resnet_layer(self, in_ch, out_ch, layer_size):
        first_block = ResBlock(in_ch, out_ch, 2)
        blocks = [first_block]
        for i in range(1, layer_size):
          blocks.append(ResBlock(out_ch, out_ch, 1))
        # Unpack torch modules into Sequential layers
        blocks = torch.nn.Sequential(*blocks)
        return blocks
  
    def forward(self, x):
        # Go through convolutional layer
        out = self.conv1(x)
        out = self.bn(out)
        out = func.relu(out)
        out = self.maxpool1(out)

        # Basic Block layers
        out = self.first_layer(out)
        out = self.second_layer(out)
        out = self.third_layer(out)
        out = self.fourth_layer(out)

        # Average pooling and fully connected layers
        out = self.avgpool(out)
        out = Variable(out.view(-1, 512))
        return self.output_layer(out)