<a href="https://colab.research.google.com/github/ganji759/Flood-Prediction-Using-Machine-Learning/blob/main/Hackathon/Image_Segmentation_Challenge.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Hackathon IndabaX DRC - UPN**
---


In this challenge you will code in missing cells. The main goal is to end up training your machine learning model and evaluate it.

This notebook implements a U-Net architecture with Generalized Divisive Normalization (GDN) for semantic segmentation on the Cityscapes dataset.

The winner will be based:
- Code complexity
- Features
- Model consistency

## Requierements :
- PyTorch
- torchvision
- PIL
- numpy
- matplotlib
- scikit-learn
- tqdm

Import torch, the neural network, the dataset, dataloader

In [None]:

# %% Import libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [None]:
import numpy as np
import random
import os
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import jaccard_score
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

In [None]:

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


In [None]:

# %% Set random seeds for reproducibility
def set_seed(seed=42):
    """Set random seeds for reproducibility across all libraries."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

set_seed(42)

In [None]:
# %% Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

In [None]:
# %% [markdown]
# ## Dataset Setup

# %% Setup Kaggle API (if needed)
def setup_kaggle_api(username, key):
    """Setup Kaggle API credentials."""
    kaggle_dir = os.path.join(os.path.expanduser("~"), ".kaggle")
    os.makedirs(kaggle_dir, exist_ok=True)

    kaggle_json = os.path.join(kaggle_dir, "kaggle.json")
    with open(kaggle_json, "w") as f:
        json.dump({"username": username, "API key": key}, f)

    os.chmod(kaggle_json, 0o600)
    print("Kaggle API configured successfully!")

In [None]:
# Uncomment and fill in your credentials if needed
# setup_kaggle_api("your_username", "your_api_key")

# %% Download dataset
import kagglehub

# Download Cityscapes dataset
try:
    path = kagglehub.dataset_download("shuvoalok/cityscapes")
    print(f"Dataset downloaded to: {path}")
except Exception as e:
    print(f"Error downloading dataset: {e}")
    path = "/kaggle/input/cityscapes"  # Fallback path

In [None]:
# %% Define dataset paths

DATASET_ROOT = path
train_images_folder_path = os.path.join(DATASET_ROOT, "train", "img")
train_mask_folder_path = os.path.join(DATASET_ROOT, "train", "label")
test_images_folder_path = os.path.join(DATASET_ROOT, "val", "img")
test_mask_folder_path = os.path.join(DATASET_ROOT, "val", "label")

# Verify paths exist
for name, path in [
    ("Train images", train_images_folder_path),
    ("Train masks", train_mask_folder_path),
    ("Test images", test_images_folder_path),
    ("Test masks", test_mask_folder_path)
]:
    if os.path.exists(path):
        num_files = len([f for f in os.listdir(path) if f.endswith('.png')])
        print(f"✓ {name}: {num_files} files")
    else:
        print(f"✗ {name}: Path not found!")

In [None]:
# %% [markdown]
# ## Constants and Configuration

# %% Define constants
IMG_HEIGHT, IMG_WIDTH = 96, 256
NUM_CLASSES = 30
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
NUM_EPOCHS = 300
NUM_WORKERS = 2

# Class names
CLASS_NAMES = [
    'unlabeled', 'dynamic', 'ground', 'road', 'sidewalk', 'parking',
    'rail track', 'building', 'wall', 'fence', 'guard rail', 'bridge',
    'tunnel', 'pole', 'traffic light', 'traffic sign', 'vegetation',
    'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus',
    'caravan', 'trailer', 'train', 'motorcycle', 'bicycle', 'license plate'
]

# Class colors (RGB)
CLASS_COLORS = np.array([
    (0, 0, 0), (111, 74, 0), (81, 0, 81), (128, 64, 128), (244, 35, 232),
    (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156),
    (190, 153, 153), (180, 165, 180), (150, 100, 100), (150, 120, 90),
    (153, 153, 153), (250, 170, 30), (220, 220, 0), (107, 142, 35),
    (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0),
    (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 0, 90), (0, 0, 110),
    (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)
], dtype=np.uint8)

print(f"Configuration:")
print(f"  Image size: {IMG_HEIGHT}x{IMG_WIDTH}")
print(f"  Number of classes: {NUM_CLASSES}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Epochs: {NUM_EPOCHS}")

In [None]:
# %% [markdown]
# ## GDN Layer Implementation

# %% Generalized Divisive Normalization Layer
class NonNegConstraint:
    """Constraint to keep parameters non-negative."""
    def __call__(self, tensor):
        return torch.clamp(tensor, min=1e-15)

class GDN(nn.Module):
    """
    Generalized Divisive Normalization layer.

    Based on: "Density Modeling of Images using a Generalized Normalization Transformation"
    by Ballé et al. (2016)
    """
    def __init__(self, in_channels, filter_size=3):
        super(GDN, self).__init__()
        self.in_channels = in_channels
        self.filter_size = filter_size
        self.padding = (filter_size - 1) // 2

        # Learnable parameters
        self.beta = nn.Parameter(torch.ones(in_channels))
        self.alpha = nn.Parameter(torch.ones(in_channels), requires_grad=False)
        self.epsilon = nn.Parameter(torch.ones(in_channels), requires_grad=False)
        self.gamma = nn.Parameter(
            torch.zeros(filter_size, filter_size, in_channels, in_channels)
        )

        self.constraint = NonNegConstraint()

    def forward(self, x):
        # Apply non-negative constraints
        with torch.no_grad():
            self.beta.data = self.constraint(self.beta.data)
            self.alpha.data = self.constraint(self.alpha.data)
            self.epsilon.data = self.constraint(self.epsilon.data)
            self.gamma.data = self.constraint(self.gamma.data)

        # Compute normalization
        abs_x = torch.abs(x)
        powered_x = abs_x ** self.alpha.view(1, -1, 1, 1)

        # Apply reflection padding
        padded_x = F.pad(powered_x,
                        (self.padding, self.padding, self.padding, self.padding),
                        mode='reflect')

        # Convolution with gamma weights
        # gamma: [h, w, in_c, out_c] -> [out_c, in_c, h, w]
        gamma_weight = self.gamma.permute(3, 2, 0, 1)

        norm_conv = F.conv2d(padded_x, gamma_weight, padding=0)

        # Apply beta and epsilon
        norm = self.beta.view(1, -1, 1, 1) + norm_conv
        norm = norm ** self.epsilon.view(1, -1, 1, 1)

        return x / norm

In [None]:

# %% [markdown]
# ## Dataset Implementation

# %% Custom Dataset Class
class CityscapesDataset(Dataset):
    """
    Custom Dataset for Cityscapes semantic segmentation.

    Returns:
        image: Tensor of shape [3, H, W], normalized to [0, 1]
        mask: Tensor of shape [NUM_CLASSES, H, W], one-hot encoded
    """
    def __init__(self, image_folder, mask_folder, image_names, mask_names,
                 height=IMG_HEIGHT, width=IMG_WIDTH, augment=False):
        self.image_folder = image_folder
        self.mask_folder = mask_folder
        self.image_names = image_names
        self.mask_names = mask_names
        self.height = height
        self.width = width
        self.augment = augment

        # Define augmentation transforms
        if augment:
            self.img_transform = transforms.Compose([
                transforms.ColorJitter(brightness=0.2, contrast=0.2,
                                     saturation=0.2, hue=0.1),
            ])
        else:
            self.img_transform = None

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.image_folder, self.image_names[idx])
        image = Image.open(img_path).convert('RGB')

        # Load mask
        mask_path = os.path.join(self.mask_folder, self.mask_names[idx])
        mask = Image.open(mask_path).convert('RGB')

        # Resize
        image = image.resize((self.width, self.height), Image.BILINEAR)
        mask = mask.resize((self.width, self.height), Image.NEAREST)

        # Apply augmentation to image only
        if self.augment and self.img_transform:
            image = self.img_transform(image)

        # Convert to numpy
        image = np.array(image, dtype=np.float32) / 255.0
        mask_rgb = np.array(mask, dtype=np.int32)

        # One-hot encode mask
        one_hot_mask = np.zeros((self.height, self.width, NUM_CLASSES),
                               dtype=np.float32)
        for i, color in enumerate(CLASS_COLORS):
            class_map = np.all(mask_rgb == color, axis=-1)
            one_hot_mask[:, :, i] = class_map

        # Convert to tensors [C, H, W]
        image = torch.from_numpy(image).permute(2, 0, 1)
        one_hot_mask = torch.from_numpy(one_hot_mask).permute(2, 0, 1)

        return image, one_hot_mask

In [None]:

# %% [markdown]
# ## Model Architecture

# %% Convolutional Block
class ConvBlock(nn.Module):
    """
    Convolutional block with:
    - Conv2D -> BatchNorm -> ReLU -> Dropout (x2)
    - Optional MaxPooling
    """
    def __init__(self, in_channels, out_channels, pool=True, dropout=0.2):
        super(ConvBlock, self).__init__()
        self.pool = pool

        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.dropout1 = nn.Dropout2d(dropout)

        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
        self.dropout2 = nn.Dropout2d(dropout)

        if pool:
            self.pool_layer = nn.MaxPool2d(2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.dropout1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.dropout2(x)

        if self.pool:
            return x, self.pool_layer(x)
        return x

In [None]:

# %% U-Net with GDN
class UNetGDN(nn.Module):
    """
    U-Net architecture with GDN layer for semantic segmentation.

    Architecture:
    - Encoder: 4 ConvBlocks with pooling (16, 32, 64, 128 filters)
    - Bridge: 1 ConvBlock without pooling (256 filters)
    - Decoder: 4 ConvBlocks with upsampling and skip connections
    - Output: 1x1 Conv with Softmax
    """
    def __init__(self, in_channels=3, num_classes=NUM_CLASSES):
        super(UNetGDN, self).__init__()

        # GDN layer
        self.gdn = GDN(in_channels)

        # Encoder
        self.enc1 = ConvBlock(in_channels, 16, pool=True)
        self.enc2 = ConvBlock(16, 32, pool=True)
        self.enc3 = ConvBlock(32, 64, pool=True)
        self.enc4 = ConvBlock(64, 128, pool=True)

        # Bridge
        self.bridge = ConvBlock(128, 256, pool=False)

        # Decoder
        self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec1 = ConvBlock(256, 128, pool=False)

        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = ConvBlock(128, 64, pool=False)

        self.up3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.dec3 = ConvBlock(64, 32, pool=False)

        self.up4 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        self.dec4 = ConvBlock(32, 16, pool=False)

        # Output
        self.out_conv = nn.Conv2d(16, num_classes, 1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        # Apply GDN
        x = self.gdn(x)

        # Encoder
        x1, p1 = self.enc1(x)
        x2, p2 = self.enc2(p1)
        x3, p3 = self.enc3(p2)
        x4, p4 = self.enc4(p3)

        # Bridge
        b = self.bridge(p4)

        # Decoder with skip connections
        u1 = self.up1(b)
        c1 = torch.cat([u1, x4], dim=1)
        d1 = self.dec1(c1)

        u2 = self.up2(d1)
        c2 = torch.cat([u2, x3], dim=1)
        d2 = self.dec2(c2)

        u3 = self.up3(d2)
        c3 = torch.cat([u3, x2], dim=1)
        d3 = self.dec3(c3)

        u4 = self.up4(d3)
        c4 = torch.cat([u4, x1], dim=1)
        d4 = self.dec4(c4)

        # Output
        out = self.out_conv(d4)
        return self.softmax(out)

In [None]:

# %% Test model instantiation
model = UNetGDN(in_channels=3, num_classes=NUM_CLASSES).to(device)
print(f"\nModel Summary:")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Test forward pass
dummy_input = torch.randn(2, 3, IMG_HEIGHT, IMG_WIDTH).to(device)
dummy_output = model(dummy_input)
print(f"\nTest forward pass:")
print(f"  Input shape: {dummy_input.shape}")
print(f"  Output shape: {dummy_output.shape}")
del dummy_input, dummy_output

In [None]:
# %% [markdown]
# ## Metrics and Visualization

# %% IoU Metric
def calculate_iou(y_true, y_pred, num_classes=NUM_CLASSES):
    """
    Calculate mean Intersection over Union (IoU).

    Args:
        y_true: Ground truth tensor [B, C, H, W]
        y_pred: Predicted class indices [B, H, W]
        num_classes: Number of classes

    Returns:
        Mean IoU score
    """
    ious = []
    batch_size = y_true.shape[0]

    for i in range(batch_size):
        # Flatten predictions and ground truth
        pred_flat = y_pred[i].reshape(-1).cpu().numpy()
        true_flat = torch.argmax(y_true[i], dim=0).reshape(-1).cpu().numpy()

        # Calculate IoU
        try:
            iou = jaccard_score(
                true_flat, pred_flat,
                average='macro',
                labels=np.arange(num_classes),
                zero_division=0
            )
            ious.append(iou)
        except:
            ious.append(0.0)

    return np.mean(ious)

# %% Visualization functions
def mask_to_rgb(mask, colors=CLASS_COLORS):
    """
    Convert class indices to RGB colored mask.

    Args:
        mask: Tensor or array of class indices [H, W]
        colors: Array of RGB colors [NUM_CLASSES, 3]

    Returns:
        RGB mask [H, W, 3]
    """
    h, w = mask.shape
    rgb_mask = np.zeros((h, w, 3), dtype=np.uint8)

    mask_np = mask.cpu().numpy() if torch.is_tensor(mask) else mask

    for class_idx in range(len(colors)):
        rgb_mask[mask_np == class_idx] = colors[class_idx]

    return rgb_mask

def visualize_predictions(model, dataset, num_samples=5, save_path=None):
    """Visualize model predictions on random samples."""
    model.eval()
    indices = random.sample(range(len(dataset)), min(num_samples, len(dataset)))

    fig, axes = plt.subplots(num_samples, 3, figsize=(15, num_samples * 4))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    with torch.no_grad():
        for i, idx in enumerate(indices):
            image, mask = dataset[idx]

            # Get prediction
            output = model(image.unsqueeze(0).to(device))
            pred = torch.argmax(output, dim=1).squeeze(0).cpu()

            # Convert to numpy for visualization
            img_np = image.permute(1, 2, 0).cpu().numpy()
            true_mask = torch.argmax(mask, dim=0)

            # Plot
            axes[i, 0].imshow(img_np)
            axes[i, 0].set_title('Input Image')
            axes[i, 0].axis('off')

            axes[i, 1].imshow(mask_to_rgb(true_mask))
            axes[i, 1].set_title('Ground Truth')
            axes[i, 1].axis('off')

            axes[i, 2].imshow(mask_to_rgb(pred))
            axes[i, 2].set_title('Prediction')
            axes[i, 2].axis('off')

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Predictions saved to {save_path}")
    plt.show()
    plt.close()

In [None]:

# %% [markdown]
# ## Training and Evaluation

# %% Training function
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    running_correct = 0
    running_total = 0

    pbar = tqdm(dataloader, desc='Training')
    for images, masks in pbar:
        images = images.to(device)
        masks = masks.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Calculate accuracy
        preds = torch.argmax(outputs, dim=1)
        true_labels = torch.argmax(masks, dim=1)
        correct = (preds == true_labels).sum().item()
        total = true_labels.numel()

        # Update metrics
        running_loss += loss.item() * images.size(0)
        running_correct += correct
        running_total += total

        # Update progress bar
        pbar.set_postfix({
            'loss': loss.item(),
            'acc': correct / total
        })

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_correct / running_total

    return epoch_loss, epoch_acc

In [None]:


# %% Validation function
def validate_epoch(model, dataloader, criterion, device):
    """Validate for one epoch."""
    model.eval()
    running_loss = 0.0
    running_correct = 0
    running_total = 0
    all_preds = []
    all_masks = []

    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Validation')
        for images, masks in pbar:
            images = images.to(device)
            masks = masks.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, masks)

            # Calculate accuracy
            preds = torch.argmax(outputs, dim=1)
            true_labels = torch.argmax(masks, dim=1)
            correct = (preds == true_labels).sum().item()
            total = true_labels.numel()

            # Update metrics
            running_loss += loss.item() * images.size(0)
            running_correct += correct
            running_total += total

            # Store for IoU calculation
            all_preds.append(preds.cpu())
            all_masks.append(masks.cpu())

            # Update progress bar
            pbar.set_postfix({
                'loss': loss.item(),
                'acc': correct / total
            })

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_correct / running_total

    # Calculate IoU
    all_preds = torch.cat(all_preds, dim=0)
    all_masks = torch.cat(all_masks, dim=0)
    epoch_iou = calculate_iou(all_masks, all_preds)

    return epoch_loss, epoch_acc, epoch_iou

In [None]:


# %% Main training function
def train_model(model, train_loader, val_loader, criterion, optimizer,
                scheduler, num_epochs, device, save_dir):
    """
    Complete training loop with validation and checkpointing.
    """
    os.makedirs(save_dir, exist_ok=True)

    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [], 'val_iou': [],
        'lr': []
    }

    best_val_iou = 0.0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        print("-" * 50)

        # Training
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device
        )

        # Validation
        val_loss, val_acc, val_iou = validate_epoch(
            model, val_loader, criterion, device
        )

        # Update learning rate
        scheduler.step(val_iou)
        current_lr = optimizer.param_groups[0]['lr']

        # Update history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_iou'].append(val_iou)
        history['lr'].append(current_lr)

        # Print epoch summary
        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"  Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val IoU: {val_iou:.4f}")
        print(f"  Learning Rate: {current_lr:.6f}")

        # Save best model
        if val_iou > best_val_iou:
            best_val_iou = val_iou
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_iou': val_iou,
                'val_loss': val_loss,
                'val_acc': val_acc,
            }, os.path.join(save_dir, 'best_model.pth'))
            print(f"  ✓ Saved best model (IoU: {val_iou:.4f})")

        # Save checkpoint every 50 epochs
        if (epoch + 1) % 50 == 0:
            checkpoint_path = os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'history': history,
            }, checkpoint_path)
            print(f"  ✓ Saved checkpoint")

    print(f"\n{'='*50}")
    print(f"Training completed!")
    print(f"Best validation IoU: {best_val_iou:.4f}")
    print(f"{'='*50}\n")

    return history, best_val_iou

In [None]:


# %% Plot training history
def plot_training_history(history, save_path=None):
    """Plot training curves."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Loss
    axes[0, 0].plot(history['train_loss'], label='Train', linewidth=2)
    axes[0, 0].plot(history['val_loss'], label='Validation', linewidth=2)
    axes[0, 0].set_title('Model Loss', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)

    # Accuracy
    axes[0, 1].plot(history['train_acc'], label='Train', linewidth=2)
    axes[0, 1].plot(history['val_acc'], label='Validation', linewidth=2)
    axes[0, 1].set_title('Model Accuracy', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)

    # Validation IoU
    axes[1, 0].plot(history['val_iou'], linewidth=2, color='green')
    axes[1, 0].set_title('Validation IoU', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('IoU')
    axes[1, 0].grid(True, alpha=0.3)

    # Learning Rate
    axes[1, 1].plot(history['lr'], linewidth=2, color='red')
    axes[1, 1].set_title('Learning Rate', fontsize=14, fontweight='bold')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].set_yscale('log')
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    #if save_path:

        # Learning Rate
    axes[1, 1].plot(history['lr'], linewidth=2)
    axes[1, 1].set_title('Learning Rate', fontsize=14, fontweight='bold')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Learning Rate')
    axes[1, 1].set_yscale('log')
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"Training curves saved to {save_path}")
    plt.show()
    plt.close()

