# Problem: Train a 3D CNN network for segmenting CT images

### Problem Statement

You are tasked with employing and evaluating a 3D CNN model in Pytorch for semantic segmentation on synthetically generated CT images.
Your goal is to review the input and label data shapes. Next, define a MedCNN model class with a `forward` method that emulates a encode-decoder architecture with appropriate input and output channels based on the input shapes.

### Requirements

1. **Implement** a MedCNN model class with Conv3D and ConvTranspose3d for downsampling and upsampling respectively.
2. **Define** Dice loss for the problem.
3. **Perform** transfer learning from a ResNet18 - a common strategy for custom architectures.
4. **Train** the model for 5 epochs.

### Constraints

- Use `Pytorch` in-built convolution layers
- Ensure, there is a segmentation head at the end of the network

<details>
  <summary>💡 Hint</summary>
  - Strip off the `Avgpooling` and linear layers from ResNet18 using `list(resnet_model.children())[:-2]`
  <br>
  - [Conv3D](https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html)
  <br>
  - [ConvTranspose3D](https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html)
  <br>
  - [Forum discussion on model.children](https://discuss.pytorch.org/t/module-children-vs-module-modules/4551)
</details>


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Generate synthetic CT-scan data (batches, slices, RGB) and associated segmentation masks
torch.manual_seed(42)
batch = 100
num_slices = 10
channels = 3
width = 256
height = 256

ct_images = torch.randn(size=(batch, num_slices, channels, width, height))
segmentation_masks = (
    torch.randn(size=(batch, num_slices, 1, width, height)) > 0
).float()

print(f"CT images (train examples) shape: {ct_images.shape}")
print(f"Segmentation binary masks (labels) shape: {segmentation_masks.shape}")

CT images (train examples) shape: torch.Size([100, 10, 3, 256, 256])
Segmentation binary masks (labels) shape: torch.Size([100, 10, 1, 256, 256])


In [3]:
# Define the MedCNN class and its forward method
class MedCNN(nn.Module):
    def __init__(self, backbone, out_channel=1):
        super(MedCNN, self).__init__()
        self.backbone = backbone

        # TODO: Add Downsample convolutional layers
        self.conv1 = nn.Conv3d(512, 256, kernel_size=3, padding=1)
        self.conv2 = nn.Conv3d(256, 128, kernel_size=3, padding=1)

        # TODO: Add Upsample convolutional layers
        self.conv_transpose1 = nn.ConvTranspose3d(
            128, 64, kernel_size=(1, 4, 4), stride=(1, 4, 4)
        )
        self.conv_transpose2 = nn.ConvTranspose3d(
            64, 16, kernel_size=(1, 8, 8), stride=(1, 8, 8)
        )

        # TODO: Final convolution layer from 16 to 1 channel
        self.conv3 = nn.Conv3d(16, out_channel, kernel_size=1)

        self.relu = nn.ReLU()

    def forward(self, x):
        b, d, c, w, h = x.size()  # Input size: [B, D, C, W, H]
        print(f"Input shape [B, D, C, W, H]: {b, d, c, w, h}")

        # TODO: make changes to the shape of the input such that it is compatible with ResNet
        x = x.view(-1, c, w, h)
        y1 = self.backbone(x)
        print(f"ResNet output shape [B, C, W, H]: {y1.shape}")

        # TODO: take output features from the backbone ResNet and make it compatible with Conv3D format
        _, new_c, new_w, new_h = y1.shape
        y1 = y1.view(b, d, new_c, new_w, new_h)
        y1 = y1.permute(0, 2, 1, 3, 4)
        y1 = self.relu(y1)
        print(f"ResNet output shape [B, C, D, W, H]: {y1.shape}")

        # TODO: Downsampling
        y2 = self.relu(self.conv2(self.relu(self.conv1(y1))))
        print(f"Downsampled output shape [B, C, D, W, H]: {y2.shape}")

        # TODO: Upsampling
        y3 = self.relu(self.conv_transpose2(self.relu(self.conv_transpose1(y2))))
        print(f"Upsampled output shape [B, C, D, W, H]: {y3.shape}")

        # TODO: final segmentation head
        y4 = torch.sigmoid(self.conv3(y3))
        y4 = y4.permute(0, 2, 1, 3, 4)
        print(f"Output shape [B, D, C, W, H]: {y4.shape}")

        return y4

In [4]:
# TODO: define Dice loss
def compute_dice_loss(pred, labels, eps=1e-8):
    """
    Args
    pred: [B, D, 1, W, H]
    labels: [B, D, 1, W, H]

    Returns
    dice_loss: [B, D, 1, W, H]
    """
    pred_flat = pred.view(-1)
    labels_flat = labels.view(-1)
    intersection = (pred_flat * labels_flat).sum()
    union = pred_flat.sum() + labels_flat.sum()
    dice = (2.0 * intersection + eps) / (union + eps)
    return 1 - dice

In [5]:
# Define resnet as the backbone removing the last two layers
resnet_model = torchvision.models.resnet18(pretrained=True)
resnet_model = nn.Sequential(*list(resnet_model.children())[:-2])

model = MedCNN(backbone=resnet_model)

optimizer = optim.Adam(model.parameters(), lr=0.01)



In [6]:
epochs = 5
for epoch in range(epochs):
    optimizer.zero_grad()
    pred = model(ct_images)
    loss = compute_dice_loss(pred, segmentation_masks)
    loss.backward()
    optimizer.step()
    print(f"Loss at epoch {epoch}: {loss}")

Input shape [B, D, C, W, H]: (100, 10, 3, 256, 256)
ResNet output shape [B, C, W, H]: torch.Size([1000, 512, 8, 8])
ResNet output shape [B, C, D, W, H]: torch.Size([100, 512, 10, 8, 8])
Downsampled output shape [B, C, D, W, H]: torch.Size([100, 128, 10, 8, 8])
Upsampled output shape [B, C, D, W, H]: torch.Size([100, 16, 10, 256, 256])
Output shape [B, D, C, W, H]: torch.Size([100, 10, 1, 256, 256])
Loss at epoch 0: 0.48381954431533813
Input shape [B, D, C, W, H]: (100, 10, 3, 256, 256)
ResNet output shape [B, C, W, H]: torch.Size([1000, 512, 8, 8])
ResNet output shape [B, C, D, W, H]: torch.Size([100, 512, 10, 8, 8])
Downsampled output shape [B, C, D, W, H]: torch.Size([100, 128, 10, 8, 8])
Upsampled output shape [B, C, D, W, H]: torch.Size([100, 16, 10, 256, 256])
Output shape [B, D, C, W, H]: torch.Size([100, 10, 1, 256, 256])
Loss at epoch 1: 0.3333361744880676
Input shape [B, D, C, W, H]: (100, 10, 3, 256, 256)
ResNet output shape [B, C, W, H]: torch.Size([1000, 512, 8, 8])
ResNet 