<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"></ul></div>

In [1]:
import torch
from torch import nn

<p>To understand how resnet work read </p>
<a href='https://arxiv.org/pdf/1512.03385.pdf'>Deep Residual Learning for Image Recognition<a/>
    
<a href='https://arxiv.org/pdf/1603.05027.pdf'>Identity Mappings in Deep Residual Networks</a>
<img src="images/resnet.jpg"  width='1000px'>
Source:  <a href='https://arxiv.org/pdf/1512.03385.pdf'>Deep Residual Learning for Image Recognition<a/>


In [2]:

class Residual(nn.Module):
    def __init__(self, input_channels, num_channels,
                 downsmaple=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,kernel_size=3, padding=1,
                               stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,kernel_size=3, padding=1)
        if downsmaple:
            self.downsmaple = nn.Conv2d(input_channels, num_channels,kernel_size=1,
                                        stride=strides)
        else:
            self.downsmaple = None
            
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
        self.relu = nn.ReLU()

    def forward(self, X):
        Y = self.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.downsmaple:
            X = self.downsmaple(X)
        return self.relu(Y+X)

In [3]:
def resnet_block(input_channels, num_channels, num_residuals,
                 first_block=False):
    layers = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            layers.append(Residual(input_channels, num_channels,
                                downsmaple=True, strides=2))
        else:
            layers.append(Residual(num_channels, num_channels))
    return nn.Sequential(*layers)

In [4]:
resnet34=nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
                   nn.BatchNorm2d(64), nn.ReLU(),
                   nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
                   resnet_block(64, 64, 2, first_block=True),
                   resnet_block(64, 128, 2),
                   resnet_block(128, 256, 2),
                   resnet_block(256, 512, 2),
                   nn.AdaptiveAvgPool2d((1,1)),
                   nn.Flatten(), nn.Linear(512, 10)
                 
                 )

In [5]:
X = torch.rand(size=(1, 1, 228, 228))
for layer in resnet34:
    X = layer(X)
    print(layer.__class__.__name__,'output shape:\t', X.shape)

Conv2d output shape:	 torch.Size([1, 64, 114, 114])
BatchNorm2d output shape:	 torch.Size([1, 64, 114, 114])
ReLU output shape:	 torch.Size([1, 64, 114, 114])
MaxPool2d output shape:	 torch.Size([1, 64, 57, 57])
Sequential output shape:	 torch.Size([1, 64, 57, 57])
Sequential output shape:	 torch.Size([1, 128, 29, 29])
Sequential output shape:	 torch.Size([1, 256, 15, 15])
Sequential output shape:	 torch.Size([1, 512, 8, 8])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 512, 1, 1])
Flatten output shape:	 torch.Size([1, 512])
Linear output shape:	 torch.Size([1, 10])
