In [1]:
import cv2
print(f"OpenCV version: {cv2.__version__}")

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

OpenCV version: 4.11.0
PyTorch version: 2.6.0+cpu
CUDA available: False


In [None]:
import torch
import os
import gc

def check_gpu():
    print(f"CUDA Available: {torch.cuda.is_available()}")

    if torch.cuda.is_available():
        print(f"Number of GPUs: {torch.cuda.device_count()}")
        print(f"GPU Device Name: {torch.cuda.get_device_name(0)}")
        print(f"Current GPU Device: {torch.cuda.current_device()}")
    else:
        print("No GPU detected. Running on CPU.")

check_gpu()

# CUDA configs
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.cuda.empty_cache()
gc.collect()

CUDA Available: True
Number of GPUs: 1
GPU Device Name: Tesla V100-SXM2-32GB
Current GPU Device: 0


212

# **Main Code**

In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import OneCycleLR
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import confusion_matrix, jaccard_score, accuracy_score
import json
import random
import shutil
from datetime import datetime

# Define paths - update these for your setup
BASE_DIR = "C:/Users/MSI GF66/Documents/Research ML/Tanmay"  # Update this with your path
ROOT_DIR = os.path.join(BASE_DIR, "KITTI")
TRAIN_DIR = os.path.join(ROOT_DIR, "training/image_2")
TRAIN_LABELS_DIR = os.path.join(ROOT_DIR, "training/semantic")
VAL_DIR = os.path.join(ROOT_DIR, "validation/image_2")  # Will be created if it doesn't exist
VAL_LABELS_DIR = os.path.join(ROOT_DIR, "validation/semantic")  # Will be created if it doesn't exist
TEST_DIR = os.path.join(ROOT_DIR, "testing/image_2")
CHECKPOINT_DIR = os.path.join(ROOT_DIR, "checkpoints")
VISUALIZATION_DIR = os.path.join(ROOT_DIR, "visualizations")
METRICS_DIR = os.path.join(ROOT_DIR, "metrics")
# Path to your pre-trained CamVid model checkpoint
CAMVID_CHECKPOINT = os.path.join(BASE_DIR, "CamVid/checkpoints/best_model_loss_b3.pth")  # Update with your model path

# Training Hyperparameters
IMAGE_HEIGHT = 1024  # KITTI images are typically larger
IMAGE_WIDTH = 1024
BATCH_SIZE = 1
ACCUMULATION_STEPS = 16
NUM_EPOCHS = 30
LEARNING_RATE = 2e-4
WEIGHT_DECAY = 0.01
GRAD_CLIP_VALUE = 1.0
MODEL_TYPE = "b3"  # b3, b4, or b5 - should match your CamVid pre-trained model
USE_REDUCED_CLASSES = True  # Use the reduced class set for better performance
FREEZE_ENCODER = False  # Set to True to freeze encoder during transfer learning

# KITTI Reduced Classes - easier to work with for the small dataset
KITTI_REDUCED_CLASSES = {
    0: {'name': 'background', 'color': (0, 0, 0)},  # Consolidate unlabeled classes
    1: {'name': 'road', 'color': (128, 64, 128)},
    2: {'name': 'sidewalk', 'color': (244, 35, 232)},
    3: {'name': 'building', 'color': (70, 70, 70)},
    4: {'name': 'wall', 'color': (102, 102, 156)},
    5: {'name': 'fence', 'color': (190, 153, 153)},
    6: {'name': 'pole', 'color': (153, 153, 153)},
    7: {'name': 'traffic light', 'color': (250, 170, 30)},
    8: {'name': 'traffic sign', 'color': (220, 220, 0)},
    9: {'name': 'vegetation', 'color': (107, 142, 35)},
    10: {'name': 'terrain', 'color': (152, 251, 152)},
    11: {'name': 'sky', 'color': (70, 130, 180)},
    12: {'name': 'person', 'color': (220, 20, 60)},
    13: {'name': 'rider', 'color': (255, 0, 0)},
    14: {'name': 'car', 'color': (0, 0, 142)},
    15: {'name': 'truck', 'color': (0, 0, 70)},
    16: {'name': 'bus', 'color': (0, 60, 100)},
    17: {'name': 'motorcycle', 'color': (0, 0, 230)},
    18: {'name': 'bicycle', 'color': (119, 11, 32)}
}

def create_validation_split(train_dir, label_dir, val_dir, val_label_dir, val_ratio=0.2, seed=42):
    """Create a validation split from the training data if it doesn't already exist."""
    # Create validation directories if they don't exist
    os.makedirs(val_dir, exist_ok=True)
    os.makedirs(val_label_dir, exist_ok=True)

    # Check if validation split already exists
    if len(os.listdir(val_dir)) > 0:
        print(f"Validation split already exists with {len(os.listdir(val_dir))} images. Skipping split creation.")
        return

    # Get all image files
    image_files = [f for f in os.listdir(train_dir) if f.endswith('.png')]

    # Set random seed for reproducibility
    random.seed(seed)

    # Determine number of validation samples
    num_val = int(len(image_files) * val_ratio)

    # Randomly select validation samples
    val_samples = random.sample(image_files, num_val)

    # Copy validation samples to validation directory
    for img_file in val_samples:
        # Copy image
        shutil.copy(
            os.path.join(train_dir, img_file),
            os.path.join(val_dir, img_file)
        )

        # Try to find matching label file
        label_found = False

        # First try: direct matching filename
        if os.path.exists(os.path.join(label_dir, img_file)):
            shutil.copy(
                os.path.join(label_dir, img_file),
                os.path.join(val_label_dir, img_file)
            )
            label_found = True
        else:
            # Second try: check alternative label patterns
            base_name = os.path.splitext(img_file)[0]
            label_candidates = [
                f"{base_name}.png",
                f"{base_name}_labelTrainIds.png",
                f"{base_name}_labelIds.png",
                f"{base_name}_gtFine_labelIds.png"
            ]

            for label_name in label_candidates:
                label_path = os.path.join(label_dir, label_name)
                if os.path.exists(label_path):
                    shutil.copy(
                        label_path,
                        os.path.join(val_label_dir, label_name)
                    )
                    label_found = True
                    break

        if not label_found:
            print(f"Warning: No label found for {img_file}")

    print(f"Created validation split with {num_val} samples")