In [None]:


# %% [markdown]
# ## Data Preparation

# %% Utility: get image/mask filenames
def get_image_mask_names(image_dir, mask_dir, ext='.png'):
    images = sorted([f for f in os.listdir(image_dir) if f.lower().endswith(ext)])
    masks  = sorted([f for f in os.listdir(mask_dir)  if f.lower().endswith(ext)])

    assert len(images) == len(masks), \
        f"Mismatch between images ({len(images)}) and masks ({len(masks)}) in {image_dir} / {mask_dir}"

    # Optional: check that names (without extension) match
    for img_name, mask_name in zip(images, masks):
        if os.path.splitext(img_name)[0] != os.path.splitext(mask_name)[0]:
            raise ValueError(f"Filename mismatch: {img_name} vs {mask_name}")

    return images, masks

# Train / Val file lists
train_image_names, train_mask_names = get_image_mask_names(
    train_images_folder_path, train_mask_folder_path
)
val_image_names, val_mask_names = get_image_mask_names(
    test_images_folder_path, test_mask_folder_path
)

print(f"Number of training samples: {len(train_image_names)}")
print(f"Number of validation samples: {len(val_image_names)}")

# %% Create Datasets and DataLoaders
pin_memory = torch.cuda.is_available()

train_dataset = CityscapesDataset(
    image_folder=train_images_folder_path,
    mask_folder=train_mask_folder_path,
    image_names=train_image_names,
    mask_names=train_mask_names,
    height=IMG_HEIGHT,
    width=IMG_WIDTH,
    augment=True
)

