# Mask R-CNN Inference Visualization Pipeline

This notebook provides comprehensive visualization of **every stage** of our Custom Mask R-CNN during inference:

## Pipeline Stages Visualized:
1. **Input Image** - Original image preprocessing
2. **EfficientNet Backbone** - Feature extraction at multiple scales (C2-C5)
3. **CBAM Attention** - Channel and Spatial attention at each backbone stage
4. **Feature Pyramid Network (FPN)** - Multi-scale feature fusion (P2-P5)
5. **FPN Attention** - CBAM attention on FPN outputs
6. **Region Proposal Network (RPN)** - Anchor-based proposals
7. **RoI Align** - Feature extraction with bilinear sampling
8. **Box Head** - FC layers for classification and regression
9. **Mask Head** - Convolutional layers for segmentation
10. **Grad-CAM Heatmaps** - Class activation maps showing "where the model looks"
11. **Final Predictions** - Boxes, classes, and masks

---

## 1. Setup and Configuration

In [None]:
import sys
import os
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path(os.getcwd()).parent if 'notebooks' in os.getcwd() else Path(os.getcwd())
sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2

# Enable interactive matplotlib
%matplotlib inline
plt.rcParams['figure.figsize'] = (16, 10)
plt.rcParams['figure.dpi'] = 100

# Check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Import visualization pipeline
from visualization.gradcam_pipeline import (
    MaskRCNNVisualizationPipeline,
    GradCAMConfig,
    load_pipeline,
    denormalize_image,
    overlay_heatmap,
    visualize_feature_maps,
    visualize_fpn_features,
    visualize_roi_align_grid,
    visualize_box_head_features,
    visualize_mask_head_stages,
    visualize_final_predictions,
    ISAID_CLASS_LABELS,
    ISAID_COLORS,
)

from models.maskrcnn_model import get_custom_maskrcnn

print("Visualization pipeline imported successfully")

In [None]:
!git clone https://github.com/michaelo-ponteski/isaid-instance-segmentation.git
%cd /kaggle/working/isaid-instance-segmentation
!git pull
!git switch gradcam
!pip install --upgrade wandb

In [None]:
import wandb
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()
my_secret = user_secrets.get_secret("wandb_key")
wandb.login(key=my_secret)

In [None]:
run = wandb.init()
artifact = run.use_artifact('marek-olnk-put-pozna-/isaid-custom-segmentation/isaid-model:v21', type='model')
artifact_dir = artifact.download()

## 2. Configuration

Set your model checkpoint path and image path here:

In [None]:
# =============================================================================
# CONFIGURATION 
# =============================================================================

# Path to your trained model checkpoint
CHECKPOINT_PATH = artifact_dir

# Path to an image for inference (can be from validation set)
IMAGE_PATH = "data/val/images/sample.png"  # <-- CHANGE THIS

# Or use a sample from the dataset
USE_DATASET_SAMPLE = True  # Set to True to load from dataset
DATASET_ROOT = "data"  # Path to iSAID dataset
SAMPLE_INDEX = 0  # Which sample to visualize

# Model configuration (must match training)
NUM_CLASSES = 16  # iSAID has 15 classes + background

# Visualization settings
CONF_THRESHOLD = 0.5  # Confidence threshold for predictions
OUTPUT_DIR = "./gradcam_outputs"  # Where to save visualizations

print("Configuration set!")
print(f"  Checkpoint: {CHECKPOINT_PATH}")
print(f"  Output directory: {OUTPUT_DIR}")

## 3. Load Model and Create Pipeline

In [None]:
# Create visualization configuration
config = GradCAMConfig(
    device=device,
    conf_threshold=CONF_THRESHOLD,
    output_dir=OUTPUT_DIR,
    colormap='jet',
    alpha_overlay=0.5,
)

# Create model
print("Creating model...")
model = get_custom_maskrcnn(
    num_classes=NUM_CLASSES,
    pretrained_backbone=False,  # We'll load trained weights
)

# Create pipeline
pipeline = MaskRCNNVisualizationPipeline(model, config)

print(f"\nModel architecture:")
print(f"  - Backbone: EfficientNet-B0 with CBAM attention")
print(f"  - FPN: Custom with attention modules")
print(f"  - Number of classes: {NUM_CLASSES}")

