<a href="https://colab.research.google.com/github/darshank528/Project-STORM/blob/master/Resnet32.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
from torchsummary import summary

#setting device configuration
device=torch.device("cuda" if torch.cuda.is_available else "cpu")

def ResNet32():
  class Basic_Module(nn.Module):
    #Basic Block with identity maps as shortcuts
    def __init__(self,in_channel,out_channel,stride=1):
      super(Basic_Module,self).__init__()
      self.conv1 = nn.Conv2d(in_channel,out_channel,kernel_size=3,stride=stride,padding=1,bias=False)
      self.bn1 = nn.BatchNorm2d(out_channel)
      self.conv2 = nn.Conv2d(out_channel,out_channel,kernel_size=3,padding=1,bias=False)
      self.bn2 = nn.BatchNorm2d(out_channel)
      self.shortcut = nn.Sequential()

      if stride!=1 or in_channel!=out_channel:
        self.shortcut = nn.Sequential(
                          nn.Conv2d(in_channel,out_channel,kernel_size=1,stride=stride,bias=False),
                          nn.BatchNorm2d(out_channel))

    def forward(self,x):
      x_shortcut = x    
      x = torch.celu(self.bn1(self.conv1(x)),alpha=0.075)
      x = self.bn2(self.conv2(x))
      x_shortcut = self.shortcut(x_shortcut)
      #creating a shortcut connection
      x = x + x_shortcut
      x = torch.celu(x,alpha=0.075)
      return x

  class Bottleneck_Module(nn.Module):
    #Bottleneck block with identity map as shortcuts
    def __init__(self,in_channel,out_channel,stride=1):
      super(Bottleneck_Module,self).__init__()
      self.conv1 = nn.Conv2d(in_channel,out_channel,kernel_size=1,padding=0,bias=False)
      self.bn1 = nn.BatchNorm2d(out_channel)
      self.conv2 = nn.Conv2d(out_channel,out_channel,kernel_size=3,padding=1,bias=False)
      self.bn2 = nn.BatchNorm2d(out_channel)
      self.conv3 = nn.Conv2d(out_channel,out_channel,kernel_size=1,padding=0,bias=False,stride=stride)
      self.bn3 = nn.BatchNorm2d(out_channel)
      self.shortcut = nn.Sequential()

      if stride!=1 or in_channel!=out_channel:
        self.shortcut = nn.Sequential(
                          nn.Conv2d(in_channel,out_channel,kernel_size=1,stride=stride,bias=False),
                          nn.BatchNorm2d(out_channel))

    def forward(self,x):
      x_shortcut = x
      x = torch.celu(self.bn1(self.conv1(x)),alpha=0.075)
      x = torch.celu(self.bn2(self.conv2(x)),alpha=0.075)
      x = self.bn3(self.conv3(x))
      x_shortcut = self.shortcut(x_shortcut)
      #creating a shortcut connection
      x = x + x_shortcut
      x = torch.celu(x,alpha=0.075)
      return x

  class ResNet(nn.Module):
    #main resnet model
    def __init__(self,block,filter_map,n,classes=10):
      super(ResNet,self).__init__()
      self.conv1 = nn.Conv2d(3,filter_map[0],kernel_size=3,padding=1,bias=False)
      self.bn1 = nn.BatchNorm2d(filter_map[0])
      self.block1 = self.MakeResnetLayer(block,(filter_map[0],filter_map[0]),n,stride=1)
      self.block2 = self.MakeResnetLayer(block,(filter_map[0],filter_map[1]),n,stride=2)
      self.block3 = self.MakeResnetLayer(block,(filter_map[1],filter_map[2]),n,stride=2)
      self.GloabalAveragePool = nn.AdaptiveAvgPool2d(2)
      self.fc = nn.Linear(2*2*filter_map[2],classes,bias=True)  

    def MakeResnetLayer(self,block,filters,n,stride):
      #defining filters
      in_channel,out_channel = filters
      #layers to be added at given stage
      layer = []
      layer.append(block(in_channel,out_channel,stride))
      for i in range(n-1):
        layer.append(block(out_channel,out_channel))
      #stacking all layers
      SubBlock = nn.Sequential(*layer)
      return SubBlock

    def forward(self,x):
      #initial layers
      x = torch.relu(self.bn1(self.conv1(x)))
      #stage1
      x = self.block1(x)
      #stage2
      x = self.block2(x)
      #stage3
      x = self.block3(x)
      #final layers
      x = self.GloabalAveragePool(x)
      x = x.view(-1,2*2*64)
      x = self.fc(x)
      return x
   
  #creating an object of resnet model and pushing it to device(CPU/GPU)
  model1 = ResNet(Basic_Module,[16,32,64],5).to(device)
  model2 = ResNet(Bottleneck_Module,[16,32,64],5).to(device)
  return model1,model2

#defining models
ModelBasic,ModelBottleneck = ResNet32()

In [2]:
#what resnet basic model look like
summary(ModelBasic,(3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
            Conv2d-3           [-1, 16, 32, 32]           2,304
       BatchNorm2d-4           [-1, 16, 32, 32]              32
            Conv2d-5           [-1, 16, 32, 32]           2,304
       BatchNorm2d-6           [-1, 16, 32, 32]              32
      Basic_Module-7           [-1, 16, 32, 32]               0
            Conv2d-8           [-1, 16, 32, 32]           2,304
       BatchNorm2d-9           [-1, 16, 32, 32]              32
           Conv2d-10           [-1, 16, 32, 32]           2,304
      BatchNorm2d-11           [-1, 16, 32, 32]              32
     Basic_Module-12           [-1, 16, 32, 32]               0
           Conv2d-13           [-1, 16, 32, 32]           2,304
      BatchNorm2d-14           [-1, 16,

In [3]:
#what resnet bottleneck model look like
summary(ModelBottleneck,(3,32,32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
            Conv2d-3           [-1, 16, 32, 32]             256
       BatchNorm2d-4           [-1, 16, 32, 32]              32
            Conv2d-5           [-1, 16, 32, 32]           2,304
       BatchNorm2d-6           [-1, 16, 32, 32]              32
            Conv2d-7           [-1, 16, 32, 32]             256
       BatchNorm2d-8           [-1, 16, 32, 32]              32
 Bottleneck_Module-9           [-1, 16, 32, 32]               0
           Conv2d-10           [-1, 16, 32, 32]             256
      BatchNorm2d-11           [-1, 16, 32, 32]              32
           Conv2d-12           [-1, 16, 32, 32]           2,304
      BatchNorm2d-13           [-1, 16, 32, 32]              32
           Conv2d-14           [-1, 16,