# üñºÔ∏è Image Segmentation: U-Net & Mask R-CNN

**Author**: Data Science Master System  
**Difficulty**: ‚≠ê‚≠ê‚≠ê‚≠ê Advanced  
**Time**: 90 minutes  
**Prerequisites**: 12_cv_object_detection

## Learning Objectives
- Semantic vs Instance segmentation
- Implement U-Net architecture
- Use Mask R-CNN for instance segmentation
- Medical imaging applications

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

## 1. Segmentation Types

In [None]:
types = {
    'Semantic': 'Classify each pixel (no instance separation)',
    'Instance': 'Separate individual objects of same class',
    'Panoptic': 'Semantic + Instance combined'
}

print("üìä Segmentation Types:")
for name, desc in types.items():
    print(f"  {name}: {desc}")

## 2. U-Net Architecture

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1):
        super().__init__()
        # Encoder
        self.enc1 = DoubleConv(n_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)
        self.pool = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = DoubleConv(512, 1024)
        
        # Decoder
        self.up4 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.dec4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = DoubleConv(128, 64)
        
        self.out = nn.Conv2d(64, n_classes, 1)
    
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))
        
        # Bottleneck
        b = self.bottleneck(self.pool(e4))
        
        # Decoder with skip connections
        d4 = self.dec4(torch.cat([self.up4(b), e4], dim=1))
        d3 = self.dec3(torch.cat([self.up3(d4), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        
        return torch.sigmoid(self.out(d1))

unet = UNet().to(device)
print(f"U-Net parameters: {sum(p.numel() for p in unet.parameters()):,}")

## 3. Dice Loss

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1.):
        super().__init__()
        self.smooth = smooth
    
    def forward(self, pred, target):
        pred = pred.view(-1)
        target = target.view(-1)
        intersection = (pred * target).sum()
        return 1 - (2. * intersection + self.smooth) / (pred.sum() + target.sum() + self.smooth)

# Combined loss
class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCELoss()
        self.dice = DiceLoss()
    
    def forward(self, pred, target):
        return self.bce(pred, target) + self.dice(pred, target)

print("‚úÖ Loss functions ready")

## 4. Mask R-CNN

In [None]:
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights

weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
mask_rcnn = maskrcnn_resnet50_fpn(weights=weights).to(device)
mask_rcnn.eval()

print(f"‚úÖ Mask R-CNN loaded")
print(f"Output: boxes + labels + scores + masks")

## 5. Model Comparison

In [None]:
import pandas as pd

comparison = pd.DataFrame({
    'Model': ['U-Net', 'DeepLabV3', 'Mask R-CNN', 'SAM'],
    'Task': ['Semantic', 'Semantic', 'Instance', 'Any'],
    'Use Case': ['Medical', 'General', 'Object masks', 'Zero-shot'],
    'Speed': ['Fast', 'Medium', 'Slow', 'Slow']
})

display(comparison)

## üéØ Key Takeaways
1. U-Net: Best for medical/binary segmentation
2. Mask R-CNN: Instance segmentation with detection
3. Dice Loss: Better for class imbalance
4. SAM: Zero-shot segmentation (state-of-art)

**Next**: 14_cv_generative_models.ipynb