In [None]:
# Load checkpoint
checkpoint_path = PROJECT_ROOT / CHECKPOINT_PATH

if checkpoint_path.exists():
    pipeline.load_model_weights(str(checkpoint_path))
    print("Model weights loaded successfully!")
else:
    print(f"Checkpoint not found at: {checkpoint_path}")
    print("Using randomly initialized weights for demonstration.")
    print("\nTo use trained weights, update CHECKPOINT_PATH in the configuration cell.")

## 4. Load Test Image

In [None]:
# Load image
if USE_DATASET_SAMPLE:
    # Load from dataset
    try:
        from datasets.isaid_dataset import get_isaid_dataset
        from training.transforms import get_transform
        
        val_dataset = get_isaid_dataset(
            root=str(PROJECT_ROOT / DATASET_ROOT),
            split='val',
            transforms=get_transform(train=False),
        )
        
        # Get sample
        image_tensor, target = val_dataset[SAMPLE_INDEX]
        
        # Denormalize for visualization
        image_np = denormalize_image(image_tensor)
        
        print(f"Loaded sample {SAMPLE_INDEX} from validation set")
        print(f"  Image shape: {image_tensor.shape}")
        print(f"  Number of GT objects: {len(target['boxes'])}")
        
        # Show ground truth classes
        gt_classes = [ISAID_CLASS_LABELS[l.item()] for l in target['labels']]
        print(f"  GT classes: {gt_classes}")
        
    except Exception as e:
        print(f"Could not load from dataset: {e}")
        print("Falling back to image path...")
        USE_DATASET_SAMPLE = False

if not USE_DATASET_SAMPLE:
    # Load from file
    image_path = PROJECT_ROOT / IMAGE_PATH
    
    if image_path.exists():
        image_np = np.array(Image.open(image_path).convert('RGB'))
        print(f"Loaded image from {image_path}")
        print(f"  Image shape: {image_np.shape}")
    else:
        # Create a dummy image for demonstration
        print(f"Image not found at {image_path}")
        print("Creating a dummy test image...")
        image_np = np.random.randint(0, 255, (800, 800, 3), dtype=np.uint8)

# Display the input image
plt.figure(figsize=(10, 10))
plt.imshow(image_np)
plt.title('Input Image', fontsize=14)
plt.axis('off')
plt.show()

---

# Stage-by-Stage Visualization

Now we'll visualize each stage of the inference pipeline step by step.

## Stage 1: Preprocessing & Feature Extraction Setup

In [None]:
# Preprocess image and run inference with feature extraction
print("Running inference with feature extraction hooks...")

# Preprocess
image_tensor, image_display = pipeline.preprocess_image(image_np)
print(f"Preprocessed tensor shape: {image_tensor.shape}")

# Run inference with hooks
results = pipeline.run_inference_with_gradients()

predictions = results['predictions']
fpn_features = results['fpn_features']
extracted = results['extracted_features']

print(f"\n Inference complete!")
print(f"   Detections found: {len(predictions['boxes'])}")
print(f"   Extracted features: {len(extracted)} tensors")
print(f"   FPN levels: {list(fpn_features.keys())}")

In [None]:
# List all extracted features
print("\n Extracted Features Summary:")
print("=" * 60)

for name, tensor in sorted(extracted.items()):
    if isinstance(tensor, torch.Tensor):
        shape = list(tensor.shape)
        print(f"  {name:40s} -> {shape}")

## Stage 2: EfficientNet Backbone Features

In [None]:
# Visualize backbone stages
backbone_stages = {
    'backbone_stage_3': 'C2 (After MBConv1)',
    'backbone_stage_4': 'C3 (After MBConv2)', 
    'backbone_stage_5': 'C4 (After MBConv3)',
    'backbone_stage_7': 'C5 (Final Features)',
}

fig, axes = plt.subplots(2, 4, figsize=(20, 10))
axes = axes.flatten()