# Define data augmentation pipelines
train_transforms = A.Compose([
    A.RandomResizedCrop(
        size=(IMAGE_HEIGHT, IMAGE_WIDTH),
        scale=(0.8, 1.0),
        ratio=(0.75, 1.33),
        p=1.0,
        interpolation=cv2.INTER_LINEAR
    ),
    A.OneOf([
        A.RandomBrightnessContrast(
            brightness_limit=0.2,
            contrast_limit=0.2,
            p=0.5
        ),
        A.RandomGamma(gamma_limit=(80, 120), p=0.5),
        A.RandomShadow(
            shadow_roi=(0, 0.5, 1, 1),
            p=0.5
        ),
    ], p=0.7),
    A.OneOf([
        A.Affine(
            scale=(0.9, 1.1),
            translate_percent={'x': (-0.1, 0.1), 'y': (-0.1, 0.1)},
            rotate=(-15, 15),
            border_mode=cv2.BORDER_CONSTANT
        ),
        A.ElasticTransform(
            alpha=120,
            sigma=6,
            p=0.5
        ),
    ], p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

val_transforms = A.Compose([
    A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH, interpolation=cv2.INTER_LINEAR),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

class KITTIDataset(Dataset):
    """Dataset class for KITTI semantic segmentation."""
    def __init__(self, image_dir, label_dir, feature_extractor, transforms=None, use_reduced_classes=True):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transforms = transforms
        self.feature_extractor = feature_extractor
        self.ignore_index = 255
        self.use_reduced_classes = use_reduced_classes

        # Set up class mappings
        self.class_mapping = KITTI_REDUCED_CLASSES
        self.num_classes = len(KITTI_REDUCED_CLASSES)

        # Create a mapping from original KITTI classes to reduced classes
        self.original_to_reduced = {}
        for i in range(35):  # KITTI has 35 classes
            if i == 7:  # road
                self.original_to_reduced[i] = 1
            elif i == 8:  # sidewalk
                self.original_to_reduced[i] = 2
            elif i == 11:  # building
                self.original_to_reduced[i] = 3
            elif i == 12:  # wall
                self.original_to_reduced[i] = 4
            elif i == 13:  # fence
                self.original_to_reduced[i] = 5
            elif i == 17 or i == 18:  # pole/polegroup
                self.original_to_reduced[i] = 6
            elif i == 19:  # traffic light
                self.original_to_reduced[i] = 7
            elif i == 20:  # traffic sign
                self.original_to_reduced[i] = 8
            elif i == 21:  # vegetation
                self.original_to_reduced[i] = 9
            elif i == 22:  # terrain
                self.original_to_reduced[i] = 10
            elif i == 23:  # sky
                self.original_to_reduced[i] = 11
            elif i == 24:  # person
                self.original_to_reduced[i] = 12
            elif i == 25:  # rider
                self.original_to_reduced[i] = 13
            elif i == 26:  # car
                self.original_to_reduced[i] = 14
            elif i == 27:  # truck
                self.original_to_reduced[i] = 15
            elif i == 28:  # bus
                self.original_to_reduced[i] = 16
            elif i == 32:  # motorcycle
                self.original_to_reduced[i] = 17
            elif i == 33:  # bicycle
                self.original_to_reduced[i] = 18
            else:  # map to background
                self.original_to_reduced[i] = 0

        # Validate and filter images
        self.images = []
        valid_extensions = ('.png', '.jpg', '.jpeg')

        print(f"Validating images in {image_dir}...")
        all_images = [f for f in os.listdir(image_dir) if f.lower().endswith(valid_extensions)]

        for img_name in all_images:
            # First try: direct mapping (same filename for image and label)
            img_path = os.path.join(image_dir, img_name)
            label_path = os.path.join(label_dir, img_name)

            if os.path.exists(label_path):
                img_test = cv2.imread(img_path)
                label_test = cv2.imread(label_path, cv2.IMREAD_UNCHANGED)
                if img_test is not None and label_test is not None:
                    self.images.append((img_name, img_name))
                    continue

            # Second try: check alternative label filenames
            base_name = os.path.splitext(img_name)[0]
            label_candidates = [
                os.path.join(label_dir, f"{base_name}.png"),
                os.path.join(label_dir, f"{base_name}_labelTrainIds.png"),
                os.path.join(label_dir, f"{base_name}_labelIds.png"),
                os.path.join(label_dir, f"{base_name}_gtFine_labelIds.png")
            ]

            valid_label_path = None
            for label_path in label_candidates:
                if os.path.exists(label_path):
                    img_test = cv2.imread(img_path)
                    label_test = cv2.imread(label_path, cv2.IMREAD_UNCHANGED)
                    if img_test is not None and label_test is not None:
                        valid_label_path = label_path
                        break

            if valid_label_path:
                self.images.append((img_name, os.path.basename(valid_label_path)))
            else:
                print(f"Warning: No valid label found for {img_name}")

        print(f"Found {len(self.images)} valid image-label pairs in {image_dir}")

        if len(self.images) == 0:
            raise RuntimeError(f"No valid image-label pairs found in {image_dir}")

    def __getitem__(self, idx):
        try:
            # Load and process input image
            img_name, label_name = self.images[idx]
            image_path = os.path.join(self.image_dir, img_name)
            image = cv2.imread(image_path)
            if image is None:
                raise ValueError(f"Failed to load image: {image_path}")
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            # Load and process label image
            label_path = os.path.join(self.label_dir, label_name)

            # Check if the label is color or single channel
            label = cv2.imread(label_path, cv2.IMREAD_UNCHANGED)
            if label is None:
                raise ValueError(f"Failed to load label: {label_path}")

            # Convert label to class indices
            if len(label.shape) == 3:  # Color image (RGB)
                h, w = label.shape[:2]
                label_mask = np.zeros((h, w), dtype=np.int64)

                for class_id, class_info in self.class_mapping.items():
                    color = class_info['color']
                    # RGB to BGR for comparison with OpenCV
                    bgr_color = (color[2], color[1], color[0])
                    mask = np.all(label == bgr_color, axis=2)
                    label_mask[mask] = class_id
            else:  # Already single channel with class IDs
                label_mask = np.zeros_like(label, dtype=np.int64)

                # Map original KITTI classes to our reduced class set
                for original_id, reduced_id in self.original_to_reduced.items():
                    label_mask[label == original_id] = reduced_id

            # Apply transforms
            if self.transforms:
                try:
                    transformed = self.transforms(image=image, mask=label_mask)
                    image = transformed['image']
                    label_mask = transformed['mask']
                except Exception as e:
                    print(f"Transform error for image {img_name}: {str(e)}")
                    raise

            return {
                'pixel_values': image,
                'labels': torch.as_tensor(label_mask, dtype=torch.long),
                'image_name': img_name
            }
        except Exception as e:
            print(f"Error processing image {self.images[idx][0]}: {str(e)}")
            return self.__getitem__((idx + 1) % len(self))

    def __len__(self):
        return len(self.images)

    def get_class_names(self):
        """Return a list of class names in order."""
        return [self.class_mapping[i]['name'] for i in range(self.num_classes)]

    def get_color_map(self):
        """Return a mapping of class IDs to colors for visualization."""
        return {i: self.class_mapping[i]['color'] for i in range(self.num_classes)}

class EnhancedSegmentationLoss(nn.Module):
    """Enhanced loss function with boundary-aware components for semantic segmentation."""
    def __init__(self, num_classes, ignore_index=255, class_weights=None):
        super().__init__()
        self.num_classes = num_classes
        self.ignore_index = ignore_index

        # Use provided class weights or create default ones
        if class_weights is None:
            # Base weights - all classes equal
            class_weights = torch.ones(num_classes)

            # Enhance weights for rare or difficult classes
            rare_classes = [7, 8, 12, 13, 17, 18]  # traffic light, sign, person, rider, motorcycle, bicycle
            small_objects = [6, 7, 8, 12, 13, 17, 18]  # poles, traffic lights, signs, pedestrians, bikes
            vehicle_classes = [14, 15, 16]  # car, truck, bus

            for class_idx in range(num_classes):
                if class_idx in small_objects:
                    class_weights[class_idx] = 3.0
                elif class_idx in vehicle_classes:
                    class_weights[class_idx] = 2.0
                elif class_idx in rare_classes:
                    class_weights[class_idx] = 2.5

        self.ce_loss = nn.CrossEntropyLoss(weight=class_weights, ignore_index=ignore_index)
        self.smooth = 1e-5

    def get_boundaries(self, tensor):
        """Extract boundaries from semantic masks for boundary-aware loss."""
        boundaries = torch.zeros_like(tensor, dtype=torch.float)
        kernel_sizes = [3, 5, 7, 9]
        weights = [0.4, 0.3, 0.2, 0.1]

        for k_size, weight in zip(kernel_sizes, weights):
            pooled = F.max_pool2d(
                tensor.float(),
                kernel_size=k_size,
                stride=1,
                padding=k_size//2
            )
            boundaries += weight * (pooled != tensor.float()).float()
        return boundaries

    def calculate_iou_loss(self, pred, target):
        """Calculate IoU loss for better boundary prediction."""
        pred = F.softmax(pred, dim=1)
        pred = pred.flatten(2)
        target = F.one_hot(target, num_classes=self.num_classes).permute(0, 3, 1, 2).flatten(2)

        intersection = (pred * target).sum(-1)
        total = (pred + target).sum(-1)
        union = total - intersection
        valid_mask = union > self.smooth
        iou = torch.zeros_like(intersection)
        iou[valid_mask] = (intersection[valid_mask] + self.smooth) / (union[valid_mask] + self.smooth)

        return 1 - iou.mean()

    def forward(self, outputs, targets):
        """Combined loss function with weighted components."""
        ce_loss = self.ce_loss(outputs, targets)
        iou_loss = self.calculate_iou_loss(outputs, targets)

        edges = self.get_boundaries(targets)
        pred_edges = self.get_boundaries(torch.argmax(outputs, dim=1))
        boundary_loss = F.mse_loss(pred_edges, edges)

        total_loss = ce_loss + 0.4 * iou_loss + 0.8 * boundary_loss

        return total_loss, {'ce_loss': ce_loss.item(),
                           'iou_loss': iou_loss.item(),
                           'boundary_loss': boundary_loss.item()}

def train_epoch(model, train_loader, optimizer, scheduler, scaler, criterion, device, epoch):
    """Train model for one epoch."""
    model.train()
    epoch_loss = 0
    batch_losses = []
    component_losses = {'ce_loss': 0, 'iou_loss': 0, 'boundary_loss': 0}

    pbar = tqdm(train_loader, desc=f'Training Epoch {epoch}')
    for batch_idx, batch in enumerate(pbar):
        try:
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)

            with torch.amp.autocast(device_type=str(device), dtype=torch.float16):
                outputs = model(pixel_values=pixel_values)
                logits = outputs.logits

                logits = F.interpolate(
                    logits,
                    size=labels.shape[-2:],
                    mode="bilinear",
                    align_corners=False
                )

                loss, loss_components = criterion(logits, labels)
                loss = loss / ACCUMULATION_STEPS

            scaler.scale(loss).backward()

            if (batch_idx + 1) % ACCUMULATION_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_VALUE)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            # Track losses
            current_loss = loss.item() * ACCUMULATION_STEPS
            epoch_loss += current_loss
            batch_losses.append(current_loss)

            # Track component losses
            for k, v in loss_components.items():
                component_losses[k] += v

            pbar.set_postfix({
                'loss': current_loss,
                'avg_loss': epoch_loss / (batch_idx + 1)
            })

        except Exception as e:
            print(f"Error in batch {batch_idx}: {str(e)}")
            continue

    # Average component losses
    for k in component_losses:
        component_losses[k] /= len(train_loader)

    return epoch_loss / len(train_loader), batch_losses, component_losses

def calculate_metrics(predictions, ground_truths, num_classes, ignore_index=255):
    """Calculate comprehensive segmentation metrics including mIoU."""
    # Flatten tensors for metric calculation
    preds = np.concatenate([p.numpy().flatten() for p in predictions])
    gts = np.concatenate([g.numpy().flatten() for g in ground_truths])

    # Filter out ignored pixels
    valid_idx = gts != ignore_index
    preds = preds[valid_idx]
    gts = gts[valid_idx]

    # Overall pixel accuracy
    pixel_acc = accuracy_score(gts, preds)

    # Class-wise metrics
    class_ious = jaccard_score(gts, preds, average=None, labels=range(num_classes), zero_division=0)

    # Mean IoU
    mean_iou = np.mean(class_ious)

    # Confusion matrix
    conf_mat = confusion_matrix(gts, preds, labels=range(num_classes))

    # Class-wise precision and recall
    precision = np.zeros(num_classes)
    recall = np.zeros(num_classes)
    f1_score = np.zeros(num_classes)

    for i in range(num_classes):
        true_pos = conf_mat[i, i]
        false_pos = conf_mat[:, i].sum() - true_pos
        false_neg = conf_mat[i, :].sum() - true_pos

        # Calculate precision and recall with handling for division by zero
        if true_pos + false_pos > 0:
            precision[i] = true_pos / (true_pos + false_pos)
        if true_pos + false_neg > 0:
            recall[i] = true_pos / (true_pos + false_neg)
        if precision[i] + recall[i] > 0:
            f1_score[i] = 2 * (precision[i] * recall[i]) / (precision[i] + recall[i])

    # Frequency weighted IoU
    class_freq = np.bincount(gts, minlength=num_classes) / len(gts)
    freq_weighted_iou = (class_ious * class_freq).sum()

    metrics = {
        'pixel_accuracy': pixel_acc,
        'mean_iou': mean_iou,
        'class_iou': class_ious.tolist(),
        'precision': precision.tolist(),
        'recall': recall.tolist(),
        'f1_score': f1_score.tolist(),
        'freq_weighted_iou': freq_weighted_iou,
        'confusion_matrix': conf_mat.tolist()
    }

    return metrics

@torch.no_grad()
def validate(model, val_loader, criterion, device, dataset):
    """Validate model on validation set."""
    model.eval()
    val_loss = 0
    component_losses = {'ce_loss': 0, 'iou_loss': 0, 'boundary_loss': 0}
    predictions = []
    ground_truths = []
    image_names = []

    for batch in tqdm(val_loader, desc='Validation'):
        try:
            pixel_values = batch['pixel_values'].to(device)
            labels = batch['labels'].to(device)
            image_names.extend(batch['image_name'])

            outputs = model(pixel_values=pixel_values)
            logits = outputs.logits

            logits = F.interpolate(
                logits,
                size=labels.shape[-2:],
                mode="bilinear",
                align_corners=False
            )

            loss, loss_comps = criterion(logits, labels)
            val_loss += loss.item()

            # Track component losses
            for k, v in loss_comps.items():
                component_losses[k] += v

            preds = torch.argmax(logits, dim=1).cpu()
            predictions.append(preds)
            ground_truths.append(labels.cpu())

        except Exception as e:
            print(f"Error during validation: {str(e)}")
            continue

    # Average component losses
    for k in component_losses:
        component_losses[k] /= len(val_loader)

    # Calculate detailed metrics
    metrics = calculate_metrics(predictions, ground_truths, dataset.num_classes)
    metrics['val_loss'] = val_loss / len(val_loader)
    metrics['component_losses'] = component_losses

    return metrics, predictions, ground_truths, image_names

def visualize_predictions(images, predictions, ground_truths, class_colors, class_names, output_dir, limit=10):
    """Visualize model predictions compared to ground truth."""
    os.makedirs(output_dir, exist_ok=True)

    # Create color maps for visualization
    color_map = np.zeros((max(class_colors.keys()) + 1, 3), dtype=np.uint8)
    for class_id, color in class_colors.items():
        color_map[class_id] = color

    for i, (image_name, pred, gt) in enumerate(zip(images, predictions, ground_truths)):
        if i >= limit:
            break

        # Get shapes and ensure they're 2D
        if len(pred.shape) == 3:
            pred = pred[0]  # Take the first item if batched
        if len(gt.shape) == 3:
            gt = gt[0]  # Take the first item if batched

        # Create colored predictions and ground truth
        pred_colored = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8)
        gt_colored = np.zeros((gt.shape[0], gt.shape[1], 3), dtype=np.uint8)

        # Apply color map
        for class_id, color in class_colors.items():
            pred_mask = (pred == class_id)
            gt_mask = (gt == class_id)

            for c in range(3):
                pred_colored[:, :, c][pred_mask] = color[c]
                gt_colored[:, :, c][gt_mask] = color[c]

        # Create visualization grid
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))

        axes[0].imshow(gt_colored)
        axes[0].set_title("Ground Truth")
        axes[0].axis('off')

        axes[1].imshow(pred_colored)
        axes[1].set_title("Prediction")
        axes[1].axis('off')

        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"vis_{image_name}"), dpi=200)
        plt.close()

    # Create a legend for class colors
    fig, ax = plt.subplots(figsize=(10, 8))
    patches = []
    for class_id, class_name in enumerate(class_names):
        if class_id in class_colors:
            color = np.array(class_colors[class_id]) / 255.0
            patches.append(plt.Rectangle((0, 0), 1, 1, fc=color, label=class_name))

    ax.legend(handles=patches, loc='center', ncol=2)
    ax.set_axis_off()
    plt.savefig(os.path.join(output_dir, "class_legend.png"), dpi=200)
    plt.close()

