In [1]:
import torch
import torch.nn as nn

In [2]:
def activation_func(activation):
    return  nn.ModuleDict([
        ['relu', nn.ReLU(inplace=True)],
        ['leaky_relu', nn.LeakyReLU(negative_slope=0.01, inplace=True)],
        ['selu', nn.SELU(inplace=True)],
        ['sigmoid', nn.Sigmoid()],
        ['none', nn.Identity()]
    ])[activation]

**Basic Block**

First, let's create the **basic block**.
As we are implementing ResNet-18-34, the **basic block** module comprises:
* **(1)** convolutional + batch_normalization layers
* ReLU activation
* **(2)** convolutional + batch_normalization layers

In [3]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsampling = 1, activation='relu'):
        super().__init__()
        
        self.downsampling = downsampling

        self.in_channels, self.out_channels = in_channels, out_channels

        if in_channels != out_channels:
            self.downsampling = 2

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = self.downsampling, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(num_features = out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(num_features = out_channels)
        self.activate = activation_func(activation)
  
    def forward(self, x):        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.activate(out)

        out = self.conv2(out)
        out = self.bn2(out)

        return out

**Residual Block**

Here, we create the residual part of the network.

The parameter *depth* gives the number of **basic blocks** that should be stacked to form each **residual block**.

As in the original paper, "*we perform downsampling directly by
convolutional layers that have a stride of 2*", and the first *conv* layer of each **residual block** is the one with stride = 2.  



<img height="200" src="https://miro.medium.com/max/713/1*D0F3UitQ2l5Q0Ak-tjEdJg.png" srcset="https://miro.medium.com/max/345/1*D0F3UitQ2l5Q0Ak-tjEdJg.png 276w, https://miro.medium.com/max/690/1*D0F3UitQ2l5Q0Ak-tjEdJg.png 552w, https://miro.medium.com/max/713/1*D0F3UitQ2l5Q0Ak-tjEdJg.png 570w" sizes="570px" float = "center">



In [4]:
class ResNetResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, depth, downsampling = 1, activation='relu'):
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.downsampling = 1 if self.in_channels == self.out_channels else 2
      
        self.blocks = nn.ModuleList([
            BasicBlock(self.in_channels, self.out_channels),
            *[BasicBlock(self.out_channels, self.out_channels) for i in range(depth - 1)]
        ])

        self.shortcuts = nn.ModuleList([
            *[nn.Sequential(
                nn.Conv2d(self.blocks[i].in_channels, self.blocks[i].out_channels, kernel_size = 1, stride = self.downsampling, bias = False),
                nn.BatchNorm2d(self.blocks[i].out_channels)) if self.blocks[i].in_channels != self.blocks[i].out_channels else nn.Identity() for i in range(depth)]
        ])

        self.activate = activation_func(activation)
    
    def forward(self, x):
          
        for i in range(len(self.blocks)):
            residual = self.shortcuts[i](x)
            x = self.blocks[i](x)
            x += residual
            x = self.activate(x)

        return x

**ResNet Encoder**

Here, we stack the entire network encoder:
* the *gate* performs the first block of processing
* the sequencial module of **blocks**, where each block is created using *depths* and *block_sizes* parameters, corresponding to the number of **basic blocks** and their channels

In [5]:
class ResNetEncoder(nn.Module):
    def __init__(self, in_channels, depths, block_sizes, activation='relu'):
        super().__init__()
        
        self.in_channels = in_channels

        self.block = ResNetResidualBlock
        self.block_sizes = block_sizes
        self.n = len(self.block_sizes)

        
        self.gate = nn.Sequential(
            nn.Conv2d(self.in_channels, self.block_sizes[0], kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(self.block_sizes[0]),
            activation_func(activation),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.blocks = nn.Sequential(
            self.block(self.block_sizes[0], self.block_sizes[0], depth = depths[0]),
            *[self.block(self.block_sizes[k], self.block_sizes[k + 1], depth = depths[k+1]) for k in range(len(depths) -1)]       
        )

    def forward(self, x):

        x = self.gate(x)
        x = self.blocks(x)
        
        return x

In [6]:
class ResNetClassifier(nn.Module):
    def __init__(self, in_features, n_classes):
        super().__init__()

        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(in_features, n_classes)
        self.activation = activation_func('sigmoid')

    def forward(self, x):
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        # if using sigmoid as output activation
        x = self.activation(self.classifier(x))

        return x

In [7]:
class ResNet(nn.Module):
    def __init__(self, in_channels, n_classes, depths, block_sizes):
        super().__init__()

        self.encoder = ResNetEncoder(in_channels, depths, block_sizes)
        self.classifier = ResNetClassifier(self.encoder.blocks[-1].out_channels, n_classes)
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.classifier(x)
        return x


In [8]:
from torchsummary import summary
import torchvision.models as models

def resnet18(in_channels, n_classes):
    return ResNet(in_channels, n_classes, depths = [2,2,2,2], block_sizes = [64, 128, 256, 512])

def resnet34(in_channels, n_classes):
    return ResNet(in_channels, n_classes, depths = [3,4,6,3], block_sizes = [64, 128, 256, 512])


# uncomment lines below to check if everything works out!
 

#model = resnet34(3, 1000)
#summary(model, (3, 224, 224))
#summary(models.resnet34(False), (3, 224, 224))

