In [19]:
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

In [20]:
class Block(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        identity_downsample=None,  # Used to downsample the input to match the output
        stride: int = 1,
    ):
        super(Block, self).__init__()
        self.expansion = 4
        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=1, stride=1, padding=0
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(
            out_channels,
            out_channels * self.expansion,
            kernel_size=1,
            stride=1,
            padding=0,
        )
        self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
        self.relu = nn.ReLU()
        self.identity_downsample = identity_downsample

    def forward(self, x):
        identity = x

        x = self.bn3(
            self.conv3(
                self.relu(self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))))
            )
        )

        if self.identity_downsample is not None:
            identity = self.identity_downsample(identity)

        x += identity
        x = self.relu(x)
        return x

In [21]:
class ResNet(nn.Module):
    def __init__(
        self,
        block,
        layers,  # [3, 4, 6, 3] means 3 layers in the first resnet stage, 4 in the second, 6 in the third, and 3 in the fourth
        image_channels,  # 3 for RGB, 1 for grayscale
        num_classes,  # number of classes in the dataset
    ):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # ResNet layers
        self.layer1 = self._make_layer(
            block, layers[0], out_channels=64, stride=1
        )  # 64*4=256
        self.layer2 = self._make_layer(
            block, layers[1], out_channels=128, stride=2
        )  # 128*4=512
        self.layer3 = self._make_layer(
            block, layers[2], out_channels=256, stride=2
        )  # 256*4=1024
        self.layer4 = self._make_layer(
            block, layers[3], out_channels=512, stride=2
        )  # 512*4=2048

        self.avgpool = nn.AdaptiveAvgPool2d(
            (1, 1)
        )  # We can use AdaptiveAvgPool2d and specify the output size as 1x1 to make sure the output size is always 1x1, regardless of the input size of the image
        self.fc = nn.Linear(512 * 4, num_classes)  # 512 * 4 = 2048

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)

        return x

    def _make_layer(self, block, num_residual_blocks, out_channels, stride):
        identity_downsample = None
        layers = []

        if stride != 1 or self.in_channels != out_channels * 4:
            # We do resnet downsample in the case that self.in_channels != out_channels * 4 because we need to match the input shape with the output shape of the residual block
            identity_downsample = nn.Sequential(
                nn.Conv2d(
                    self.in_channels,
                    out_channels * 4,
                    kernel_size=1,
                    stride=stride,
                ),
                nn.BatchNorm2d(out_channels * 4),
            )

        layers.append(
            block(self.in_channels, out_channels, identity_downsample, stride)
        )
        self.in_channels = (
            out_channels * 4
        )  # in_channels is always 4x out_channels because of the block's expansion (256)
        for i in range(num_residual_blocks - 1):
            layers.append(
                block(self.in_channels, out_channels)
            )  # 256 -> 64, 64*4 (256) again

        return nn.Sequential(*layers)

In [22]:
def ResNet50(img_channel: int = 3, num_classes: int = 1000) -> ResNet:
    return ResNet(Block, [3, 4, 6, 3], img_channel, num_classes)

In [23]:
def ResNet101(img_channel: int = 3, num_classes: int = 1000) -> ResNet:
    return ResNet(Block, [3, 4, 23, 3], img_channel, num_classes)

In [24]:
def ResNet152(img_channel: int = 3, num_classes: int = 1000) -> ResNet:
    return ResNet(Block, [3, 8, 36, 3], img_channel, num_classes)

In [25]:
from torchinfo import summary

model = ResNet50()
print(
    summary(model, input_size=(1, 3, 224, 224), verbose=0)
)  # input_size=(1, 3, 224, 224) -> (batch_size, channels, H, W)

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,472
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 256, 56, 56]          --
│    └─Block: 2-1                        [1, 256, 56, 56]          --
│    │    └─Conv2d: 3-1                  [1, 64, 56, 56]           4,160
│    │    └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-3                    [1, 64, 56, 56]           --
│    │    └─Conv2d: 3-4                  [1, 64, 56, 56]           36,928
│    │    └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128
│    │    └─ReLU: 3-6                    [1, 64, 56, 56]           --
│ 

In [26]:
# Delete the existing graphs
writer = SummaryWriter()
writer.log_dir = "../runs/resnet"
writer.add_graph(ResNet50(), torch.rand(1, 3, 224, 224))
writer.flush()
writer.close()