def plot_training_curves(train_losses, val_losses, component_losses, metrics, save_dir):
    """Plot training curves and metrics."""
    os.makedirs(save_dir, exist_ok=True)

    # Plot overall losses
    plt.figure(figsize=(12, 6))
    plt.plot(train_losses, label='Training Loss', color='blue', alpha=0.7)
    plt.plot(val_losses, label='Validation Loss', color='red', alpha=0.7)

    window_size = 5
    if len(train_losses) >= window_size:
        train_ma = np.convolve(train_losses, np.ones(window_size)/window_size, mode='valid')
        val_ma = np.convolve(val_losses, np.ones(window_size)/window_size, mode='valid')
        plt.plot(range(window_size-1, len(train_losses)), train_ma,
                '--', color='darkblue', alpha=0.5, label='Train Moving Avg')
        plt.plot(range(window_size-1, len(val_losses)), val_ma,
                '--', color='darkred', alpha=0.5, label='Val Moving Avg')

    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(save_dir, 'loss_curves.png'), dpi=200)
    plt.close()

    # Plot component losses
    plt.figure(figsize=(12, 6))
    plt.plot([d['ce_loss'] for d in component_losses['train']], label='CE Loss (Train)', color='blue')
    plt.plot([d['iou_loss'] for d in component_losses['train']], label='IoU Loss (Train)', color='green')
    plt.plot([d['boundary_loss'] for d in component_losses['train']], label='Boundary Loss (Train)', color='orange')
    plt.plot([d['ce_loss'] for d in component_losses['val']], label='CE Loss (Val)', color='blue', linestyle='--')
    plt.plot([d['iou_loss'] for d in component_losses['val']], label='IoU Loss (Val)', color='green', linestyle='--')
    plt.plot([d['boundary_loss'] for d in component_losses['val']], label='Boundary Loss (Val)', color='orange', linestyle='--')

    plt.title('Loss Components Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Loss Value')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(save_dir, 'component_losses.png'), dpi=200)
    plt.close()

    # Plot metrics
    plt.figure(figsize=(12, 6))
    plt.plot(metrics['mean_iou'], label='Mean IoU', color='blue', marker='o')
    plt.plot(metrics['pixel_accuracy'], label='Pixel Accuracy', color='green', marker='s')
    plt.plot(metrics['freq_weighted_iou'], label='Freq Weighted IoU', color='red', marker='^')

    plt.title('Segmentation Metrics Over Time')
    plt.xlabel('Epoch')
    plt.ylabel('Metric Value')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.ylim(0, 1)
    plt.savefig(os.path.join(save_dir, 'metrics.png'), dpi=200)
    plt.close()

    # Plot per-class IoU
    if len(metrics['class_iou']) > 0:
        class_names = [f"Class {i}" for i in range(len(metrics['class_iou'][0]))]
        last_class_iou = metrics['class_iou'][-1]

        plt.figure(figsize=(12, 8))
        plt.bar(class_names, last_class_iou)
        plt.title('IoU by Class (Final Epoch)')
        plt.xlabel('Class')
        plt.ylabel('IoU')
        plt.xticks(rotation=45, ha='right')
        plt.grid(True, alpha=0.3, axis='y')
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'class_iou.png'), dpi=200)
        plt.close()