val_dataset = CityscapesDataset(
    image_folder=test_images_folder_path,
    mask_folder=test_mask_folder_path,
    image_names=val_image_names,
    mask_names=val_mask_names,
    height=IMG_HEIGHT,
    width=IMG_WIDTH,
    augment=False
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=pin_memory
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=pin_memory
)

print(f"Train loader batches: {len(train_loader)}")
print(f"Val loader batches: {len(val_loader)}")

# %% [markdown]
# ## Loss, Optimizer, and Scheduler

# BCE over one-hot masks and softmax probabilities
criterion = nn.BCELoss()

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Reduce LR when validation IoU plateaus
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',
    factor=0.5,
    patience=5,
    min_lr=1e-6
)

SAVE_DIR = "./unet_gdn_cityscapes"
os.makedirs(SAVE_DIR, exist_ok=True)

# %% [markdown]
# ## Run Training

history, best_val_iou = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=NUM_EPOCHS,
    device=device,
    save_dir=SAVE_DIR
)

# %% Plot training history
plot_training_history(
    history,
    save_path=os.path.join(SAVE_DIR, "training_curves.png")
)

# %% Visualize some predictions on the validation set
visualize_predictions(
    model,
    dataset=val_dataset,
    num_samples=3,
    save_path=os.path.join(SAVE_DIR, "val_predictions.png")
)


