In [7]:
import torch
from torch import nn


## The Hyperparams, as given in https://arxiv.org/pdf/1512.03385 for CIFAR-10
class Residual_Block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.residual_core = nn.Sequential(
            # Dimensionality Reduction - Conv 1x1
            nn.Conv2d(in_channels, out_channels // 2, 1),
            nn.BatchNorm2d(out_channels // 2),

            nn.LeakyReLU(),
            
            # Feature Extraction
            nn.Conv2d(out_channels // 2, out_channels // 2, 3),
            nn.BatchNorm2d(out_channels // 2),
            
            nn.LeakyReLU(),
            
            # Dimensionality Expansion
            nn.Conv2d(out_channels // 2, out_channels, 1),
            nn.BatchNorm2d(out_channels // 2),
        )
    
    def forward(self, x):
        residual = x
        x = self.residual_core(x)
        x = nn.LeakyReLU()(x + residual)
        return x

class ResNet(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.conv_1 = Residual_Block(3, 16)
        self.list_1 = nn.ModuleList([Residual_Block(16, 16) for i in range(2 * n)])
        self.max_pool1 = nn.MaxPool2d(2, stride = 2)
        self.conv_2 = Residual_Block(16, 32)
        self.list_2 = nn.ModuleList([Residual_Block(32, 32) for i in range(2 * n - 1)])
        self.max_pool2 = nn.MaxPool2d(2, stride = 2)
        self.conv_3 = Residual_Block(32, 64)
        self.list_3 = nn.ModuleList([Residual_Block(64, 64) for i in range(2 * n)])
        self.avg_pool = nn.AvgPool2d(1)
        self.finLinLay = nn.Linear(64, 200)
        self.softmax = nn.Softmax(dim = -1)
    
    def forward(self, x):
        x = self.conv_1(x)
        x = self.list_1(x)
        x = self.max_pool1(x)
        x = self.conv_2(x)
        x = self.list_2(x)
        x = self.max_pool2(x)
        x = self.conv_3(x)
        x = self.list_3(x)
        x = self.avg_pool(x)
        x = self.finLinLay(x)
        x = self.softmax(x)
        return x




In [9]:
model = ResNet(50)
total_params = sum(p.numel() for p in model.parameters())
total_params

1829360

In [None]:
optimizer = torch.optim.Adam(model.parameters())