# ACT-UNet: Adaptive Computation Time for U-Net Segmentation

This notebook demonstrates the ACT-enhanced U-Net with reinforcement learning-based adaptive bottleneck depth.

## Overview
- **Model**: U-Net with ACT bottleneck that learns when to stop iterative refinement
- **Dataset**: Oxford-IIIT Pets (binary segmentation)
- **Training**: Actor-Critic RL approach for learning halting policy


In [None]:
# Setup and Imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import os
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create directories
os.makedirs('datasets', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('results', exist_ok=True)


## Data Download
Download and prepare the Oxford-IIIT Pets dataset for segmentation


In [None]:
# data download
from torchvision.datasets import OxfordIIITPet
import torchvision.transforms as transforms

print("Downloading Oxford-IIIT Pets dataset...")
print("This will download ~800MB of data on first run\n")

# Download dataset with segmentation masks
dataset = OxfordIIITPet(
    root='./datasets',
    split='trainval', 
    target_types='segmentation',
    download=True,
    transform=transforms.ToTensor()
)

# Also download test split
test_dataset = OxfordIIITPet(
    root='./datasets',
    split='test',
    target_types='segmentation', 
    download=True,
    transform=transforms.ToTensor()
)

print(f"✓ Training samples: {len(dataset)}")
print(f"✓ Test samples: {len(test_dataset)}")
print("\nDataset downloaded successfully!")


## Load Dataset and Create DataLoaders


In [None]:
# Create data loaders
from data.pets import PetsDataset

# Configuration
batch_size = 8
image_size = (256, 256)
num_workers = 2

# Get data loaders
train_loader, test_loader = PetsDataset.get_data_loaders(
    batch_size=batch_size,
    num_workers=num_workers,
    root='./datasets',
    size=image_size,
    download=False  # Already downloaded
)

print(f"Train batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")
print(f"Batch size: {batch_size}")
print(f"Image size: {image_size}")


## Visualize Sample Data


In [None]:
# Visualize samples
import matplotlib.pyplot as plt

def show_samples(loader, num_samples=4):
    batch = next(iter(loader))
    images = batch['image'][:num_samples]
    masks = batch['mask'][:num_samples]
    
    # Denormalize images for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    images = images * std + mean
    images = torch.clamp(images, 0, 1)
    
    fig, axes = plt.subplots(num_samples, 3, figsize=(12, num_samples*3))
    
    for i in range(num_samples):
        # Original image
        axes[i, 0].imshow(images[i].permute(1, 2, 0))
        axes[i, 0].set_title('Input Image')
        axes[i, 0].axis('off')
        
        # Ground truth mask
        axes[i, 1].imshow(masks[i, 0], cmap='gray')
        axes[i, 1].set_title('Ground Truth Mask')
        axes[i, 1].axis('off')
        
        # Overlay
        axes[i, 2].imshow(images[i].permute(1, 2, 0))
        axes[i, 2].imshow(masks[i, 0], alpha=0.5, cmap='jet')
        axes[i, 2].set_title('Overlay')
        axes[i, 2].axis('off')
    
    plt.tight_layout()
    plt.show()

show_samples(train_loader)


## Create Models
We'll create both standard U-Net and ACT-enhanced U-Net for comparison


In [None]:
# Create models
from model.unet_base import UNet
from model.act_unet import ACTUNet

# Standard U-Net (baseline)
unet_baseline = UNet(n_channels=3, n_classes=1, bilinear=True).to(device)

# ACT-enhanced U-Net
act_unet = ACTUNet(
    n_channels=3, 
    n_classes=1,
    max_iterations=5,  # Maximum K iterations in bottleneck
    bilinear=True
).to(device)

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Standard U-Net parameters: {count_parameters(unet_baseline):,}")
print(f"ACT-UNet parameters: {count_parameters(act_unet):,}")
print(f"Max iterations (K): {act_unet.max_iterations}")


## Training ACT-UNet
Train the model with RL-based actor-critic updates


In [None]:
# Initialize trainer
from training.train import ACTTrainer

trainer = ACTTrainer(
    model=act_unet,
    train_loader=train_loader,
    val_loader=test_loader,
    device=device,
    checkpoint_dir='./checkpoints'
)

print("Trainer initialized!")
print(f"Starting from epoch: {trainer.epoch}")
print(f"Actor-Critic alternating frequency: {trainer.alternating_freq} steps")


In [None]:
# Train model (set small number for demo, increase for full training)
num_epochs = 2  # Increase to 50-100 for full training

train_metrics, val_metrics = trainer.train(
    num_epochs=num_epochs,
    save_freq=5,
    val_freq=1
)


## Evaluation and Visualization


In [None]:
# Evaluate model and visualize results
@torch.no_grad()
def evaluate_and_visualize(model, loader, num_samples=4):
    model.eval()
    batch = next(iter(loader))
    images = batch['image'][:num_samples].to(device)
    masks = batch['mask'][:num_samples].to(device)
    
    # Get predictions with ACT info
    if hasattr(model, 'bottleneck'):  # ACT-UNet
        predictions, act_info = model(images, return_act_info=True)
        halt_iters = act_info['halt_iterations'].cpu().numpy()
    else:  # Standard U-Net
        predictions = model(images)
        halt_iters = None
    
    # Convert to probabilities
    pred_probs = torch.sigmoid(predictions).cpu()
    pred_binary = (pred_probs > 0.5).float()
    
    # Denormalize images
    images_cpu = images.cpu()
    mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
    images_vis = images_cpu * std + mean
    images_vis = torch.clamp(images_vis, 0, 1)
    
    # Create visualization
    fig, axes = plt.subplots(num_samples, 5, figsize=(15, num_samples*3))
    
    for i in range(num_samples):
        # Input image
        axes[i, 0].imshow(images_vis[i].permute(1, 2, 0))
        axes[i, 0].set_title('Input')
        axes[i, 0].axis('off')
        
        # Ground truth
        axes[i, 1].imshow(masks[i, 0].cpu(), cmap='gray')
        axes[i, 1].set_title('Ground Truth')
        axes[i, 1].axis('off')
        
        # Prediction
        axes[i, 2].imshow(pred_binary[i, 0], cmap='gray')
        if halt_iters is not None:
            axes[i, 2].set_title(f'Prediction (k={halt_iters[i]})')
        else:
            axes[i, 2].set_title('Prediction')
        axes[i, 2].axis('off')
        
        # Probability map
        axes[i, 3].imshow(pred_probs[i, 0], cmap='viridis', vmin=0, vmax=1)
        axes[i, 3].set_title('Probability')
        axes[i, 3].axis('off')
        
        # Error map
        error = torch.abs(pred_binary[i, 0] - masks[i, 0].cpu())
        axes[i, 4].imshow(error, cmap='hot')
        axes[i, 4].set_title('Error')
        axes[i, 4].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print ACT statistics if available
    if halt_iters is not None:
        print(f"Halt iterations: {halt_iters}")
        print(f"Average iterations: {halt_iters.mean():.2f}")
        print(f"Computation saved: {(1 - halt_iters.mean()/model.max_iterations)*100:.1f}%")

# Evaluate ACT-UNet
print("ACT-UNet Results:")
evaluate_and_visualize(act_unet, test_loader)


## Adaptive Depth Analysis


In [None]:
# Analyze adaptive depth distribution
@torch.no_grad()
def analyze_depth_distribution(model, loader, num_batches=10):
    model.eval()
    all_iterations = []
    all_difficulties = []  # We'll use loss as a proxy for difficulty
    
    for i, batch in enumerate(loader):
        if i >= num_batches:
            break
            
        images = batch['image'].to(device)
        masks = batch['mask'].to(device)
        
        # Get predictions with ACT info
        predictions, act_info = model(images, return_act_info=True)
        halt_iters = act_info['halt_iterations'].cpu().numpy()
        all_iterations.extend(halt_iters)
        
        # Compute per-sample loss as difficulty proxy
        from training.losses import SegmentationLoss
        loss_fn = SegmentationLoss()
        for j in range(images.shape[0]):
            loss = loss_fn(predictions[j:j+1], masks[j:j+1]).item()
            all_difficulties.append(loss)
    
    all_iterations = np.array(all_iterations)
    all_difficulties = np.array(all_difficulties)
    
    # Plot distribution
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Histogram of iterations
    axes[0].hist(all_iterations, bins=model.max_iterations, edgecolor='black')
    axes[0].set_xlabel('Number of Iterations (k)')
    axes[0].set_ylabel('Frequency')
    axes[0].set_title('Distribution of Adaptive Depths')
    axes[0].axvline(all_iterations.mean(), color='red', linestyle='--', 
                    label=f'Mean: {all_iterations.mean():.2f}')
    axes[0].legend()
    
    # Difficulty vs Iterations
    axes[1].scatter(all_difficulties, all_iterations, alpha=0.5)
    axes[1].set_xlabel('Task Difficulty (Loss)')
    axes[1].set_ylabel('Iterations Used')
    axes[1].set_title('Adaptive Depth vs Difficulty')
    
    # Iteration counts
    unique, counts = np.unique(all_iterations, return_counts=True)
    axes[2].bar(unique, counts, edgecolor='black')
    axes[2].set_xlabel('Iterations')
    axes[2].set_ylabel('Count')
    axes[2].set_title('Iteration Usage')
    axes[2].set_xticks(range(model.max_iterations))
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"\\nAdaptive Depth Statistics:")
    print(f"Mean iterations: {all_iterations.mean():.2f}")
    print(f"Std iterations: {all_iterations.std():.2f}")
    print(f"Min iterations: {all_iterations.min()}")
    print(f"Max iterations: {all_iterations.max()}")
    print(f"Computation saved: {(1 - all_iterations.mean()/model.max_iterations)*100:.1f}%")
    
    # Correlation with difficulty
    correlation = np.corrcoef(all_difficulties, all_iterations)[0, 1]
    print(f"Correlation (difficulty vs iterations): {correlation:.3f}")

# Analyze depth distribution
if hasattr(act_unet, 'bottleneck'):
    analyze_depth_distribution(act_unet, test_loader)
