In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import os


In [None]:
# Hyperparameters
batch_size = 16
num_epochs = 50
lr = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


# U-Net: Convolutional Networks for Biomedical Image Segmentation

U-Net is a fully convolutional network architecture designed for semantic segmentation tasks, particularly in biomedical image analysis.

**Key Features:**
1. **Encoder-Decoder Structure**: Contracting path (encoder) to capture context, expanding path (decoder) to enable precise localization
2. **Skip Connections**: Concatenate feature maps from encoder to decoder to preserve fine-grained details
3. **Symmetric Architecture**: The network has a U-shaped architecture with symmetric encoder and decoder paths

**Architecture:**
- **Encoder (Contracting Path)**: Repeated application of two 3x3 convolutions, each followed by ReLU and 2x2 max pooling
- **Decoder (Expanding Path)**: Up-convolution (transposed convolution) followed by concatenation with corresponding feature map from encoder, then two 3x3 convolutions
- **Final Layer**: 1x1 convolution to map feature maps to desired number of classes


In [None]:
# Double Convolution Block
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super(DoubleConv, self).__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        return self.double_conv(x)


In [None]:
# Down (Encoder) Block
class Down(nn.Module):
    """Downscaling with maxpool then double conv"""
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    
    def forward(self, x):
        return self.maxpool_conv(x)


In [None]:
# Up (Decoder) Block
class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        
        # If bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # Input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        
        # Concatenate along channel dimension (skip connection)
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


In [None]:
# Complete U-Net Model
class UNet(nn.Module):
    def __init__(self, n_channels=1, n_classes=1, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        # Encoder (Contracting Path)
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        
        # Decoder (Expanding Path)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
    
    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        # Decoder with skip connections
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


In [None]:
# Create synthetic dataset for demonstration
class SyntheticSegmentationDataset(Dataset):
    """Synthetic dataset for segmentation demonstration"""
    def __init__(self, num_samples=1000, img_size=128):
        self.num_samples = num_samples
        self.img_size = img_size
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        # Create a simple synthetic image with geometric shapes
        img = np.zeros((self.img_size, self.img_size), dtype=np.float32)
        mask = np.zeros((self.img_size, self.img_size), dtype=np.float32)
        
        # Add random circles
        num_circles = np.random.randint(1, 4)
        for _ in range(num_circles):
            center_x = np.random.randint(20, self.img_size - 20)
            center_y = np.random.randint(20, self.img_size - 20)
            radius = np.random.randint(10, 30)
            
            y, x = np.ogrid[:self.img_size, :self.img_size]
            dist_from_center = np.sqrt((x - center_x)**2 + (y - center_y)**2)
            circle = dist_from_center <= radius
            
            img[circle] = np.random.uniform(0.3, 1.0)
            mask[circle] = 1.0
        
        # Add noise
        noise = np.random.normal(0, 0.1, (self.img_size, self.img_size))
        img = np.clip(img + noise, 0, 1)
        
        # Convert to tensors
        img = torch.from_numpy(img).unsqueeze(0)  # Add channel dimension
        mask = torch.from_numpy(mask).unsqueeze(0)
        
        return img, mask

# Create datasets
train_dataset = SyntheticSegmentationDataset(num_samples=800, img_size=128)
val_dataset = SyntheticSegmentationDataset(num_samples=200, img_size=128)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f'Training samples: {len(train_dataset)}')
print(f'Validation samples: {len(val_dataset)}')


In [None]:
# Initialize model
model = UNet(n_channels=1, n_classes=1, bilinear=True).to(device)

# Loss function (Dice Loss + BCE Loss for binary segmentation)
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()
    
    def forward(self, inputs, targets, smooth=1):
        # Flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        # Binary Cross Entropy
        BCE = F.binary_cross_entropy_with_logits(inputs, targets, reduction='mean')
        
        # Dice Loss
        inputs = torch.sigmoid(inputs)
        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
        
        # Combine
        Dice_BCE = BCE + dice_loss
        return Dice_BCE

criterion = DiceBCELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

print(model)
print(f'Total parameters: {sum(p.numel() for p in model.parameters()):,}')


In [None]:
# Training function
def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    
    for batch_idx, (images, masks) in enumerate(train_loader):
        images = images.to(device)
        masks = masks.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if (batch_idx + 1) % 20 == 0:
            print(f'Batch [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
    
    return running_loss / len(train_loader)

# Validation function
def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct_pixels = 0
    total_pixels = 0
    
    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            running_loss += loss.item()
            
            # Calculate accuracy
            preds = torch.sigmoid(outputs) > 0.5
            masks_bool = masks > 0.5
            correct_pixels += (preds == masks_bool).sum().item()
            total_pixels += masks.numel()
    
    accuracy = 100. * correct_pixels / total_pixels
    return running_loss / len(val_loader), accuracy


In [None]:
# Training loop
train_losses = []
val_losses = []
val_accuracies = []

for epoch in range(num_epochs):
    print(f'\nEpoch [{epoch+1}/{num_epochs}]')
    print('-' * 50)
    
    # Train
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
    train_losses.append(train_loss)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    print(f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.2f}%')


In [None]:
# Plot training curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 3, 2)
plt.plot(val_accuracies, label='Val Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Validation Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


In [None]:
# Visualize predictions
def visualize_predictions(model, data_loader, num_samples=4, device='cpu'):
    model.eval()
    with torch.no_grad():
        images, masks = next(iter(data_loader))
        images = images[:num_samples].to(device)
        masks = masks[:num_samples].to(device)
        
        outputs = model(images)
        preds = torch.sigmoid(outputs) > 0.5
        
        fig, axes = plt.subplots(num_samples, 3, figsize=(12, 4 * num_samples))
        
        for i in range(num_samples):
            # Original image
            axes[i, 0].imshow(images[i].cpu().squeeze(), cmap='gray')
            axes[i, 0].set_title('Input Image')
            axes[i, 0].axis('off')
            
            # Ground truth mask
            axes[i, 1].imshow(masks[i].cpu().squeeze(), cmap='gray')
            axes[i, 1].set_title('Ground Truth')
            axes[i, 1].axis('off')
            
            # Prediction
            axes[i, 2].imshow(preds[i].cpu().squeeze(), cmap='gray')
            axes[i, 2].set_title('Prediction')
            axes[i, 2].axis('off')
        
        plt.tight_layout()
        plt.show()

visualize_predictions(model, val_loader, num_samples=4, device=device)


In [None]:
# Calculate IoU (Intersection over Union) metric
def calculate_iou(pred, target, threshold=0.5):
    """Calculate IoU for binary segmentation"""
    pred = (torch.sigmoid(pred) > threshold).float()
    target = target.float()
    
    intersection = (pred * target).sum()
    union = pred.sum() + target.sum() - intersection
    
    if union == 0:
        return 1.0  # Both are empty
    
    iou = intersection / union
    return iou.item()

# Evaluate IoU on validation set
model.eval()
ious = []
with torch.no_grad():
    for images, masks in val_loader:
        images = images.to(device)
        masks = masks.to(device)
        outputs = model(images)
        
        for i in range(outputs.size(0)):
            iou = calculate_iou(outputs[i], masks[i])
            ious.append(iou)

mean_iou = np.mean(ious)
print(f'Mean IoU on validation set: {mean_iou:.4f}')
print(f'IoU range: [{np.min(ious):.4f}, {np.max(ious):.4f}]')