def evaluate_transfer_learning_performance():
    """Compare performance of baseline and transfer learning."""
    # Load metrics from the best checkpoints
    transfer_metrics_path = os.path.join(METRICS_DIR, 'transfer/best_metrics.json')
    baseline_metrics_path = os.path.join(METRICS_DIR, 'baseline/best_metrics.json')

    if os.path.exists(transfer_metrics_path) and os.path.exists(baseline_metrics_path):
        with open(transfer_metrics_path, 'r') as f:
            transfer_metrics = json.load(f)

        with open(baseline_metrics_path, 'r') as f:
            baseline_metrics = json.load(f)

        # Create comparison directory
        comparison_dir = os.path.join(VISUALIZATION_DIR, 'comparison')
        os.makedirs(comparison_dir, exist_ok=True)

        # Plot overall metrics comparison
        metrics_to_compare = ['mean_iou', 'pixel_accuracy', 'freq_weighted_iou']
        values_transfer = [transfer_metrics[m] for m in metrics_to_compare]
        values_baseline = [baseline_metrics[m] for m in metrics_to_compare]

        plt.figure(figsize=(10, 6))
        x = np.arange(len(metrics_to_compare))
        width = 0.35

        plt.bar(x - width/2, values_transfer, width, label='With Transfer Learning')
        plt.bar(x + width/2, values_baseline, width, label='From Scratch')

        plt.xticks(x, [m.replace('_', ' ').title() for m in metrics_to_compare])
        plt.ylim(0, 1)
        plt.ylabel('Score')
        plt.title('Performance Comparison: Transfer Learning vs Training from Scratch')
        plt.legend()
        plt.grid(True, alpha=0.3, axis='y')
        plt.savefig(os.path.join(comparison_dir, 'metrics_comparison.png'), dpi=200)
        plt.close()

        # Create per-class IoU comparison
        # class_names = [f"Class {i}" for i in range(len(transfer_metrics['class_iou']))]
        class_names = [KITTI_REDUCED_CLASSES[i]['name'] for i in range(len(transfer_metrics['class_iou']))]


        plt.figure(figsize=(14, 8))
        x = np.arange(len(class_names))
        width = 0.35

        plt.bar(x - width/2, transfer_metrics['class_iou'], width, label='With Transfer Learning')
        plt.bar(x + width/2, baseline_metrics['class_iou'], width, label='From Scratch')

        plt.xticks(x, class_names, rotation=45, ha='right')
        plt.ylim(0, 1)
        plt.ylabel('IoU Score')
        plt.title('Per-Class IoU Comparison: Transfer Learning vs Training from Scratch')
        plt.legend()
        plt.grid(True, alpha=0.3, axis='y')
        plt.tight_layout()
        plt.savefig(os.path.join(comparison_dir, 'class_iou_comparison.png'), dpi=200)
        plt.close()

        # Generate report
        with open(os.path.join(METRICS_DIR, 'performance_comparison.txt'), 'w') as f:
            f.write("# SegFormer Transfer Learning Performance Comparison\n\n")

            f.write("## Overall Metrics\n\n")
            f.write("| Metric | With Transfer Learning | From Scratch | Improvement |\n")
            f.write("|--------|------------------------|--------------|-------------|\n")

            for metric in metrics_to_compare:
                tl_value = transfer_metrics[metric]
                baseline_value = baseline_metrics[metric]
                improvement = tl_value - baseline_value
                improvement_percent = (improvement / baseline_value) * 100 if baseline_value > 0 else float('inf')

                f.write(f"| {metric.replace('_', ' ').title()} | {tl_value:.4f} | {baseline_value:.4f} | {improvement_percent:+.2f}% |\n")

            f.write("\n## Classes with Most Improvement\n\n")
            f.write("| Class | With Transfer Learning | From Scratch | Improvement |\n")
            f.write("|-------|------------------------|--------------|-------------|\n")

            # Calculate improvement for each class
            improvements = []
            for i, (tl, baseline) in enumerate(zip(transfer_metrics['class_iou'], baseline_metrics['class_iou'])):
                if baseline > 0:
                    improvement_percent = (tl - baseline) / baseline * 100
                else:
                    improvement_percent = float('inf') if tl > 0 else 0

                improvements.append((i, tl, baseline, improvement_percent))

            # Sort by improvement (descending)
            improvements.sort(key=lambda x: x[3], reverse=True)

            # Write top 5 most improved classes
            for i, tl, baseline, imp_percent in improvements[:5]:
                f.write(f"| Class {i} | {tl:.4f} | {baseline:.4f} | {imp_percent:+.2f}% |\n")

        print("Performance comparison completed. Results saved to metrics directory.")
    else:
        print("Metrics files not found. Run both baseline and transfer learning experiments first.")

