In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, expansion, downsample=None, stride=1):
        super(ResidualBlock, self).__init__()

        layers = []
        if expansion == 1:
            layers.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)
                )
            )

            layers.append(
                nn.Sequential(
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False),
                    nn.BatchNorm2d(out_channels)
                )
            )

        else:
            layers.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)
                )
            )

            layers.append(
                nn.Sequential(
                    nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
                    nn.BatchNorm2d(out_channels),
                    nn.ReLU(inplace=True)
                )
            )

            layers.append(
                nn.Sequential(
                    nn.Conv2d(out_channels, out_channels * expansion, kernel_size=1, stride=1, padding=0, bias=False),
                    nn.BatchNorm2d(out_channels * expansion)
                )
            )

        self.layers = nn.Sequential(*layers)

        self.downsample = downsample
        self.stride = stride
        self.relu = nn.ReLU()
            
    def forward(self, x):
        residual = x

        out = self.layers(x)

        if self.downsample is not None:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)

        return out
    
class ResNet(nn.Module):
    def __init__(self, block, block_depth=[2, 2, 2, 2], expansion=1):
        super(ResNet, self).__init__()
        
        self.in_channels = 64

        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.stage_1 = self._make_layer(block, 64, block_depth[0], expansion)
        self.stage_2 = self._make_layer(block, 128, block_depth[1], expansion, stride=2)
        self.stage_3 = self._make_layer(block, 256, block_depth[2], expansion, stride=2)
        self.stage_4 = self._make_layer(block, 512, block_depth[3], expansion, stride=2)
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * expansion, 10)
        
    def _make_layer(self, block, out_channels, blocks, expansion, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * expansion)
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, expansion, downsample, stride))
        self.in_channels = out_channels * expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels, expansion))

        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.stem(x)
        out = self.stage_1(out)
        out = self.stage_2(out)
        out = self.stage_3(out)
        out = self.stage_4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        
        return out
    
class ResNet18(ResNet):
    def __init__(self, block):
        super(ResNet18, self).__init__(block)

class ResNet34(ResNet):
    def __init__(self, block):
        super(ResNet34, self).__init__(block, block_depth=[3, 4, 6, 3])

class ResNet50(ResNet):
    def __init__(self, block):
        super(ResNet50, self).__init__(block, block_depth=[3, 4, 6, 3], expansion=4)

class ResNet101(ResNet):
    def __init__(self, block):
        super(ResNet101, self).__init__(block, block_depth=[3, 4, 23, 3], expansion=4)

class ResNet152(ResNet):
    def __init__(self, block):
        super(ResNet152, self).__init__(block, block_depth=[3, 8, 36, 3], expansion=4)

In [27]:
def test_vgg():
    model = ResNet50(ResidualBlock)
    x = torch.randn(4, 3, 224, 224)
    out = model(x)
    print(out.shape)
    print(model)

test_vgg()

torch.Size([4, 64, 56, 56])
torch.Size([4, 256, 56, 56])
torch.Size([4, 512, 28, 28])
torch.Size([4, 1024, 14, 14])
torch.Size([4, 2048, 7, 7])
torch.Size([4, 10])
ResNet50(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  )
  (conv2_x): Sequential(
    (0): ResidualBlock(
      (layers): Sequential(
        (0): Sequential(
          (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_r