In [1]:
import torch
import torch.nn as nn
from collections import namedtuple

In [25]:
class ResNet(nn.Module):
    def __init__(self,config,outdim):
        super().__init__()
        block,n_blocks,channels=config
        self.in_channels=channels[0]
        assert len(n_blocks) == len(channels)==4

        ## stem layer
        self.conv1=nn.Conv2d(in_channels=3,out_channels=self.in_channels,kernel_size=(7,7),stride=2,padding=3,bias=False)
        self.bn1=nn.BatchNorm2d(self.in_channels)
        self.relu=nn.ReLU(inplace=False)
        self.maxpool2d=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)

        ### first block
        self.layer1=self.get_resnet_layer(block,n_blocks[0],channels[0])
        ### second block
        self.layer2=self.get_resnet_layer(block,n_blocks[1],channels[1])
        ### third block
        self.layer3=self.get_resnet_layer(block,n_blocks[2],channels[2])
        ### 4th layer
        self.layer4=self.get_resnet_layer(block,n_blocks[3],channels[3])

        ### avgpool
        self.avgpool=nn.AdaptiveAvgPool2d((1,1))
        self.fc=nn.Linear(self.in_channels,outdim)


    def get_resnet_layer(self,block,n_blocks,channels,stride=1):
        layers=[]
        if self.in_channels != block.expansion*channels:
            downsample=True
        else:
            downsample=False

        layers.append(block(self.in_channels,channels,stride,downsample))
        for i in range(1,n_blocks):
            layers.append(block(block.expansion*channels,channels))
            self.in_channels=block.expansion*channels
        
        return nn.Sequential(*layers)


    def forward(self,x):
        # Stem layer forward pass
        x=self.conv1(x)
        x=self.bn1(x)
        x=self.relu(x)
        x=self.maxpool2d(x)
        x=self.layer1(x)
        x=self.layer2(x)
        x=self.layer3(x)
        x=self.layer4(x)
        x=self.avgpool(x)
        h=x.view(x.shape[0],-1)
        x=self.fc(h)
        return x

In [26]:
class BasicBlock(nn.Module):
    expansion=1
    def __init__(self,in_channels,out_channels,stride=1,downsample=False):
        super().__init__()
        self.conv1=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=(3,3),stride=stride,padding=1,bias=False)
        self.bn1=nn.BatchNorm2d(num_features=out_channels)
        self.relu=nn.ReLU(inplace=False)
        self.conv2=nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,stride=1,padding=1,bias=False)
        self.bn2=nn.BatchNorm2d(num_features=out_channels)
        if downsample:
            conv=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,stride=stride,bias=False)
            bn=nn.BatchNorm2d(num_features=out_channels)
            downsample=nn.Sequential(conv,bn)
        else:
            downsample=None

    def forward(self,x):
        i=x

        x=self.conv1(x)
        x=self.bn1(x)
        x=self.relu(x)

        x=self.conv2(x)
        x=self.bn2(x)
        if self.downsample is not None:
            i=self.downsample(i)
        x+=i
        return x





In [27]:
class BottleNeck(nn.Module):
    expansion=4
    def __init__(self,in_channels,out_channels,stride=1,downsample=False):
        super().__init__()
        self.conv1=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=(1,1),stride=1,bias=False)
        self.bn1=nn.BatchNorm2d(num_features=out_channels)
        self.relu=nn.ReLU(inplace=False)

        self.conv2=nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=(3,3),stride=stride,padding=1,bias=False)
        self.bn2=nn.BatchNorm2d(num_features=out_channels)
        
        self.conv3=nn.Conv2d(in_channels=out_channels,out_channels=self.expansion*out_channels,kernel_size=1,stride=1,bias=False)
        self.bn3=nn.BatchNorm2d(num_features=self.expansion*out_channels)
        if downsample:
            conv=nn.Conv2d(in_channels=in_channels,out_channels=self.expansion*out_channels,kernel_size=1,stride=stride,bias=False)
            bn=nn.BatchNorm2d(num_features=self.expansion*out_channels)
            downsample=nn.Sequential(conv,bn)
        else:
            downsample=None



    def forward(self,x):
        i=x
        x=self.conv1(x)
        x=self.bn1(x)
        x=self.relu(x)

        x=self.conv2(x)
        x=self.bn2(x)
        x=self.relu(x)

        x=self.conv3(x)
        x=self.bn3(x)
        if self.downsample is not None:
            i=self.downsample(i)
        x+=i
        x=nn.relu(x)
        return x

In [28]:
output_dim=10

In [29]:
ResNetConfig=namedtuple('ResNetConfig',['block','n_blocks','channels'])

In [30]:
resnet18_config=ResNetConfig(block=BasicBlock,
                             n_blocks=[2,2,2,2],
                             channels=[64,128,256,512],
                             ) 

In [31]:
resnet34_config=ResNetConfig(block=BasicBlock,
                             n_blocks=[3,4,6,3],
                             channels=[64,128,256,512],
                             ) 

In [32]:
resnet50_config=ResNetConfig(block=BottleNeck,
                             n_blocks=[3,4,6,3],
                             channels=[64,128,256,512],
                             ) 

In [33]:
resnet101_config=ResNetConfig(block=BottleNeck,
                             n_blocks=[3,4,23,3],
                             channels=[64,128,256,512],
                             ) 

In [34]:
resnet152_config=ResNetConfig(block=BottleNeck,
                             n_blocks=[3,8,36,3],
                             channels=[64,128,256,512],
                             ) 

In [35]:
resnet18=ResNet(resnet18_config,output_dim)
resnet34=ResNet(resnet34_config,output_dim)
resnet50=ResNet(resnet50_config,output_dim)
resnet101=ResNet(resnet101_config,output_dim)
resnet152=ResNet(resnet152_config,output_dim)

In [36]:
resnet101

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (maxpool2d): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BottleNeck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BottleNeck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn