<a href="https://colab.research.google.com/github/ikhlas15/ATHENS-AI-Medical-Imaging/blob/main/H10_cnn_blocks_and_advanced_architectures.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Notebook 10: Advanced CNN Architectures (ResNet, U-Net)**

### **Course**: Artificial Intelligence in Medical Imaging: From Fundamentals to Applications

***

## **1. Introduction**

Welcome to Notebook 10! You've successfully built and trained a baseline CNN. Now, we will explore the architectural innovations that have enabled deep learning models to become incredibly deep and powerful. We will move beyond simple, sequential models and into the world of state-of-the-art network design.

The architectures we discuss today, **ResNet** and **U-Net**, are arguably the two most important and influential designs in modern medical imaging AI.

#### **What you will learn today:**
*   The motivation behind **Residual Connections** and how they solve the "vanishing gradient" problem, allowing for extremely deep networks.
*   How to build a **Residual Block**, the fundamental component of a ResNet.
*   An overview of the complete **ResNet** architecture and how to use a pre-built version from `torchvision`.
*   An introduction to the **U-Net** architecture, the gold standard for image segmentation tasks in medical imaging, and its famous "encoder-decoder" structure with "skip connections."

***

## **2. Setup: Installing and Importing Libraries**

Let's start with our usual setup.


In [None]:
# Install required packages
!pip install torch torchvision medmnist scikit-learn seaborn

import torch
import torch.nn as nn
# TODO: Import the functional API from torch.nn
import torch.nn. ... as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from medmnist import PneumoniaMNIST
import numpy as np
import matplotlib.pyplot as plt

# Set our standard random seed and device
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Collecting medmnist
  Downloading medmnist-3.0.2-py3-none-any.whl.metadata (14 kB)
Collecting fire (from medmnist)
  Downloading fire-0.7.1-py3-none-any.whl.metadata (5.8 kB)