def train_and_evaluate(experiment_type):
    """
    Train and evaluate SegFormer model on KITTI.
    experiment_type: 'baseline' (from scratch) or 'transfer' (using CamVid weights)
    """
    # Set random seeds for reproducibility
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

    # Clear CUDA cache before starting
    torch.cuda.empty_cache()

    # Set up experiment directories
    exp_checkpoint_dir = os.path.join(CHECKPOINT_DIR, experiment_type)
    exp_visualization_dir = os.path.join(VISUALIZATION_DIR, experiment_type)
    exp_metrics_dir = os.path.join(METRICS_DIR, experiment_type)

    # Create directories
    for directory in [exp_checkpoint_dir, exp_visualization_dir, exp_metrics_dir]:
        os.makedirs(directory, exist_ok=True)

    # Get device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create validation split if it doesn't exist
    if not os.path.exists(VAL_DIR) or len(os.listdir(VAL_DIR)) == 0:
        print("Creating validation split...")
        create_validation_split(
            TRAIN_DIR,
            TRAIN_LABELS_DIR,
            VAL_DIR,
            VAL_LABELS_DIR,
            val_ratio=0.2
        )

    try:
        # Initialize feature extractor
        feature_extractor = SegformerImageProcessor.from_pretrained(
            f"nvidia/mit-{MODEL_TYPE}",
            do_reduce_labels=True,
            do_rescale=False,
            size={"height": IMAGE_HEIGHT, "width": IMAGE_WIDTH}
        )

        # Create datasets
        train_dataset = KITTIDataset(
            TRAIN_DIR,
            TRAIN_LABELS_DIR,
            feature_extractor,
            transforms=train_transforms,
            use_reduced_classes=USE_REDUCED_CLASSES
        )

        val_dataset = KITTIDataset(
            VAL_DIR,
            VAL_LABELS_DIR,
            feature_extractor,
            transforms=val_transforms,
            use_reduced_classes=USE_REDUCED_CLASSES
        )

        print(f"Number of classes: {train_dataset.num_classes}")
        print(f"Number of training samples: {len(train_dataset)}")
        print(f"Number of validation samples: {len(val_dataset)}")

        # Create dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=BATCH_SIZE,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )

        # Initialize model
        model = SegformerForSemanticSegmentation.from_pretrained(
            f"nvidia/mit-{MODEL_TYPE}",
            num_labels=train_dataset.num_classes,
            id2label={str(i): str(i) for i in range(train_dataset.num_classes)},
            label2id={str(i): i for i in range(train_dataset.num_classes)},
            ignore_mismatched_sizes=True
        ).to(device)

        # Enable memory efficient attention if available
        if hasattr(model.config, 'use_memory_efficient_attention'):
            model.config.use_memory_efficient_attention = True

        # For transfer learning experiment, load CamVid weights
        if experiment_type == 'transfer' and os.path.exists(CAMVID_CHECKPOINT):
            print(f"Loading weights from CamVid checkpoint: {CAMVID_CHECKPOINT}")
            # Load state dictionary
            checkpoint = torch.load(CAMVID_CHECKPOINT, map_location=device)

            if 'model_state_dict' in checkpoint:
                state_dict = checkpoint['model_state_dict']
            else:
                state_dict = checkpoint  # In case the entire state dict was saved

            # Filter out mismatched keys (especially the classification head)
            model_dict = model.state_dict()
            pretrained_dict = {k: v for k, v in state_dict.items()
                              if k in model_dict and v.shape == model_dict[k].shape}

            # Update model with pretrained weights
            model_dict.update(pretrained_dict)
            model.load_state_dict(model_dict)

            print(f"Loaded {len(pretrained_dict)}/{len(model_dict)} layers from checkpoint")

            # Freeze encoder parameters if specified
            if FREEZE_ENCODER:
                print("Freezing encoder parameters...")
                for name, param in model.segformer.encoder.named_parameters():
                    param.requires_grad = False

                # Only train the decoder (segmentation head)
                trainable_params = sum(p.numel() for p in model.decode_head.parameters() if p.requires_grad)
                total_params = sum(p.numel() for p in model.parameters())
                print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%})")

        # Initialize optimizer and scheduler
        optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=LEARNING_RATE,
            weight_decay=WEIGHT_DECAY
        )

        total_steps = len(train_loader) * NUM_EPOCHS // ACCUMULATION_STEPS
        scheduler = OneCycleLR(
            optimizer,
            max_lr=LEARNING_RATE,
            total_steps=total_steps,
            pct_start=0.1
        )

        # Initialize criterion
        criterion = EnhancedSegmentationLoss(train_dataset.num_classes).to(device)

        # Initialize AMP scaler
        scaler = torch.amp.GradScaler()

        # Training loop
        train_losses = []
        val_losses = []
        all_metrics = []
        component_losses = {'train': [], 'val': []}
        best_val_miou = 0

        for epoch in range(NUM_EPOCHS):
            print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")

            # Training
            train_loss, batch_losses, train_comp_losses = train_epoch(
                model, train_loader, optimizer, scheduler, scaler, criterion, device, epoch
            )
            train_losses.append(train_loss)
            component_losses['train'].append(train_comp_losses)

            # Validation
            val_metrics, predictions, ground_truths, image_names = validate(
                model, val_loader, criterion, device, val_dataset
            )
            val_losses.append(val_metrics['val_loss'])
            component_losses['val'].append(val_metrics['component_losses'])
            all_metrics.append(val_metrics)

            print(f"Train Loss: {train_loss:.4f}")
            print(f"Val Loss: {val_metrics['val_loss']:.4f}")
            print(f"Val Mean IoU: {val_metrics['mean_iou']:.4f}")
            print(f"Val Pixel Accuracy: {val_metrics['pixel_accuracy']:.4f}")

            # Generate visualizations for this epoch
            epoch_vis_dir = os.path.join(exp_visualization_dir, f'epoch_{epoch+1}')
            visualize_predictions(
                image_names[:10],  # Use first 10 images
                predictions[:10],
                ground_truths[:10],
                val_dataset.get_color_map(),
                val_dataset.get_class_names(),
                epoch_vis_dir
            )

            # Save checkpoint
            if val_metrics['mean_iou'] > best_val_miou:
                best_val_miou = val_metrics['mean_iou']
                checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_loss': val_metrics['val_loss'],
                    'val_miou': val_metrics['mean_iou'],
                    'config': {
                        'model_type': MODEL_TYPE,
                        'image_size': (IMAGE_HEIGHT, IMAGE_WIDTH),
                        'num_classes': train_dataset.num_classes,
                        'class_names': val_dataset.get_class_names(),
                        'use_reduced_classes': USE_REDUCED_CLASSES,
                        'transfer_learning': experiment_type == 'transfer',
                        'freeze_encoder': FREEZE_ENCODER and experiment_type == 'transfer'
                    }
                }
                checkpoint_path = os.path.join(
                    exp_checkpoint_dir,
                    f'best_model_miou_{val_metrics["mean_iou"]:.4f}.pth'
                )
                torch.save(checkpoint, checkpoint_path)

                # Also save the best metrics for later comparison
                with open(os.path.join(exp_metrics_dir, 'best_metrics.json'), 'w') as f:
                    json.dump(val_metrics, f, indent=2)

                print(f"New best model saved! Val Mean IoU: {val_metrics['mean_iou']:.4f}")

            # Plot training curves
            plot_training_curves(
                train_losses,
                val_losses,
                component_losses,
                {
                    'mean_iou': [m['mean_iou'] for m in all_metrics],
                    'pixel_accuracy': [m['pixel_accuracy'] for m in all_metrics],
                    'freq_weighted_iou': [m['freq_weighted_iou'] for m in all_metrics],
                    'class_iou': [m['class_iou'] for m in all_metrics]
                },
                exp_visualization_dir
            )

            # Save metrics for this epoch
            with open(os.path.join(exp_metrics_dir, f'metrics_epoch_{epoch+1}.json'), 'w') as f:
                json.dump(val_metrics, f, indent=2)

            # Clear cache after each epoch
            torch.cuda.empty_cache()

        print(f"Training completed for {experiment_type} experiment!")

        # Save final model
        final_checkpoint = {
            'model_state_dict': model.state_dict(),
            'config': {
                'model_type': MODEL_TYPE,
                'image_size': (IMAGE_HEIGHT, IMAGE_WIDTH),
                'num_classes': train_dataset.num_classes,
                'class_names': val_dataset.get_class_names(),
                'use_reduced_classes': USE_REDUCED_CLASSES,
                'transfer_learning': experiment_type == 'transfer'
            }
        }
        torch.save(
            final_checkpoint,
            os.path.join(exp_checkpoint_dir, f'final_model.pth')
        )

    except Exception as e:
        print(f"Error during training: {str(e)}")
        import traceback
        traceback.print_exc()