PyTorch version: 2.9.0+cu126
CUDA available: True
Using device: cuda
GPU: Tesla T4
Memory: 15.83 GB
Downloading from https://www.kaggle.com/api/v1/datasets/download/shuvoalok/cityscapes?dataset_version_number=2...


100%|██████████| 199M/199M [00:01<00:00, 162MB/s]

Extracting files...





Dataset downloaded to: /root/.cache/kagglehub/datasets/shuvoalok/cityscapes/versions/2
✓ Train images: 2975 files
✓ Train masks: 2975 files
✓ Test images: 500 files
✓ Test masks: 500 files
Configuration:
  Image size: 96x256
  Number of classes: 30
  Batch size: 32
  Learning rate: 0.001
  Epochs: 300

Model Summary:
Total parameters: 1,944,632
Trainable parameters: 1,944,626

Test forward pass:
  Input shape: torch.Size([2, 3, 96, 256])
  Output shape: torch.Size([2, 30, 96, 256])
Number of training samples: 2975
Number of validation samples: 500
Train loader batches: 93
Val loader batches: 16

Epoch 1/300
--------------------------------------------------


Training:   0%|          | 0/93 [00:00<?, ?it/s]

Validation:   0%|          | 0/16 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b1921546840>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b1921546840>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16


Epoch 1 Summary:
  Train Loss: 0.0413 | Train Acc: 0.0930
  Val Loss: 0.0370 | Val Acc: 0.1094 | Val IoU: 0.0066
  Learning Rate: 0.001000
  ✓ Saved best model (IoU: 0.0066)

