In [46]:
import torch
import torch.nn as nn
from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter

In [51]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, intermediate_channels, stride=1, identity=None):
        super(ResidualBlock, self).__init__()
        self.identity = identity
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=intermediate_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(intermediate_channels)
        self.conv2 = nn.Conv2d(
            in_channels=intermediate_channels,
            out_channels=intermediate_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False
        )
        self.bn2 = nn.BatchNorm2d(intermediate_channels)
        self.conv3 = nn.Conv2d(
            in_channels=intermediate_channels,
            out_channels=intermediate_channels * 4,
            kernel_size=1,
            padding=0,
            stride=1,
            bias=False
        )
        self.bn3 = nn.BatchNorm2d(intermediate_channels * 4)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.relu3 = nn.ReLU()

    def forward(self, x):
        if self.identity:
            identity = x.clone()

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.conv3(x)
        x = self.bn3(x)

        if self.identity:
            x += identity
            
        out = self.relu3(x)
        return out



In [43]:
class ResNet(nn.Module):
    def __init__(self, ResidualBlock, layers, img_channels, num_classes):
        super(ResNet, self).__init__()

        self.in_channels = 64

        self.conv1 = nn.Conv2d(
            in_channels=img_channels,
            out_channels=64,
            kernel_size=7,
            stride=2,
            padding=3
        )
        self.bn1 = nn.BatchNorm2d(64)

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.relu = nn.ReLU()

        self.layer1 = self._make_layer(ResidualBlock, layers[0], 64, stride=1)
        self.layer2 = self._make_layer(ResidualBlock, layers[1], 128, stride=2)
        self.layer3 = self._make_layer(ResidualBlock, layers[2], 256, stride=2)
        self.layer4 = self._make_layer(ResidualBlock, layers[3], 512, stride=2)

        self.avgpool = nn.AvgPool2d(kernel_size=7)
        self.fc = nn.Linear(2048, num_classes)


    
    def _make_layer(self, ResidualBlock, num_residual_blocks, intermediate_channels, stride):

        layers = []


        layers += [ResidualBlock(self.in_channels, intermediate_channels, stride)]

        self.in_channels = intermediate_channels * 4

        for i in range(num_residual_blocks - 1):
            layers += [ResidualBlock(self.in_channels, intermediate_channels, identity=True)]

        return nn.Sequential(*layers)

        
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
    
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        out = self.fc(x)
        return out


        

In [52]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = ResNet(ResidualBlock, [3, 4, 6, 3], img_channels=3, num_classes=1000).to(device)
x = torch.randn(10, 3, 224, 224).to(device)

writer = SummaryWriter()
writer.add_graph(model, x.to(device))

summary(model, (3, 224, 224))

print(model(x).shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,472
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
             ReLU-13          [-1, 256, 56, 56]               0
    ResidualBlock-14          [-1, 256,