<a href="https://colab.research.google.com/github/moh2236945/pytorch_classification/blob/master/models/ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import shutil
from collections import OrderedDict
from IPython.display import clear_output

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets
from torchsummary import summary
from torch.utils.data import Dataset, DataLoader, random_split

In [2]:
class LambdaLayer(nn.Module):
    
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd
    
    def forward(self, x):
        return self.lambd(x)

class ConvBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, stride=1, option='A'):
        super(ConvBlock, self).__init__()
        
        self.features = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)),
            ('bn1', nn.BatchNorm2d(out_channels)),
            ('act1', nn.ReLU()),
            ('conv2', nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)),
            ('bn2', nn.BatchNorm2d(out_channels))
        ]))

        self.shortcut = nn.Sequential()
        
        if stride != 1 or in_channels != out_channels:
            if option == 'A':
                pad = out_channels//4
                # ::2 replace the stride 2 + F.pad apply padding to (W,H,C,N).
                self.shortcut = LambdaLayer(lambda x:
                            F.pad(x[:, :, ::2, ::2], (0,0, 0,0, pad,pad, 0,0)))
            if option == 'B':
                self.shortcut = nn.Sequential(OrderedDict([
                    ('s_conv1', nn.Conv2d(in_channels, 2*out_channels, kernel_size=1, stride=stride, padding=0, bias=False)),
                    ('s_bn1', nn.BatchNorm2d(2*out_channels))
                ]))
        
    def forward(self, x):
        out = self.features(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    """
        ResNet architecture for CIFAR-10.
    """
    def __init__(self, block_type, num_blocks):
        super(ResNet, self).__init__()
        
        self.in_channels = 16
        
        self.conv0 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn0 = nn.BatchNorm2d(16)
        self.block1 = self.__build_layer(block_type, 16, num_blocks[0], mismatch_stride=1)
        self.block2 = self.__build_layer(block_type, 32, num_blocks[1], mismatch_stride=2)
        self.block3 = self.__build_layer(block_type, 64, num_blocks[2], mismatch_stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.linear = nn.Linear(64, 10)
    
    def __build_layer(self, block_type, out_channels, num_blocks, mismatch_stride):
        strides = [mismatch_stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block_type(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = F.relu(self.bn0(self.conv0(x)))
        out = self.block1(out)
        out = self.block2(out)        
        out = self.block3(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.linear(out)
        return out
def ResNet56():
    return ResNet(block_type=ConvBlock, num_blocks=[9,9,9])


In [3]:
model = ResNet56()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
summary(model, (3, 32, 32))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 32, 32]             432
       BatchNorm2d-2           [-1, 16, 32, 32]              32
            Conv2d-3           [-1, 16, 32, 32]           2,304
       BatchNorm2d-4           [-1, 16, 32, 32]              32
              ReLU-5           [-1, 16, 32, 32]               0
            Conv2d-6           [-1, 16, 32, 32]           2,304
       BatchNorm2d-7           [-1, 16, 32, 32]              32
         ConvBlock-8           [-1, 16, 32, 32]               0
            Conv2d-9           [-1, 16, 32, 32]           2,304
      BatchNorm2d-10           [-1, 16, 32, 32]              32
             ReLU-11           [-1, 16, 32, 32]               0
           Conv2d-12           [-1, 16, 32, 32]           2,304
      BatchNorm2d-13           [-1, 16, 32, 32]              32
        ConvBlock-14           [-1, 16,