Downloading medmnist-3.0.2-py3-none-any.whl (25 kB)
Downloading fire-0.7.1-py3-none-any.whl (115 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.9/115.9 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fire, medmnist
Successfully installed fire-0.7.1 medmnist-3.0.2
Using device: cuda


***

## **3. The Challenge of Deep Networks and the Rise of Residual Connections**

In theory, a deeper neural network should be more powerful because it can learn more complex and hierarchical features. However, in the early days of deep learning, researchers found that simply stacking more and more layers on top of each other led to a problem: **performance got worse**.

This was due to the **vanishing gradient problem**. During backpropagation, the gradient signal has to travel backward through all the layers. In a very deep network, this signal can become exponentially smaller, until it is effectively zero by the time it reaches the early layers. As a result, these early layers stop learning.

The **ResNet (Residual Network)** paper introduced a brilliantly simple solution: the **residual connection**, also known as a **skip connection**.

<br>
<img src="https://upload.wikimedia.org/wikipedia/commons/b/ba/ResBlock.png" alt="Residual Connection Diagram" width="500"/>
<br>

Instead of forcing a block of layers to learn a direct mapping from an input $x$ to an output $H(x)$, a residual connection allows the block to learn a *residual function* $F(x) = H(x) - x$. The final output is then $H(x) = F(x) + x$.

This seemingly minor change has a profound effect: it creates a "shortcut" for the gradient to flow directly through the network. If a block of layers isn't useful, the network can easily learn to make $F(x)$ zero, effectively "skipping" the block by passing the input $x$ through unchanged. This makes it much easier to train extremely deep networks (some ResNets have over 150 layers!).

***

## **4. Building a Residual Block from Scratch**

Let's implement the core building block of a ResNet. A basic residual block consists of:
*   Two convolutional layers with Batch Normalization and ReLU activations.
*   A skip connection that adds the input of the block to the output.
*   A special "downsample" connection to handle cases where the input and output dimensions don't match (e.g., when we reduce the spatial size or increase the number of channels).


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()

        # The main path
        # TODO: 1) 3×3 conv (in -> out) with given stride
        # HINT: 3×3 conv → padding = 1
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=___, stride=___, padding=___, bias=False)
        # HINT: BatchNorm receives number of output channels
        self.bn1 = nn.BatchNorm2d(...)

        # HINT: same as conv1 but out→out
        self.conv2 = nn.Conv2d(, , kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # The skip connection path
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            # If dimensions change, we need to project the input to match the output dimensions
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=..., stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        # Main path
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        # Add the skip connection
        out += self....(x)

        # Apply relu after the addition
        out = F....(out)
        return out

# --- Test the block ---
# Case 1: Dimensions stay the same
block_same_dim = ResidualBlock(in_channels=64, out_channels=64, stride=1).to(device)
input_tensor_1 = torch.randn(4, 64, 32, 32).to(device) # Batch of 4, 64 channels, 32x32
output_tensor_1 = block_same_dim(input_tensor_1)
print(f"Output shape (same dim): {output_tensor_1.shape}")

# Case 2: Dimensions change (stride=2 halves the size, channels double)
block_change_dim = ResidualBlock(in_channels=64, out_channels=128, stride=).to(device)
input_tensor_2 = torch.randn(4, 64, 32, 32).to(device)
output_tensor_2 = block_change_dim(input_tensor_2)
print(f"Output shape (changed dim): {...}")

Output shape (same dim): torch.Size([4, 64, 32, 32])
Output shape (changed dim): torch.Size([4, 128, 16, 16])


***

## **5. The Full ResNet Architecture**

A full ResNet model is built by stacking these residual blocks. `torchvision` provides pre-built, highly optimized implementations of famous architectures like ResNet18, ResNet34, and ResNet50.

Let's load a **ResNet18** and adapt it for our 1-channel grayscale PneumoniaMNIST dataset.


In [None]:
# Load the ResNet18 architecture
# `pretrained=False` means we get the architecture with randomly initialized weights.
resnet_model = torchvision.models.resnet18(weights=False)

# --- Adapt the model for our specific problem ---
# Original ResNet is designed for 3-channel (RGB) images. We need to change the first convolutional layer
# to accept 1-channel (grayscale) images.
# Hint: Replace the first convolution so it accepts 1 channel instead of 3
resnet_model.conv1 = nn.Conv2d(_____, 64, kernel_size=7, stride=2, padding=3, bias=False)


# The original model was trained on ImageNet (1000 classes). We need to change the final fully
# connected layer to output 2 classes for our pneumonia detection task.
num_ftrs = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(num_ftrs, )

# Move the model to the device
resnet_model.to(device)

print("--- Adapted ResNet18 Architecture ---")
# print(resnet_model) # Uncomment to see the full architecture

# Let's test it with a sample batch (4)
images, _ = next(iter(DataLoader(PneumoniaMNIST(split='val', transform=transforms.ToTensor(), download=True), batch_size=)))
# Hint: Move images to the same device as the model
images = images.to(...)
output = resnet_model(images)
print(f"\nInput shape: {images.shape}")
print(f"Output shape: {output.shape}")



--- Adapted ResNet18 Architecture ---


100%|██████████| 4.17M/4.17M [00:01<00:00, 3.40MB/s]



Input shape: torch.Size([4, 1, 28, 28])
Output shape: torch.Size([4, 2])


You can now train this ResNet model using the exact same training loop from Notebook 06! It's a drop-in replacement for our `BaselineCNN`, but far more powerful.

***

## **6. U-Net: The Gold Standard for Medical Image Segmentation**

While ResNet is a king of **classification**, **U-Net** is the king of **segmentation**. Segmentation is the task of classifying every single pixel in an image (e.g., "this pixel belongs to a tumor," "this one belongs to healthy tissue").

The U-Net architecture is famous for its beautiful, symmetric, U-shaped structure.
<br>
<img src="https://miro.medium.com/v2/resize:fit:1400/format:webp/1*VUS2cCaPB45wcHHFp_fQZQ.png" alt="U-Net Architecture Diagram" width="700"/>

Source: https://short.upm.es/3udvl
<br>

It consists of two main parts:
1.  **The Contracting Path (Encoder):** This is essentially a standard classification network (like ResNet or our baseline CNN). It takes the input image and progressively downsamples it, capturing features at different scales. The deeper it goes, the more it understands "what" is in the image, but it loses information about "where."
2.  **The Expansive Path (Decoder):** This path takes the low-resolution feature map from the encoder and progressively upsamples it, aiming to reconstruct a full-resolution segmentation map.
3.  **Skip Connections:** This is the magic of U-Net. It concatenates the feature maps from the encoder at each level with the corresponding feature maps in the decoder. This allows the decoder to use both the high-level semantic information from deep in the network *and* the fine-grained spatial information from the early layers, resulting in highly precise segmentations.

We will not build a full U-Net today, but it is crucial to understand that the **encoder** part of a U-Net is often a pre-trained classification network like ResNet. The features learned for classifying images are highly effective for localizing objects in segmentation.


***

## **7. Summary and Next Steps**

This notebook introduced you to the architectural patterns that power modern deep learning for medical imaging. You have learned:
*   Why **Residual Connections** are so effective and how they enable the training of very deep networks.
*   How to build a **Residual Block** from scratch.
*   How to use a powerful, pre-built **ResNet** model from `torchvision` and adapt it to a new task.
*   The high-level structure of **U-Net**, the standard for medical image segmentation, and its reliance on a CNN encoder and skip connections.

**Challenge**: Take the adapted ResNet18 model from this notebook and train it on the PneumoniaMNIST dataset using the training loop from Notebook 06. Compare its final validation accuracy to the baseline CNN from Notebook 07. How much of an improvement do you see?

In the next notebook, **`11_segmentation_unet.ipynb`**, we will build and train a full U-Net model from scratch to perform a real segmentation task.