# Functions for running experiments in Jupyter
def run_baseline():
    """Run baseline experiment (train from scratch)."""
    print("Running baseline experiment (SegFormer on KITTI from scratch)")
    train_and_evaluate('baseline')

def run_transfer():
    """Run transfer learning experiment (using CamVid weights)."""
    print("Running transfer learning experiment (CamVid â†’ KITTI)")
    train_and_evaluate('transfer')

def run_both():
    """Run both experiments sequentially and compare."""
    print("Running baseline experiment (SegFormer on KITTI from scratch)")
    train_and_evaluate('baseline')
    print("\n\nRunning transfer learning experiment (CamVid â†’ KITTI)")
    train_and_evaluate('transfer')
    print("\n\nGenerating comparison reports")
    evaluate_transfer_learning_performance()

def run_comparison():
    """Just compare existing results."""
    print("Generating comparison reports")
    evaluate_transfer_learning_performance()

# If running as a script (not in Jupyter)
# Comment out the entire main block
'''
if __name__ == '__main__':
    import sys

    if len(sys.argv) > 1:
        mode = sys.argv[1]
        if mode == 'baseline':
            run_baseline()
        elif mode == 'transfer':
            run_transfer()
        elif mode == 'compare':
            run_comparison()
        else:
            run_both()
    else:
        run_both()
'''

# Add explicit call to run_transfer at the bottom
print("About to start transfer learning...")
run_comparison()

About to start transfer learning...
Generating comparison reports
Performance comparison completed. Results saved to metrics directory.


In [4]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
from transformers import SegformerForSemanticSegmentation, SegformerImageProcessor
import albumentations as A
from albumentations.pytorch import ToTensorV2

def get_transform(image_height, image_width):
    return A.Compose([
        A.Resize(height=image_height, width=image_width),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

def register_hooks(model, activations, gradients):
    def forward_hook(module, input, output):
        activations.append(output)

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])

    handle_f = model.decode_head.classifier.register_forward_hook(forward_hook)
    handle_b = model.decode_head.classifier.register_backward_hook(backward_hook)
    return handle_f, handle_b

def generate_gradcam(model, input_tensor, class_index, activations, gradients):
    output = model(pixel_values=input_tensor)
    logits = output.logits  # (1, C, H, W)

    model.zero_grad()
    logits[0, class_index].mean().backward()

    act = activations[0].detach().cpu()[0]  # [C, H, W]
    grad = gradients[0].detach().cpu()[0]   # [C, H, W]
    weights = grad.mean(dim=(1, 2))         # GAP

    cam = (weights[:, None, None] * act).sum(dim=0)
    cam = torch.relu(cam)
    cam -= cam.min()
    cam /= cam.max()
    return cam.numpy()

def overlay_heatmap(original_img, cam):
    cam_resized = cv2.resize(cam, (original_img.shape[1], original_img.shape[0]))
    heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(original_img, 0.5, heatmap, 0.5, 0)
    return overlay