ax_idx = 0
for stage_key, stage_name in backbone_stages.items():
    if stage_key in extracted:
        feat = extracted[stage_key]
        if feat.dim() == 4:
            feat = feat[0]
        
        # Mean activation
        feat_mean = feat.mean(dim=0).cpu().numpy()
        feat_mean = (feat_mean - feat_mean.min()) / (feat_mean.max() - feat_mean.min() + 1e-8)
        
        axes[ax_idx].imshow(feat_mean, cmap='viridis')
        axes[ax_idx].set_title(f'{stage_name}\nMean ({feat.shape[0]} channels)', fontsize=10)
        axes[ax_idx].axis('off')
        ax_idx += 1
        
        # Max activation
        feat_max = feat.max(dim=0)[0].cpu().numpy()
        feat_max = (feat_max - feat_max.min()) / (feat_max.max() - feat_max.min() + 1e-8)
        
        axes[ax_idx].imshow(feat_max, cmap='hot')
        axes[ax_idx].set_title(f'{stage_name}\nMax Activation', fontsize=10)
        axes[ax_idx].axis('off')
        ax_idx += 1

# Hide unused axes
for i in range(ax_idx, len(axes)):
    axes[i].axis('off')

fig.suptitle('Stage 2: EfficientNet Backbone Feature Maps', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Stage 3: CBAM Attention Visualization

In [None]:
# Visualize CBAM attention outputs at each backbone stage
cbam_stages = [
    ('cbam_attention_c2', 'CBAM @ C2 (40 channels)'),
    ('cbam_attention_c3', 'CBAM @ C3 (80 channels)'),
    ('cbam_attention_c4', 'CBAM @ C4 (112 channels)'),
    ('cbam_attention_c5', 'CBAM @ C5 (320 channels)'),
]

fig, axes = plt.subplots(2, 4, figsize=(20, 10))

for idx, (stage_key, stage_name) in enumerate(cbam_stages):
    if stage_key in extracted:
        feat = extracted[stage_key]
        if feat.dim() == 4:
            feat = feat[0]
        
        # Top row: Mean activation (what CBAM emphasizes)
        feat_mean = feat.mean(dim=0).cpu().numpy()
        feat_mean = (feat_mean - feat_mean.min()) / (feat_mean.max() - feat_mean.min() + 1e-8)
        
        axes[0, idx].imshow(feat_mean, cmap='viridis')
        axes[0, idx].set_title(f'{stage_name}\nMean Activation', fontsize=10)
        axes[0, idx].axis('off')
        
        # Bottom row: Resized overlay on image
        feat_resized = cv2.resize(feat_mean, (image_np.shape[1], image_np.shape[0]))
        overlay = overlay_heatmap(image_np, feat_resized, alpha=0.5, colormap='jet')
        
        axes[1, idx].imshow(overlay)
        axes[1, idx].set_title(f'Attention Overlay', fontsize=10)
        axes[1, idx].axis('off')
    else:
        axes[0, idx].axis('off')
        axes[1, idx].axis('off')

fig.suptitle('Stage 3: CBAM Attention - Where the Model Focuses at Each Scale', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Stage 4: Feature Pyramid Network (FPN)

In [None]:
# Visualize FPN multi-scale features
fig = visualize_fpn_features(fpn_features, figsize=(20, 10))
fig.suptitle('Stage 4: FPN Multi-Scale Features (P2-P5)', fontsize=14, fontweight='bold')
plt.show()

In [None]:
# FPN features overlaid on image
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

for idx, level in enumerate(['P2', 'P3', 'P4', 'P5']):
    if level in fpn_features:
        feat = fpn_features[level]
        if feat.dim() == 4:
            feat = feat[0]
        
        # Mean activation
        feat_mean = feat.mean(dim=0).cpu().numpy()
        feat_mean = (feat_mean - feat_mean.min()) / (feat_mean.max() - feat_mean.min() + 1e-8)
        
        # Top row: Raw feature map
        axes[0, idx].imshow(feat_mean, cmap='plasma')
        axes[0, idx].set_title(f'{level}: {feat.shape[1]}x{feat.shape[2]}\n({feat.shape[0]} channels)', fontsize=10)
        axes[0, idx].axis('off')
        
        # Bottom row: Overlay on image
        feat_resized = cv2.resize(feat_mean, (image_np.shape[1], image_np.shape[0]))
        overlay = overlay_heatmap(image_np, feat_resized, alpha=0.5, colormap='plasma')
        
        axes[1, idx].imshow(overlay)
        axes[1, idx].set_title(f'{level} Overlay on Image', fontsize=10)
        axes[1, idx].axis('off')

fig.suptitle('Stage 4: FPN Feature Maps at Each Scale', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Stage 5: FPN Attention Modules

In [None]:
# Visualize FPN attention outputs
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

fpn_attention_keys = ['fpn_attention_0', 'fpn_attention_1', 'fpn_attention_2', 'fpn_attention_3']
fpn_level_names = ['P5 (smallest)', 'P4', 'P3', 'P2 (largest)']

for idx, (key, name) in enumerate(zip(fpn_attention_keys, fpn_level_names)):
    if key in extracted:
        feat = extracted[key]
        if feat.dim() == 4:
            feat = feat[0]
        
        feat_mean = feat.mean(dim=0).cpu().numpy()
        feat_mean = (feat_mean - feat_mean.min()) / (feat_mean.max() - feat_mean.min() + 1e-8)
        
        # Top: Feature map
        axes[0, idx].imshow(feat_mean, cmap='viridis')
        axes[0, idx].set_title(f'{name}\nAfter Attention', fontsize=10)
        axes[0, idx].axis('off')
        
        # Bottom: Overlay
        feat_resized = cv2.resize(feat_mean, (image_np.shape[1], image_np.shape[0]))
        overlay = overlay_heatmap(image_np, feat_resized, alpha=0.5)
        
        axes[1, idx].imshow(overlay)
        axes[1, idx].set_title(f'Attention Overlay', fontsize=10)
        axes[1, idx].axis('off')
    else:
        axes[0, idx].axis('off')
        axes[1, idx].axis('off')

fig.suptitle('Stage 5: FPN Attention Module Outputs', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Stage 6: RPN Proposals (Region Proposal Network)

In [None]:
# Visualize detected boxes as RPN-like proposals
import matplotlib.patches as patches

fig, axes = plt.subplots(1, 2, figsize=(16, 8))

# All predictions (before NMS/thresholding)
boxes = predictions['boxes'].cpu().numpy()
scores = predictions['scores'].cpu().numpy()
labels = predictions['labels'].cpu().numpy()

# Left: All predictions colored by confidence
axes[0].imshow(image_np)
cmap = plt.get_cmap('RdYlGn')

for box, score in zip(boxes, scores):
    x1, y1, x2, y2 = box
    color = cmap(score)
    rect = patches.Rectangle(
        (x1, y1), x2-x1, y2-y1,
        linewidth=1, edgecolor=color, facecolor='none', alpha=0.7
    )
    axes[0].add_patch(rect)

axes[0].set_title(f'All Predictions ({len(boxes)} boxes)\nColored by confidence (red=low, green=high)', fontsize=12)
axes[0].axis('off')

# Right: High confidence predictions only
axes[1].imshow(image_np)

high_conf = scores >= CONF_THRESHOLD
for box, score, label in zip(boxes[high_conf], scores[high_conf], labels[high_conf]):
    x1, y1, x2, y2 = box
    color = np.array(ISAID_COLORS.get(label, [255, 255, 255])) / 255.0
    
    rect = patches.Rectangle(
        (x1, y1), x2-x1, y2-y1,
        linewidth=2, edgecolor=color, facecolor='none'
    )
    axes[1].add_patch(rect)
    axes[1].text(x1, y1-3, f'{ISAID_CLASS_LABELS[label]}: {score:.2f}',
                 fontsize=8, color='white', backgroundcolor=color)

axes[1].set_title(f'High Confidence (>{CONF_THRESHOLD}) Predictions ({high_conf.sum()} boxes)', fontsize=12)
axes[1].axis('off')

fig.suptitle('Stage 6: Region Proposals and NMS Results', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Stage 7: RoI Align Sampling Grid

In [None]:
# Visualize RoI Align sampling grids
if len(predictions['boxes']) > 0:
    # Use top 4 predictions
    top_k = min(4, len(predictions['boxes']))
    top_boxes = predictions['boxes'][:top_k]
    
    fig = visualize_roi_align_grid(
        image_np,
        top_boxes,
        output_size=7,
        sampling_ratio=2,
        max_boxes=top_k,
        figsize=(20, 6)
    )
    fig.suptitle('Stage 7: RoI Align - Bilinear Sampling Grid Visualization', fontsize=14, fontweight='bold')
    plt.show()
else:
    print("No detections to visualize RoI Align grid")

In [None]:
# Detailed RoI Align explanation
if len(predictions['boxes']) > 0:
    fig, axes = plt.subplots(2, 4, figsize=(20, 10))
    
    box_idx = 0
    box = predictions['boxes'][box_idx].cpu().numpy()
    x1, y1, x2, y2 = box
    w, h = x2 - x1, y2 - y1
    
    output_size = 7
    sampling_ratio = 2
    
    # Row 1: Different aspects of RoI Align
    # 1. Original RoI region
    crop = image_np[int(y1):int(y2), int(x1):int(x2)]
    axes[0, 0].imshow(crop)
    axes[0, 0].set_title(f'Original RoI\n{int(w)}x{int(h)} pixels', fontsize=10)
    axes[0, 0].axis('off')
    
    # 2. Grid overlay
    axes[0, 1].imshow(crop)
    cell_w = w / output_size
    cell_h = h / output_size
    for gx in range(output_size + 1):
        x = gx * cell_w
        axes[0, 1].axvline(x, color='yellow', linewidth=1)
    for gy in range(output_size + 1):
        y = gy * cell_h
        axes[0, 1].axhline(y, color='yellow', linewidth=1)
    axes[0, 1].set_title(f'7x7 Output Grid\n{output_size*output_size} cells', fontsize=10)
    axes[0, 1].axis('off')
    
    # 3. Sampling points
    axes[0, 2].imshow(crop)
    for gx in range(output_size):
        for gy in range(output_size):
            for sx in range(sampling_ratio):
                for sy in range(sampling_ratio):
                    px = (gx + (sx + 0.5) / sampling_ratio) * cell_w
                    py = (gy + (sy + 0.5) / sampling_ratio) * cell_h
                    axes[0, 2].plot(px, py, 'r.', markersize=3)
    axes[0, 2].set_title(f'Bilinear Sampling Points\n{output_size**2 * sampling_ratio**2} points', fontsize=10)
    axes[0, 2].axis('off')
    
    # 4. Final 7x7 output (from pooled features if available)
    axes[0, 3].imshow(cv2.resize(crop, (7*20, 7*20), interpolation=cv2.INTER_NEAREST))
    axes[0, 3].set_title('Final 7x7 Output\n(Bilinear interpolated)', fontsize=10)
    axes[0, 3].axis('off')
    
    # Row 2: Multiple boxes
    for i in range(min(4, len(predictions['boxes']))):
        box = predictions['boxes'][i].cpu().numpy()
        x1, y1, x2, y2 = box
        crop = image_np[max(0,int(y1)):int(y2), max(0,int(x1)):int(x2)]
        if crop.size > 0:
            axes[1, i].imshow(crop)
            label = predictions['labels'][i].item()
            score = predictions['scores'][i].item()
            axes[1, i].set_title(f'RoI {i+1}: {ISAID_CLASS_LABELS[label]}\nConf: {score:.2f}', fontsize=10)
        axes[1, i].axis('off')
    
    fig.suptitle('Stage 7: RoI Align - How Features are Extracted from Each Proposal', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

## Stage 8: Box Head Analysis

In [None]:
# Visualize box head features
if 'box_head_fc1' in extracted and 'box_cls_score' in extracted:
    fig = visualize_box_head_features(
        extracted['box_head_fc1'],
        extracted['box_head_fc2'],
        extracted['box_cls_score'],
        extracted.get('box_bbox_pred', torch.zeros(1)),
        ISAID_CLASS_LABELS,
        max_rois=min(8, len(predictions['boxes'])) if len(predictions['boxes']) > 0 else 8,
    )
    fig.suptitle('Stage 8: Box Head - Classification & Regression', fontsize=14, fontweight='bold')
    plt.show()
else:
    print("Box head features not captured - hooks may not have fired")

In [None]:
# Detailed class probability analysis
if 'box_cls_score' in extracted and len(predictions['boxes']) > 0:
    import torch.nn.functional as F
    
    cls_scores = extracted['box_cls_score']
    cls_probs = F.softmax(cls_scores, dim=1).cpu().numpy()
    
    fig, axes = plt.subplots(2, 4, figsize=(20, 8))
    axes = axes.flatten()
    
    for i in range(min(8, len(cls_probs))):
        probs = cls_probs[i]
        
        # Get top 5 classes
        top_indices = np.argsort(probs)[::-1][:5]
        top_probs = probs[top_indices]
        top_names = [ISAID_CLASS_LABELS[idx][:10] for idx in top_indices]
        
        colors = [np.array(ISAID_COLORS.get(idx, [128,128,128]))/255.0 for idx in top_indices]
        
        bars = axes[i].barh(range(5), top_probs, color=colors)
        axes[i].set_yticks(range(5))
        axes[i].set_yticklabels(top_names)
        axes[i].set_xlim(0, 1)
        axes[i].set_title(f'RoI {i+1}', fontsize=10)
        axes[i].invert_yaxis()
        
        # Add probability values
        for j, (bar, p) in enumerate(zip(bars, top_probs)):
            axes[i].text(p + 0.02, j, f'{p:.2f}', va='center', fontsize=8)
    
    # Hide unused axes
    for i in range(min(8, len(cls_probs)), 8):
        axes[i].axis('off')
    
    fig.suptitle('Stage 8: Per-RoI Class Probability Distribution (Top 5 Classes)', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

## Stage 9: Mask Head Stages

In [None]:
# Visualize mask head intermediate stages
mask_features = {k: v for k, v in extracted.items() if k.startswith('mask_head_')}

if mask_features and len(predictions['boxes']) > 0:
    print(f"Mask head stages captured: {list(mask_features.keys())}")
    
    fig = visualize_mask_head_stages(mask_features, box_idx=0)
    fig.suptitle('Stage 9: Mask Head - Convolutional Stages', fontsize=14, fontweight='bold')
    plt.show()
else:
    print("Mask head features not captured or no detections")

In [None]:
# Visualize individual mask predictions
if 'masks' in predictions and len(predictions['masks']) > 0:
    masks = predictions['masks'].cpu().numpy()
    
    n_masks = min(8, len(masks))
    fig, axes = plt.subplots(2, n_masks, figsize=(4*n_masks, 8))
    
    for i in range(n_masks):
        mask = masks[i]
        if mask.ndim == 3:
            mask = mask[0]  # Remove channel dim
        
        label = predictions['labels'][i].item()
        score = predictions['scores'][i].item()
        box = predictions['boxes'][i].cpu().numpy()
        
        # Top: Raw mask probability
        axes[0, i].imshow(mask, cmap='hot', vmin=0, vmax=1)
        axes[0, i].set_title(f'{ISAID_CLASS_LABELS[label]}\nP={score:.2f}', fontsize=10)
        axes[0, i].axis('off')
        
        # Bottom: Binary mask on image crop
        x1, y1, x2, y2 = box.astype(int)
        crop = image_np[max(0,y1):y2, max(0,x1):x2].copy()
        mask_binary = (mask > 0.5).astype(np.uint8)
        mask_crop = mask_binary[max(0,y1):y2, max(0,x1):x2]
        
        if crop.size > 0 and mask_crop.size > 0:
            # Resize mask to crop size if needed
            if mask_crop.shape != crop.shape[:2]:
                mask_crop = cv2.resize(mask_crop, (crop.shape[1], crop.shape[0]))
            
            color = np.array(ISAID_COLORS.get(label, [255, 0, 0]))
            overlay = crop.copy().astype(np.float32)
            overlay[mask_crop > 0] = overlay[mask_crop > 0] * 0.5 + color * 0.5
            overlay = np.clip(overlay, 0, 255).astype(np.uint8)
            
            axes[1, i].imshow(overlay)
        axes[1, i].set_title('Masked Region', fontsize=10)
        axes[1, i].axis('off')
    
    fig.suptitle('Stage 9: Individual Mask Predictions', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
else:
    print("No mask predictions available")

## Stage 10: Grad-CAM Heatmaps

In [None]:
# Generate Grad-CAM style heatmaps from different FPN levels
fig, axes = plt.subplots(2, 4, figsize=(20, 10))

levels = ['P2', 'P3', 'P4', 'P5']

for idx, level in enumerate(levels):
    if level in fpn_features:
        feat = fpn_features[level]
        if feat.dim() == 4:
            feat = feat[0]
        
        # Compute activation-based heatmap (Grad-CAM approximation)
        heatmap = feat.mean(dim=0).cpu().numpy()
        heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
        heatmap = cv2.resize(heatmap, (image_np.shape[1], image_np.shape[0]))
        
        # Top row: Heatmap only
        axes[0, idx].imshow(heatmap, cmap='jet')
        axes[0, idx].set_title(f'{level} Activation Map', fontsize=12)
        axes[0, idx].axis('off')
        
        # Bottom row: Overlay on image
        overlay = overlay_heatmap(image_np, heatmap, alpha=0.5, colormap='jet')
        axes[1, idx].imshow(overlay)
        axes[1, idx].set_title(f'{level} Overlay', fontsize=12)
        axes[1, idx].axis('off')

fig.suptitle('Stage 10: Grad-CAM Heatmaps - Where the Model Looks at Each FPN Level', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Combined multi-scale heatmap
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Original image
axes[0].imshow(image_np)
axes[0].set_title('Original Image', fontsize=14)
axes[0].axis('off')

# Combine heatmaps from multiple scales
combined_heatmap = np.zeros((image_np.shape[0], image_np.shape[1]), dtype=np.float32)
weights = {'P2': 0.1, 'P3': 0.2, 'P4': 0.3, 'P5': 0.4}  # Higher weight for deeper features

for level, weight in weights.items():
    if level in fpn_features:
        feat = fpn_features[level]
        if feat.dim() == 4:
            feat = feat[0]
        
        heatmap = feat.mean(dim=0).cpu().numpy()
        heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
        heatmap = cv2.resize(heatmap, (image_np.shape[1], image_np.shape[0]))
        combined_heatmap += weight * heatmap

# Normalize combined
combined_heatmap = (combined_heatmap - combined_heatmap.min()) / (combined_heatmap.max() - combined_heatmap.min() + 1e-8)

# Multi-scale heatmap
axes[1].imshow(combined_heatmap, cmap='jet')
axes[1].set_title('Multi-Scale Combined Heatmap', fontsize=14)
axes[1].axis('off')

# Overlay
overlay = overlay_heatmap(image_np, combined_heatmap, alpha=0.5, colormap='jet')
axes[2].imshow(overlay)
axes[2].set_title('Combined Heatmap Overlay', fontsize=14)
axes[2].axis('off')

fig.suptitle('Multi-Scale Grad-CAM: Combined Feature Importance', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Stage 11: Final Predictions with Full Visualization

In [None]:
# Final comprehensive visualization
fig = visualize_final_predictions(
    image_np,
    predictions,
    gradcam_heatmap=combined_heatmap,
    class_labels=ISAID_CLASS_LABELS,
    conf_threshold=CONF_THRESHOLD,
    figsize=(18, 6)
)
fig.suptitle('Stage 11: Final Predictions with Masks and Grad-CAM', fontsize=14, fontweight='bold')
plt.show()

In [None]:
# Print detection summary
print("\n" + "="*60)
print("DETECTION SUMMARY")
print("="*60)

boxes = predictions['boxes'].cpu().numpy()
labels = predictions['labels'].cpu().numpy()
scores = predictions['scores'].cpu().numpy()

high_conf = scores >= CONF_THRESHOLD
print(f"\nTotal detections: {len(boxes)}")
print(f"High confidence (>{CONF_THRESHOLD}): {high_conf.sum()}")

print(f"\n{'#':<4} {'Class':<20} {'Confidence':<12} {'Box (x1,y1,x2,y2)'}")
print("-" * 70)

for i, (box, label, score) in enumerate(zip(boxes[high_conf], labels[high_conf], scores[high_conf])):
    class_name = ISAID_CLASS_LABELS.get(label, f'Class {label}')
    box_str = f"({box[0]:.0f}, {box[1]:.0f}, {box[2]:.0f}, {box[3]:.0f})"
    print(f"{i+1:<4} {class_name:<20} {score:<12.4f} {box_str}")

# Class distribution
if high_conf.sum() > 0:
    print("\n Class Distribution:")
    unique_labels, counts = np.unique(labels[high_conf], return_counts=True)
    for label, count in zip(unique_labels, counts):
        class_name = ISAID_CLASS_LABELS.get(label, f'Class {label}')
        print(f"   {class_name}: {count}")

---

## Complete Pipeline Summary Grid

In [None]:
# Generate comprehensive summary grid
fig = pipeline._create_summary_grid(results, combined_heatmap)
plt.show()

---

## Save All Visualizations

In [None]:
# Save all visualizations to disk
SAVE_OUTPUTS = True  # Set to True to save

if SAVE_OUTPUTS:
    import os
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    print(f"Generating and saving all visualizations to {OUTPUT_DIR}/...")
    
    figures = pipeline.generate_all_visualizations(
        image_np,
        output_dir=OUTPUT_DIR,
        save=True
    )
    
    print(f"\n Saved {len(figures)} visualization files!")
    print(f"\nFiles saved:")
    for fname in sorted(os.listdir(OUTPUT_DIR)):
        if fname.endswith('.png'):
            print(f"   {fname}")
else:
    print("Set SAVE_OUTPUTS = True to save visualizations")

---

## Analyze Multiple Images

In [None]:
# Function to quickly analyze any image
def analyze_image(image_source, save_prefix=None):
    """
    Quick analysis of an image.
    
    Args:
        image_source: Path to image, numpy array, or dataset index (int)
        save_prefix: Optional prefix for saving outputs
    """
    # Load image
    if isinstance(image_source, int):
        # Load from dataset
        img_tensor, target = val_dataset[image_source]
        img_np = denormalize_image(img_tensor)
    elif isinstance(image_source, str):
        img_np = np.array(Image.open(image_source).convert('RGB'))
    else:
        img_np = image_source
    
    # Run pipeline
    pipeline.preprocess_image(img_np)
    results = pipeline.run_inference_with_gradients()
    
    # Quick summary
    preds = results['predictions']
    n_det = (preds['scores'] >= CONF_THRESHOLD).sum().item()
    
    # Visualization
    fpn_feat = results['fpn_features']
    heatmap = None
    if 'P4' in fpn_feat:
        feat = fpn_feat['P4'][0] if fpn_feat['P4'].dim() == 4 else fpn_feat['P4']
        heatmap = feat.mean(dim=0).cpu().numpy()
        heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
        heatmap = cv2.resize(heatmap, (img_np.shape[1], img_np.shape[0]))
    
    fig = visualize_final_predictions(
        img_np, preds, heatmap,
        class_labels=ISAID_CLASS_LABELS,
        conf_threshold=CONF_THRESHOLD
    )
    
    if save_prefix:
        fig.savefig(f'{OUTPUT_DIR}/{save_prefix}_predictions.png', dpi=150, bbox_inches='tight')
    
    plt.show()
    print(f"Found {n_det} objects with confidence >= {CONF_THRESHOLD}")
    
    return results

print("Function `analyze_image()` ready!")
print("\nUsage examples:")
print("  analyze_image(5)                    # Analyze dataset sample #5")
print("  analyze_image('path/to/image.jpg')  # Analyze image file")
print("  analyze_image(image_array)          # Analyze numpy array")

In [None]:
# Example: Analyze a few more samples
# Uncomment and modify as needed:

# for idx in [1, 2, 3]:
#     print(f"\n{'='*60}")
#     print(f"Analyzing sample {idx}")
#     print('='*60)
#     analyze_image(idx, save_prefix=f'sample_{idx}')

---

## Summary

This notebook visualized the complete inference pipeline of our Custom Mask R-CNN:

| Stage | Component | What We Saw |
|-------|-----------|-------------|
| 1 | Input | Original image preprocessing |
| 2 | Backbone | EfficientNet feature extraction (C2-C5) |
| 3 | CBAM | Channel and spatial attention at each scale |
| 4 | FPN | Multi-scale feature fusion (P2-P5) |
| 5 | FPN Attention | Attention-enhanced FPN features |
| 6 | RPN | Region proposals generation |
| 7 | RoI Align | Bilinear sampling grid for each proposal |
| 8 | Box Head | Classification and bbox regression |
| 9 | Mask Head | Instance segmentation prediction |
| 10 | Grad-CAM | Where the model "looks" to make predictions |
| 11 | Output | Final boxes, classes, masks, and confidence scores |

**Key Insights:**
- CBAM attention helps the model focus on relevant regions
- FPN enables detection at multiple scales
- Grad-CAM shows discriminative regions for each class
- RoI Align preserves spatial information through bilinear sampling