In [1]:
import torch
from torch import nn

In [54]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.batch_norm = nn.BatchNorm2d(num_features=out_channels)
        
    def forward(self, X):
        return self.batch_norm(self.conv(X))
    
    
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResNetBlock, self).__init__()
        
        self.stride = 1 if in_channels == out_channels else 2
        self.projection = None
        
        self.block1 = ConvBlock(in_channels, out_channels, kernel_size=3, stride=self.stride, padding=1)
        self.block2 = ConvBlock(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        
        if self.stride == 2:
            self.projection = ConvBlock(in_channels, out_channels, kernel_size=1, stride=self.stride, padding=0)
        
    def forward(self, X):
        copy_X = X.clone()
        
        X = self.relu(self.block1(X))
        X = self.relu(self.block2(X))
        
        if self.projection is not None:
            copy_X = self.projection(copy_X)
            
        
        X += copy_X
        print(X.shape)
        return self.relu(X)
    
    
class ResNet34(nn.Module):
    _config = [(2, 64), (2, 128), (2, 256), (2, 512)]
    
    def __init__(self, in_channels=3, num_classes=1000):
        super(ResNet34, self).__init__()
        
        self.init_conv = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=7, stride=2),
            nn.BatchNorm2d(num_features=64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
        
        self.in_channels_res = 64
        self.resnet_blocks = self._create_blocks()
        
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.fc = nn.Linear(512, num_classes)
        
        
    def _create_blocks(self):
        blocks = []
        
        in_channels = self.in_channels_res
        for n_repeat, out_channels in ResNet34._config:
            for _ in range(n_repeat):
                blocks.append(ResNetBlock(in_channels, out_channels))
                in_channels = out_channels
                
        return nn.Sequential(*blocks)
    
    def forward(self, X):
        X = self.init_conv(X)
        
        X = self.resnet_blocks(X)
        
        X = self.avg_pool(X)
        X = X.reshape(X.shape[0], -1)
        return self.fc(X)

In [55]:
model = ResNet34()
X = torch.randn((1, 3, 224, 224))
preds = model(X)

torch.Size([1, 64, 54, 54])
torch.Size([1, 64, 54, 54])
torch.Size([1, 128, 27, 27])
torch.Size([1, 128, 27, 27])
torch.Size([1, 256, 14, 14])
torch.Size([1, 256, 14, 14])
torch.Size([1, 512, 7, 7])
torch.Size([1, 512, 7, 7])


In [53]:
preds.shape

torch.Size([1, 1000])