In [1]:
#https://github.com/ava-orange-education/Mastering-Computer-Vision-with-PyTorch-2.0 

In [None]:
# ResNet in PyTorch

import torch
import torch.nn as nn

<img src="images/Resnet.png" width=1200 height=400/>

In [None]:
# 1. THE RESIDUAL BLOCK
# This is the building block. It does Conv -> BN -> ReLU -> Conv -> BN
# Then it adds the original input (x) to the result before the final ReLU.
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        
        # First Convolution
        # in_channels, out_channels these are passed when called
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        # Second Convolution
         # in_channels, out_channels these are passed when called
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, 
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # THE SHORTCUT (SKIP CONNECTION) LOGIC
        # If the input size (x) doesn't match the output size (due to stride or channel change),
        # we need to resize 'x' using a 1x1 convolution so we can add them together.
        self.shortcut = nn.Sequential()#<-------------- This block executes in 2nd and 3rd block of above image 
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = x  # Save the original input
        
        # Pass through the weight layers
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        # --- THE MAGIC HAPPENS HERE ---
        # Add the original input (identity) to the processed output
        out += self.shortcut(identity) 
        # ------------------------------
        
        out = self.relu(out)
        return out


<img src="images/Architecture.png" width=1000 height=500 />

In [None]:
# 2. THE MAIN RESNET MODEL
class SimpleResNet(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleResNet, self).__init__()
        
        # Initial processing (Input usually 3 channels for RGB)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        
        # Stack the Residual Blocks
        # Layer 1: Input 64 -> Output 64 (No size change)
        self.layer1 = ResidualBlock(64, 64, stride=1)
        # Layer 2: Input 64 -> Output 128 (Stride 2 cuts width/height in half)
        self.layer2 = ResidualBlock(64, 128, stride=2)
        # Layer 3: Input 128 -> Output 256
        self.layer3 = ResidualBlock(128, 256, stride=2)
        
        # Final Classification
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) # Global Average Pooling
        self.fc = nn.Linear(256, num_classes) # Fully Connected Layer, num_classes = 10

    def forward(self, x):
        
        # Initial Block
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        # Residual/Sequential Layers
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        
        # Final Classifier
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1) # Flatten
        out = self.fc(out)
        return out

In [None]:
# 3. TEST IT
# Create a random image: Batch size 1, 3 channels (RGB), 32x32 pixels
dummy_input = torch.randn(1, 3, 32, 32)
model = SimpleResNet(num_classes=10)
output = model(dummy_input)

print(f"Input shape: {dummy_input.shape}")
print(f"Output shape: {output.shape}") # Should be [1, 10]