Epoch 2/300
--------------------------------------------------


Training:   0%|          | 0/93 [00:00<?, ?it/s]

Validation:   0%|          | 0/16 [00:00<?, ?it/s]


Epoch 2 Summary:
  Train Loss: 0.0374 | Train Acc: 0.1287
  Val Loss: 0.0365 | Val Acc: 0.1178 | Val IoU: 0.0069
  Learning Rate: 0.001000
  ✓ Saved best model (IoU: 0.0069)

Epoch 3/300
--------------------------------------------------


Training:   0%|          | 0/93 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b1921546840>Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7b1921546840>Traceback (most recent call last):

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
Traceback (most recent call last):
      File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
self._shutdown_workers()    
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
        if w.is_alive():
if w.is_alive(): 
             ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
        assert self.

Validation:   0%|          | 0/16 [00:00<?, ?it/s]


Epoch 3 Summary:
  Train Loss: 0.0368 | Train Acc: 0.1298
  Val Loss: 0.0361 | Val Acc: 0.1184 | Val IoU: 0.0067
  Learning Rate: 0.001000

Epoch 4/300
--------------------------------------------------


Training:   0%|          | 0/93 [00:00<?, ?it/s]

Validation:   0%|          | 0/16 [00:00<?, ?it/s]


Epoch 4 Summary:
  Train Loss: 0.0364 | Train Acc: 0.1293
  Val Loss: 0.0360 | Val Acc: 0.1201 | Val IoU: 0.0069
  Learning Rate: 0.001000
  ✓ Saved best model (IoU: 0.0069)

Epoch 5/300
--------------------------------------------------


Training:   0%|          | 0/93 [00:00<?, ?it/s]

Validation:   0%|          | 0/16 [00:00<?, ?it/s]


Epoch 5 Summary:
  Train Loss: 0.0363 | Train Acc: 0.1305
  Val Loss: 0.0357 | Val Acc: 0.1283 | Val IoU: 0.0072
  Learning Rate: 0.001000
  ✓ Saved best model (IoU: 0.0072)

Epoch 6/300
--------------------------------------------------


Training:   0%|          | 0/93 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b1921546840>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b1921546840>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

Validation:   0%|          | 0/16 [00:00<?, ?it/s]


Epoch 6 Summary:
  Train Loss: 0.0361 | Train Acc: 0.1321
  Val Loss: 0.0357 | Val Acc: 0.1252 | Val IoU: 0.0070
  Learning Rate: 0.001000

Epoch 7/300
--------------------------------------------------


Training:   0%|          | 0/93 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b1921546840>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1637, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b1921546840>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1654, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16