def run_gradcam(image_path, checkpoint_path, model_type='b3', image_height=1024, image_width=1024, class_index=None, output_path='gradcam_output.png'):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load model config
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

    model = SegformerForSemanticSegmentation.from_pretrained(
        f"nvidia/mit-{model_type}",
        num_labels=checkpoint['config']['num_classes'],
        ignore_mismatched_sizes=True
    )
    model.load_state_dict(checkpoint['model_state_dict'])  # Load your trained weights
    model.to(device).eval()

    # Load and transform image
    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    transform = get_transform(image_height, image_width)
    transformed = transform(image=img_rgb)
    input_tensor = transformed['image'].unsqueeze(0).to(device)

    # Register hooks
    activations, gradients = [], []
    handle_f, handle_b = register_hooks(model, activations, gradients)

    # Run forward pass and Grad-CAM
    with torch.no_grad():
        logits = model(pixel_values=input_tensor).logits

    # Safe spatial indexing
    _, _, H_out, W_out = logits.shape
    class_index = logits[0, :, H_out // 2, W_out // 2].argmax().item()

    cam = generate_gradcam(model, input_tensor, class_index, activations, gradients)

    # Generate and save visualization
    overlay = overlay_heatmap(img, cam)
    cv2.imwrite(output_path, overlay)

    # Cleanup
    handle_f.remove()
    handle_b.remove()

    print(f"Grad-CAM saved to: {output_path}")

# Example usage:
# run_gradcam(
#     image_path="PATH/TO/YOUR/TEST_IMAGE.png",
#     checkpoint_path="PATH/TO/best_model_miou.pth",
#     output_path="PATH/TO/save/gradcam.png"
# )


In [5]:
def run_multiclass_gradcam(image_path, checkpoint_path, model_type='b3', image_height=1024, image_width=1024, output_dir='gradcam_outputs'):
    os.makedirs(output_dir, exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    num_classes = checkpoint['config']['num_classes']

    model = SegformerForSemanticSegmentation.from_pretrained(
        f"nvidia/mit-{model_type}",
        num_labels=num_classes,
        ignore_mismatched_sizes=True
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device).eval()

    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    transform = get_transform(image_height, image_width)
    transformed = transform(image=img_rgb)
    input_tensor = transformed['image'].unsqueeze(0).to(device)

    activations, gradients = [], []
    handle_f, handle_b = register_hooks(model, activations, gradients)

    # Run a forward pass once
    with torch.no_grad():
        logits = model(pixel_values=input_tensor).logits

    for class_index in range(num_classes):
        # Reset hooks storage for clean gradients
        activations.clear()
        gradients.clear()

        cam = generate_gradcam(model, input_tensor, class_index, activations, gradients)
        overlay = overlay_heatmap(img, cam)

        class_names = checkpoint['config'].get('class_names', [f"Class_{i}" for i in range(num_classes)])
        out_path = os.path.join(output_dir, f"{class_names[class_index]}_gradcam.png")

        cv2.imwrite(out_path, overlay)
        print(f"Saved: {out_path}")

    handle_f.remove()
    handle_b.remove()


In [6]:
run_multiclass_gradcam(
    image_path="KITTI/testing/image_2/000135_10.png",
    checkpoint_path="best_model_miou_0.5342.pth",
    output_dir="output/gradcam_visualization.png"
)


Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b3 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Saved: output/gradcam_visualization.png\background_gradcam.png
Saved: output/gradcam_visualization.png\road_gradcam.png
Saved: output/gradcam_visualization.png\sidewalk_gradcam.png
Saved: output/gradcam_visualization.png\building_gradcam.png
Saved: output/gradcam_visualization.png\wall_gradcam.png
Saved: output/gradcam_visualization.png\fence_gradcam.png
Saved: output/gradcam_visualization.png\pole_gradcam.png
Saved: output/gradcam_visualization.png\traffic light_gradcam.png
Saved: output/gradcam_visualization.png\traffic sign_gradcam.png
Saved: output/gradcam_visualization.png\vegetation_gradcam.png
Saved: output/gradcam_visualization.png\terrain_gradcam.png
Saved: output/gradcam_visualization.png\sky_gradcam.png
Saved: output/gradcam_visualization.png\person_gradcam.png
Saved: output/gradcam_visualization.png\rider_gradcam.png
Saved: output/gradcam_visualization.png\car_gradcam.png
Saved: output/gradcam_visualization.png\truck_gradcam.png
Saved: output/gradcam_visualization.png\bus_g

  heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)


Saved: output/gradcam_visualization.png\motorcycle_gradcam.png
Saved: output/gradcam_visualization.png\bicycle_gradcam.png


In [36]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
from transformers import SegformerForSemanticSegmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image, ImageDraw, ImageFont

def get_transform(image_height, image_width):
    return A.Compose([
        A.Resize(height=image_height, width=image_width),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

def register_hooks(model, activations, gradients):
    def forward_hook(module, input, output):
        activations.append(output)

    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])

    handle_f = model.decode_head.classifier.register_forward_hook(forward_hook)
    handle_b = model.decode_head.classifier.register_backward_hook(backward_hook)
    return handle_f, handle_b

def generate_gradcam(model, input_tensor, class_index, activations, gradients):
    output = model(pixel_values=input_tensor)
    logits = output.logits

    model.zero_grad()
    logits[0, class_index].mean().backward()

    act = activations[0].detach().cpu()[0]
    grad = gradients[0].detach().cpu()[0]
    weights = grad.mean(dim=(1, 2))

    cam = (weights[:, None, None] * act).sum(dim=0)
    cam = torch.relu(cam)
    cam -= cam.min()
    cam /= (cam.max() + 1e-5)
    return cam.numpy()

def apply_class_colormap(cam, color_map_id):
    cam_uint8 = np.uint8(255 * cam)
    heatmap = cv2.applyColorMap(cam_uint8, color_map_id)
    return heatmap

def create_label_legend(class_names, colormaps, save_path):
    num_classes = len(class_names)
    img_height = 30 * num_classes + 10
    legend_img = Image.new("RGB", (300, img_height), color=(255, 255, 255))
    draw = ImageDraw.Draw(legend_img)
    
    try:
        font = ImageFont.truetype("arial.ttf", size=18)
    except:
        font = ImageFont.load_default()

    for i, name in enumerate(class_names):
        color_rgb = tuple(int(255 * c) for c in plt.cm.get_cmap('tab20')(i % 20)[:3])
        draw.rectangle([(10, i * 30 + 10), (30, i * 30 + 30)], fill=color_rgb)
        draw.text((40, i * 30 + 10), name, fill=(0, 0, 0), font=font)

    legend_img.save(save_path)
    print(f"âœ… Label legend saved at: {save_path}")

def run_combined_multiclass_gradcam(image_path, checkpoint_path, model_type='b3', image_height=1024, image_width=1024, output_path='output/combined_gradcam.png'):
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    num_classes = checkpoint['config']['num_classes']
    class_names = checkpoint['config'].get('class_names', [f"Class_{i}" for i in range(num_classes)])

    model = SegformerForSemanticSegmentation.from_pretrained(
        f"nvidia/mit-{model_type}",
        num_labels=num_classes,
        ignore_mismatched_sizes=True
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device).eval()

    img = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    transform = get_transform(image_height, image_width)
    transformed = transform(image=img_rgb)
    input_tensor = transformed['image'].unsqueeze(0).to(device)

    activations, gradients = [], []
    handle_f, handle_b = register_hooks(model, activations, gradients)

    colormaps = [
        cv2.COLORMAP_JET, cv2.COLORMAP_HOT, cv2.COLORMAP_COOL, cv2.COLORMAP_OCEAN,
        cv2.COLORMAP_SUMMER, cv2.COLORMAP_AUTUMN, cv2.COLORMAP_WINTER, cv2.COLORMAP_SPRING,
        cv2.COLORMAP_PINK, cv2.COLORMAP_BONE
    ]

    combined_heatmap = np.zeros_like(img_rgb, dtype=np.float32)

    for class_index in range(num_classes):
        activations.clear()
        gradients.clear()

        cam = generate_gradcam(model, input_tensor, class_index, activations, gradients)
        cam_resized = cv2.resize(cam, (img.shape[1], img.shape[0]))
        heatmap = apply_class_colormap(cam_resized, colormaps[class_index % len(colormaps)])
        heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0

        # Boost and accumulate for final composite
        combined_heatmap += 1.5 * heatmap_rgb

        # ðŸ’¾ Save per-class heatmap
        class_overlay = (0.5 * img_rgb / 255.0 + 0.5 * heatmap_rgb)
        class_overlay = np.clip(class_overlay * 255, 0, 255).astype(np.uint8)
        class_name = class_names[class_index].replace(" ", "_").lower()
        class_out_path = os.path.join(os.path.dirname(output_path), f"gradcam_{class_name}.png")
        cv2.imwrite(class_out_path, cv2.cvtColor(class_overlay, cv2.COLOR_RGB2BGR))
        print(f"ðŸ“Œ Saved individual Grad-CAM for: {class_name}")


    combined_heatmap = np.clip(combined_heatmap, 0, 1)
    overlay = (0.5 * img_rgb / 255.0 + 0.5 * combined_heatmap)
    overlay = np.clip(overlay * 255, 0, 255).astype(np.uint8)

    cv2.imwrite(output_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
    print(f"âœ… Combined Grad-CAM saved at: {output_path}")

    label_path = output_path.replace(".png", "_labels.png")
    create_label_legend(class_names, colormaps, label_path)

    handle_f.remove()
    handle_b.remove()


# Example usage
# run_combined_multiclass_gradcam(
#     image_path="KITTI/testing/image_2/000089_10.png",
#     checkpoint_path="best_model_miou_0.5342.pth",
#     output_path="output/combined_gradcam.png"
# )


In [38]:
run_combined_multiclass_gradcam(
    image_path="KITTI/testing/image_2/000089_10.png",
    checkpoint_path="best_model_miou_0.5342.pth",
    output_path="output/gradcam_visualization_combined.png"
)

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/mit-b3 and are newly initialized: ['decode_head.batch_norm.bias', 'decode_head.batch_norm.num_batches_tracked', 'decode_head.batch_norm.running_mean', 'decode_head.batch_norm.running_var', 'decode_head.batch_norm.weight', 'decode_head.classifier.bias', 'decode_head.classifier.weight', 'decode_head.linear_c.0.proj.bias', 'decode_head.linear_c.0.proj.weight', 'decode_head.linear_c.1.proj.bias', 'decode_head.linear_c.1.proj.weight', 'decode_head.linear_c.2.proj.bias', 'decode_head.linear_c.2.proj.weight', 'decode_head.linear_c.3.proj.bias', 'decode_head.linear_c.3.proj.weight', 'decode_head.linear_fuse.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ðŸ“Œ Saved individual Grad-CAM for: background
ðŸ“Œ Saved individual Grad-CAM for: road
ðŸ“Œ Saved individual Grad-CAM for: sidewalk
ðŸ“Œ Saved individual Grad-CAM for: building
ðŸ“Œ Saved individual Grad-CAM for: wall
ðŸ“Œ Saved individual Grad-CAM for: fence
ðŸ“Œ Saved individual Grad-CAM for: pole
ðŸ“Œ Saved individual Grad-CAM for: traffic_light
ðŸ“Œ Saved individual Grad-CAM for: traffic_sign
ðŸ“Œ Saved individual Grad-CAM for: vegetation
ðŸ“Œ Saved individual Grad-CAM for: terrain
ðŸ“Œ Saved individual Grad-CAM for: sky
ðŸ“Œ Saved individual Grad-CAM for: person
ðŸ“Œ Saved individual Grad-CAM for: rider
ðŸ“Œ Saved individual Grad-CAM for: car
ðŸ“Œ Saved individual Grad-CAM for: truck
ðŸ“Œ Saved individual Grad-CAM for: bus
ðŸ“Œ Saved individual Grad-CAM for: motorcycle
ðŸ“Œ Saved individual Grad-CAM for: bicycle
âœ… Combined Grad-CAM saved at: output/gradcam_visualization_combined.png


  color_rgb = tuple(int(255 * c) for c in plt.cm.get_cmap('tab20')(i % 20)[:3])


âœ… Label legend saved at: output/gradcam_visualization_combined_labels.png


In [2]:
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from transformers import SegformerForSemanticSegmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Resize and normalize transforms
def get_transform(image_height, image_width):
    return A.Compose([
        A.Resize(height=image_height, width=image_width),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

# Extract attention rollout from MiT encoder
def get_attention_rollout(model, input_tensor, image_size):
    attention_maps = []

    def hook(module, input, output):
        attention = module.attention_scores.detach().cpu()
        attention_maps.append(attention)

    handles = []
    for blk in model.segformer.encoder.block:
        handles.append(blk.attn.register_forward_hook(hook))

    _ = model(pixel_values=input_tensor)

    for h in handles:
        h.remove()

    # Stack attention maps and average heads: [layers, heads, N, N] -> [layers, N, N]
    attn = torch.stack(attention_maps)[:, 0]  # remove batch dim
    attn = attn.mean(dim=1)

    # Add identity and normalize
    eye = torch.eye(attn.size(-1)).to(attn.device)
    attn = attn + eye
    attn = attn / attn.sum(dim=-1, keepdim=True)

    rollout = attn[0]
    for i in range(1, attn.shape[0]):
        rollout = attn[i] @ rollout

    cls_attn = rollout[0, 1:]  # exclude CLS token
    num_patches = cls_attn.shape[0]
    h = w = int(num_patches ** 0.5)
    cls_attn_map = cls_attn.reshape(h, w)

    # Resize to full image size
    cls_attn_map = cv2.resize(cls_attn_map.numpy(), image_size, interpolation=cv2.INTER_LINEAR)
    cls_attn_map = (cls_attn_map - cls_attn_map.min()) / (cls_attn_map.max() - cls_attn_map.min())
    return cls_attn_map

# Top-level function: load model + image + save overlay
def visualize_rollout_on_image(image_path, checkpoint_path, model_type="b3",
                               image_height=1024, image_width=1024,
                               output_path="attention_rollout.png"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

    # Load model and restore trained weights
    model = SegformerForSemanticSegmentation.from_pretrained(
        f"nvidia/mit-{model_type}",
        num_labels=checkpoint['config']['num_classes'],
        ignore_mismatched_sizes=True
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device).eval()

    # Load and preprocess image
    img_bgr = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    transform = get_transform(image_height, image_width)
    transformed = transform(image=img_rgb)
    input_tensor = transformed['image'].unsqueeze(0).to(device)

    # Compute attention rollout
    rollout_map = get_attention_rollout(model, input_tensor, (image_width, image_height))

    # Create overlay
    heatmap = cv2.applyColorMap(np.uint8(255 * rollout_map), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    overlay = cv2.addWeighted(img_rgb, 0.5, heatmap, 0.5, 0)

    # Save
    cv2.imwrite(output_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
    print(f"âœ… Saved attention rollout to: {output_path}")


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
visualize_rollout_on_image(
    image_path="KITTI/testing/image_2/000093_10.png",
    checkpoint_path="checkpoints/transfer/best_model_miou.pth",
    output_path="output/attention_rollout.png"
)


FileNotFoundError: [Errno 2] No such file or directory: 'checkpoints/transfer/best_model_miou.pth'