In [1]:
import os
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import cv2
from torch.utils.data import DataLoader
from transformers import SegformerForSemanticSegmentation
from tqdm import tqdm
import pandas as pd
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1 import make_axes_locatable

# Set font sizes for publication-quality figures
plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 14,
    'axes.labelsize': 12,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'figure.titlesize': 16
})

# Configure paths
BASE_DIR = "/project/6087070/thatkar"

# KITTI settings
KITTI_ROOT = os.path.join(BASE_DIR, "KITTI")
KITTI_VAL_DIR = os.path.join(KITTI_ROOT, "validation/image_2")
KITTI_VAL_LABELS_DIR = os.path.join(KITTI_ROOT, "validation/semantic")
KITTI_BASELINE_CHECKPOINT = os.path.join(KITTI_ROOT, "checkpoints/baseline/best_model_miou_0.5117.pth")
KITTI_TRANSFER_CHECKPOINT = os.path.join(KITTI_ROOT, "checkpoints/transfer/best_model_miou_0.5342.pth")
KITTI_OUTPUT_DIR = os.path.join(KITTI_ROOT, "journal_visualizations")

# IDD settings
IDD_ROOT = os.path.join(BASE_DIR, "IDD")
IDD_VAL_DIR = os.path.join(IDD_ROOT, "leftImg8bit/val")
IDD_VAL_LABELS_DIR = os.path.join(IDD_ROOT, "IDD_pixelwise_masks/val")
IDD_BASELINE_CHECKPOINT = os.path.join(IDD_ROOT, "checkpoints/segformer_b3_idd/best_model_miou_0.5028.pth")
IDD_TRANSFER_CHECKPOINT = os.path.join(IDD_ROOT, "checkpoints/segformer_b3_idd_transfer/best_model_miou_0.5108.pth")
IDD_OUTPUT_DIR = os.path.join(IDD_ROOT, "journal_visualizations")

# Create output directories
os.makedirs(KITTI_OUTPUT_DIR, exist_ok=True)
os.makedirs(IDD_OUTPUT_DIR, exist_ok=True)

# Define dataset classes
# Import your dataset classes or define them here
# For demonstration purposes, simplified dataset class:
class SegmentationDataset:
    def __init__(self, image_dir, label_dir, class_mapping, transforms=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transforms = transforms
        self.class_mapping = class_mapping
        self.num_classes = len(class_mapping)
        self.ignore_index = 255
        
        # Find valid image-label pairs
        self.images = self._get_image_label_pairs()
        
    def _get_image_label_pairs(self):
        # This should be implemented according to your dataset structure
        # For demonstration, assume a simple structure
        # In reality, you'll need to adapt this for KITTI and IDD
        pairs = []
        return pairs
        
    def __getitem__(self, idx):
        # This should be implemented according to your dataset structure
        pass
        
    def __len__(self):
        return len(self.images)
    
    def get_class_names(self):
        """Return a list of class names in order."""
        return [self.class_mapping[i]['name'] for i in range(self.num_classes)]
    
    def get_color_map(self):
        """Return a mapping of class IDs to colors for visualization."""
        return {i: self.class_mapping[i]['color'] for i in range(self.num_classes)}

# Define utility functions
def denormalize_image(img_tensor):
    """Convert normalized tensor back to displayable image."""
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    
    img_tensor = img_tensor * std + mean
    img = img_tensor.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    img = (img * 255).astype(np.uint8)
    return img

def create_colored_mask(mask, class_colors):
    """Create a colored segmentation mask."""
    h, w = mask.shape
    colored_mask = np.zeros((h, w, 3), dtype=np.uint8)
    
    for class_id, color in class_colors.items():
        class_mask = (mask == class_id)
        if np.any(class_mask):
            for c in range(3):
                colored_mask[:, :, c][class_mask] = color[c]
    
    # Handle ignore index
    ignore_mask = (mask == 255)
    if np.any(ignore_mask):
        colored_mask[ignore_mask] = [180, 180, 180]
    
    return colored_mask

def calculate_metrics(pred_mask, gt_mask, num_classes, ignore_index=255):
    """Calculate IoU metrics for a single image."""
    class_ious = []
    
    for c in range(num_classes):
        pred_c = (pred_mask == c)
        gt_c = (gt_mask == c)
        
        # Only consider valid pixels (not ignore_index)
        valid_mask = (gt_mask != ignore_index)
        pred_c = pred_c & valid_mask
        gt_c = gt_c & valid_mask
        
        intersection = np.logical_and(pred_c, gt_c).sum()
        union = np.logical_or(pred_c, gt_c).sum()
        
        if union == 0:
            iou = float('nan')  # Class not present
        else:
            iou = intersection / union
            
        class_ious.append(iou)
    
    valid_ious = [iou for iou in class_ious if not np.isnan(iou)]
    mean_iou = np.mean(valid_ious) if valid_ious else 0
    
    return {
        'class_ious': class_ious,
        'mean_iou': mean_iou
    }

def generate_journal_visualizations(dataset_name, val_dataset, baseline_model, transfer_model, 
                                   output_dir, device, num_samples=5, seed=42):
    """Generate standardized visualizations for journal publication."""
    # Set random seed for reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    # Create subdirectories
    comparison_dir = os.path.join(output_dir, "comparisons")
    metrics_dir = os.path.join(output_dir, "metrics")
    
    os.makedirs(comparison_dir, exist_ok=True)
    os.makedirs(metrics_dir, exist_ok=True)
    
    # Get class colors and names
    class_colors = val_dataset.get_color_map()
    class_names = val_dataset.get_class_names()
    
    # Create dataloader with fixed seed for reproducibility
    dataloader = DataLoader(
        val_dataset, 
        batch_size=1, 
        shuffle=True, 
        num_workers=2, 
        generator=torch.Generator().manual_seed(seed)
    )
    
    # Process samples
    all_baseline_metrics = []
    all_transfer_metrics = []
    
    for i, batch in enumerate(tqdm(dataloader, desc=f"Generating {dataset_name} visualizations")):
        if i >= num_samples:
            break
            
        pixel_values = batch['pixel_values'].to(device)
        gt_labels = batch['labels'].cpu().numpy()[0]
        image_name = batch['image_name'][0]
        
        # Get original image
        original_img = denormalize_image(pixel_values[0].cpu())
        
        # Get baseline model prediction
        with torch.no_grad():
            outputs = baseline_model(pixel_values=pixel_values)
            logits = outputs.logits
            logits = F.interpolate(logits, size=gt_labels.shape, mode="bilinear", align_corners=False)
            baseline_pred = torch.argmax(logits, dim=1)[0].cpu().numpy()
        
        # Get transfer learning model prediction
        with torch.no_grad():
            outputs = transfer_model(pixel_values=pixel_values)
            logits = outputs.logits
            logits = F.interpolate(logits, size=gt_labels.shape, mode="bilinear", align_corners=False)
            transfer_pred = torch.argmax(logits, dim=1)[0].cpu().numpy()
        
        # Calculate metrics
        baseline_metrics = calculate_metrics(baseline_pred, gt_labels, val_dataset.num_classes)
        transfer_metrics = calculate_metrics(transfer_pred, gt_labels, val_dataset.num_classes)
        
        all_baseline_metrics.append(baseline_metrics)
        all_transfer_metrics.append(transfer_metrics)
        
        # Create colored masks
        gt_colored = create_colored_mask(gt_labels, class_colors)
        baseline_colored = create_colored_mask(baseline_pred, class_colors)
        transfer_colored = create_colored_mask(transfer_pred, class_colors)
        
        # Create difference visualization
        diff_mask = np.zeros_like(original_img)
        # Areas where transfer is correct and baseline is wrong = bright green
        diff_mask[(transfer_pred == gt_labels) & (baseline_pred != gt_labels)] = [0, 255, 0]
        # Areas where baseline is correct and transfer is wrong = red
        diff_mask[(baseline_pred == gt_labels) & (transfer_pred != gt_labels)] = [255, 0, 0]
        
        # Create publication-style visualization
        fig = plt.figure(figsize=(15, 10))
        gs = gridspec.GridSpec(2, 3, height_ratios=[1, 0.05])
        
        # Original image
        ax1 = plt.subplot(gs[0, 0])
        ax1.imshow(original_img)
        ax1.set_title('Original Image')
        ax1.axis('off')
        
        # Ground truth
        ax2 = plt.subplot(gs[0, 1])
        ax2.imshow(gt_colored)
        ax2.set_title('Ground Truth')
        ax2.axis('off')
        
        # Baseline prediction
        ax3 = plt.subplot(gs[0, 2])
        im3 = ax3.imshow(baseline_colored)
        ax3.set_title(f'Baseline (mIoU: {baseline_metrics["mean_iou"]:.4f})')
        ax3.axis('off')
        
        # Transfer learning prediction with differences highlighted
        ax4 = plt.subplot(gs[1, 0:2])
        im4 = ax4.imshow(transfer_colored)
        ax4.set_title(f'Transfer Learning (mIoU: {transfer_metrics["mean_iou"]:.4f})')
        ax4.axis('off')
        
        # Improvement visualization
        ax5 = plt.subplot(gs[1, 2])
        im5 = ax5.imshow(diff_mask)
        ax5.set_title('Differences\nGreen: TL Better, Red: Baseline Better')
        ax5.axis('off')
        
        plt.tight_layout()
        plt.savefig(os.path.join(comparison_dir, f"{dataset_name}_{i+1}_{image_name}.png"), 
                    dpi=300, bbox_inches='tight')
        plt.close()
        
        # Create a closeup visualization of interesting regions
        # This highlights specific improvements in challenging areas
        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
        
        # Find an interesting region with differences (if any)
        diff_positions = np.where((diff_mask[:,:,0] > 0) | (diff_mask[:,:,1] > 0))
        
        if len(diff_positions[0]) > 0:
            # Find center of differences
            center_y = int(np.mean(diff_positions[0]))
            center_x = int(np.mean(diff_positions[1]))
            
            # Define crop region
            crop_size = min(200, gt_labels.shape[0]//2, gt_labels.shape[1]//2)
            start_y = max(0, center_y - crop_size//2)
            start_x = max(0, center_x - crop_size//2)
            end_y = min(gt_labels.shape[0], start_y + crop_size)
            end_x = min(gt_labels.shape[1], start_x + crop_size)
            
            # Crop images
            crop_original = original_img[start_y:end_y, start_x:end_x]
            crop_gt = gt_colored[start_y:end_y, start_x:end_x]
            crop_baseline = baseline_colored[start_y:end_y, start_x:end_x]
            crop_transfer = transfer_colored[start_y:end_y, start_x:end_x]
            crop_diff = diff_mask[start_y:end_y, start_x:end_x]
            
            # Display cropped regions
            axes[0, 0].imshow(original_img)
            axes[0, 0].set_title('Original Image')
            axes[0, 0].axis('off')
            
            # Show crop region
            rect = plt.Rectangle((start_x, start_y), crop_size, crop_size, 
                                edgecolor='white', facecolor='none', linewidth=2)
            axes[0, 0].add_patch(rect)
            
            axes[0, 1].imshow(crop_original)
            axes[0, 1].set_title('Region of Interest')
            axes[0, 1].axis('off')
            
            axes[0, 2].imshow(crop_gt)
            axes[0, 2].set_title('Ground Truth')
            axes[0, 2].axis('off')
            
            axes[1, 0].imshow(crop_baseline)
            axes[1, 0].set_title('Baseline Prediction')
            axes[1, 0].axis('off')
            
            axes[1, 1].imshow(crop_transfer)
            axes[1, 1].set_title('Transfer Learning Prediction')
            axes[1, 1].axis('off')
            
            axes[1, 2].imshow(crop_diff)
            axes[1, 2].set_title('Differences')
            axes[1, 2].axis('off')
            
            plt.tight_layout()
            plt.savefig(os.path.join(comparison_dir, f"{dataset_name}_{i+1}_{image_name}_closeup.png"), 
                        dpi=300, bbox_inches='tight')
            plt.close()
    
    # Create summary metrics visualization
    baseline_mean_ious = [metrics['mean_iou'] for metrics in all_baseline_metrics]
    transfer_mean_ious = [metrics['mean_iou'] for metrics in all_transfer_metrics]
    
    # Overall mIoU comparison
    plt.figure(figsize=(8, 5))
    plt.bar(['Baseline', 'Transfer Learning'], 
            [np.mean(baseline_mean_ious), np.mean(transfer_mean_ious)],
            color=['skyblue', 'coral'])
    plt.ylabel('Mean IoU')
    plt.title(f'{dataset_name} - Mean IoU Comparison')
    plt.grid(axis='y', alpha=0.3)
    plt.savefig(os.path.join(metrics_dir, f"{dataset_name}_miou_comparison.png"), 
                dpi=300, bbox_inches='tight')
    plt.close()
    
    # Per-class IoU comparison
    class_ious_baseline = np.nanmean([m['class_ious'] for m in all_baseline_metrics], axis=0)
    class_ious_transfer = np.nanmean([m['class_ious'] for m in all_transfer_metrics], axis=0)
    
    # Only include classes that appear in the samples
    valid_classes = ~np.isnan(class_ious_baseline) & ~np.isnan(class_ious_transfer)
    class_indices = np.where(valid_classes)[0]
    
    if len(class_indices) > 0:
        fig, ax = plt.subplots(figsize=(12, 6))
        x = np.arange(len(class_indices))
        width = 0.35
        
        baseline_bars = ax.bar(x - width/2, class_ious_baseline[class_indices], width, 
                              label='Baseline', color='skyblue')
        transfer_bars = ax.bar(x + width/2, class_ious_transfer[class_indices], width,
                               label='Transfer Learning', color='coral')
        
        ax.set_ylabel('IoU')
        ax.set_title(f'{dataset_name} - Per-Class IoU Comparison')
        ax.set_xticks(x)
        ax.set_xticklabels([class_names[i] for i in class_indices], rotation=90)
        ax.legend()
        ax.grid(axis='y', alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(metrics_dir, f"{dataset_name}_class_iou_comparison.png"), 
                    dpi=300, bbox_inches='tight')
        plt.close()
        
        # Calculate and save improvements
        improvements = class_ious_transfer - class_ious_baseline
        relative_improvements = np.zeros_like(improvements)
        for i, (baseline, transfer) in enumerate(zip(class_ious_baseline, class_ious_transfer)):
            if not np.isnan(baseline) and baseline > 0:
                relative_improvements[i] = (transfer - baseline) / baseline * 100
        
        # Sort classes by improvement
        improvement_data = []
        for i in class_indices:
            improvement_data.append({
                'Class': class_names[i],
                'Baseline_IoU': class_ious_baseline[i],
                'Transfer_IoU': class_ious_transfer[i],
                'Absolute_Improvement': improvements[i],
                'Relative_Improvement (%)': relative_improvements[i]
            })
        
        # Sort by relative improvement
        improvement_data.sort(key=lambda x: x['Relative_Improvement (%)'], reverse=True)
        
        # Create a table for the paper
        fig, ax = plt.subplots(figsize=(10, len(improvement_data)*0.4 + 1))
        ax.axis('tight')
        ax.axis('off')
        
        table_data = [[d['Class'], 
                       f"{d['Baseline_IoU']:.4f}", 
                       f"{d['Transfer_IoU']:.4f}", 
                       f"{d['Absolute_Improvement']:.4f}", 
                       f"{d['Relative_Improvement (%)']:.2f}%"] 
                      for d in improvement_data]
        
        table = ax.table(cellText=table_data, 
                         colLabels=['Class', 'Baseline IoU', 'Transfer IoU', 
                                    'Abs. Improvement', 'Rel. Improvement'],
                         loc='center', cellLoc='center')
        
        table.auto_set_font_size(False)
        table.set_fontsize(10)
        table.scale(1, 1.5)
        
        # Highlight cells with significant improvements
        for i, d in enumerate(improvement_data):
            if d['Relative_Improvement (%)'] > 5:  # More than 5% improvement
                for j in range(1, 5):
                    table[(i+1, j)].set_facecolor('#d5f5e3')  # Light green
                    
        plt.title(f'{dataset_name} - Class-wise Improvements', fontsize=14)
        plt.tight_layout()
        plt.savefig(os.path.join(metrics_dir, f"{dataset_name}_improvement_table.png"), 
                    dpi=300, bbox_inches='tight')
        plt.close()
        
        # Save as CSV for reference
        df = pd.DataFrame(improvement_data)
        df.to_csv(os.path.join(metrics_dir, f"{dataset_name}_improvements.csv"), index=False)
    
    # Save the legend for class colors
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.axis('off')
    
    # Create color swatches for legend
    handles = []
    for i in class_indices:
        color = [c/255 for c in class_colors[i]]
        handles.append(plt.Rectangle((0,0), 1, 1, color=color, label=class_names[i]))
    
    ax.legend(handles=handles, loc='center', ncol=2)
    plt.title(f'{dataset_name} - Class Color Legend')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"{dataset_name}_color_legend.png"), 
                dpi=300, bbox_inches='tight')
    plt.close()

def main():
    """Main function to generate visualizations for both datasets."""
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Generate visualizations for KITTI
    print("\nGenerating KITTI visualizations...")
    # Create KITTI dataset and models
    # [Code to implement]
    
    # Generate visualizations for IDD
    print("\nGenerating IDD visualizations...")
    # Create IDD dataset and models
    # [Code to implement]
    
    print("\nVisualization generation complete!")

if __name__ == "__main__":
    main()

Using device: cuda

Generating KITTI visualizations...

Generating IDD visualizations...

Visualization generation complete!


In [3]:
!pip install -U typing_extensions==4.8.0
!pip install -U torch torchvision

Looking in links: /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/x86-64-v3, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/generic, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic
Processing /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/generic/typing_extensions-4.8.0+computecanada-py3-none-any.whl
Installing collected packages: typing_extensions
  Attempting uninstall: typing_extensions
    Found existing installation: typing_extensions 4.7.1
    Uninstalling typing_extensions-4.7.1:
[31mERROR: Could not install packages due to an OSError: [Errno 30] Read-only file system: '/cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v3/Compiler/gcccore/python/3.11.5/lib/python3.11/site-packages/__pycache__/typing_extensions.cpython-311.pyc'
[0m[31m
[0mLooking in links: /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/x86-64-v3, /cvmfs/soft.computecanada.ca/custom/python/wheelhouse/gentoo2023/generic, /cvmfs/

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import SegformerForSemanticSegmentation
import cv2
from PIL import Image
import torchvision.transforms as transforms
from torch.nn import functional as F

class SegFormerAttentionViz:
    def __init__(self, model_path, num_classes=32):
        """
        Initialize with a pretrained SegFormer model
        
        Args:
            model_path: Path to saved model checkpoint
            num_classes: Number of segmentation classes
        """
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        # Load model
        self.model = SegformerForSemanticSegmentation.from_pretrained(
            "nvidia/segformer-b5-finetuned-cityscapes-1024-1024",
            num_labels=num_classes,
            ignore_mismatched_sizes=True
        )
        # Load saved weights
        state_dict = torch.load(model_path, map_location=self.device)
        self.model.load_state_dict(state_dict)
        self.model.to(self.device)
        self.model.eval()
        
        # Register hooks to capture attention
        self.attention_maps = {}
        self._register_hooks()
        
        # Image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def _register_hooks(self):
        """Register forward hooks to extract attention maps from all encoder blocks"""
        for block_idx, block in enumerate(self.model.segformer.encoder.block):
            for layer_idx, layer in enumerate(block):
                if hasattr(layer, 'layer') and hasattr(layer.layer, 'attention'):
                    attn_layer = layer.layer.attention
                    
                    def hook_fn(module, input, output, block_idx=block_idx, layer_idx=layer_idx):
                        # Extract attention weights
                        # Shape: [batch_size, num_heads, seq_len, seq_len]
                        self.attention_maps[f"block{block_idx}_layer{layer_idx}"] = output[1].detach()
                    
                    attn_layer.register_forward_hook(hook_fn)
    
    def process_image(self, image_path):
        """Process an image and get segmentation and attention maps"""
        # Load and preprocess image
        image = Image.open(image_path).convert('RGB')
        input_tensor = self.transform(image).unsqueeze(0).to(self.device)
        
        # Get segmentation output and attention maps
        with torch.no_grad():
            outputs = self.model(input_tensor)
            logits = outputs.logits
            
        # Get predicted segmentation
        seg_pred = torch.argmax(logits, dim=1).cpu().numpy()[0]
        
        # Keep original image for visualization
        self.original_image = np.array(image.resize((512, 512)))
        
        return seg_pred, self.attention_maps
    
    def visualize_attention(self, block_idx, layer_idx, head_idx=0):
        """
        Visualize attention map for a specific block, layer and attention head
        
        Args:
            block_idx: Encoder block index
            layer_idx: Layer index within the block
            head_idx: Attention head index
        """
        key = f"block{block_idx}_layer{layer_idx}"
        if key not in self.attention_maps:
            print(f"No attention map found for {key}")
            return
        
        # Get attention map for specified head
        attn_map = self.attention_maps[key][0, head_idx].cpu().numpy()
        
        # Reshape attention map to 2D grid based on feature map size
        # This depends on the specific architecture and might need adjustment
        h = w = int(np.sqrt(attn_map.shape[0]))
        attn_map = attn_map.reshape(h, w, h, w)
        
        # Average over source tokens to get a 2D attention map
        attn_map = attn_map.mean(axis=(0, 1))
        
        # Resize to match image dimensions
        attn_map = cv2.resize(attn_map, (512, 512))
        
        # Normalize for visualization
        attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
        
        # Create visualization
        plt.figure(figsize=(20, 10))
        
        plt.subplot(1, 2, 1)
        plt.imshow(self.original_image)
        plt.title("Original Image")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(self.original_image)
        plt.imshow(attn_map, alpha=0.5, cmap='jet')
        plt.title(f"Attention Map (Block {block_idx}, Layer {layer_idx}, Head {head_idx})")
        plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(f"attention_b{block_idx}_l{layer_idx}_h{head_idx}.png", dpi=300)
        plt.close()
        
        return attn_map
    
    def visualize_multi_scale_attention(self):
        """Visualize attention maps from different layers to show multi-scale processing"""
        plt.figure(figsize=(20, 15))
        
        # Original image
        plt.subplot(3, 3, 1)
        plt.imshow(self.original_image)
        plt.title("Original Image")
        plt.axis('off')
        
        # Plot attention from different blocks and layers
        plot_idx = 2
        for block_idx in range(4):  # Adjust based on your model
            for layer_idx in [0]:  # Just the first layer of each block
                key = f"block{block_idx}_layer{layer_idx}"
                if key in self.attention_maps:
                    # Get attention map for first head
                    attn_map = self.attention_maps[key][0, 0].cpu().numpy()
                    
                    # Reshape and process
                    h = w = int(np.sqrt(attn_map.shape[0]))
                    attn_map = attn_map.reshape(h, w, h, w)
                    attn_map = attn_map.mean(axis=(0, 1))
                    attn_map = cv2.resize(attn_map, (512, 512))
                    attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
                    
                    plt.subplot(3, 3, plot_idx)
                    plt.imshow(self.original_image)
                    plt.imshow(attn_map, alpha=0.5, cmap='jet')
                    plt.title(f"Block {block_idx}, Layer {layer_idx}")
                    plt.axis('off')
                    plot_idx += 1
        
        plt.tight_layout()
        plt.savefig("multi_scale_attention.png", dpi=300)
        plt.close()

    def compare_models(self, image_path, model_paths, model_names):
        """
        Compare attention maps between different model variants (B3, B4, B5)
        
        Args:
            image_path: Path to input image
            model_paths: List of paths to different model checkpoints
            model_names: List of names for the models (e.g., "B3", "B4", "B5")
        """
        plt.figure(figsize=(20, 5*len(model_paths)))
        
        # Original image
        plt.subplot(len(model_paths)+1, 3, 1)
        image = Image.open(image_path).convert('RGB')
        image_resized = image.resize((512, 512))
        plt.imshow(np.array(image_resized))
        plt.title("Original Image")
        plt.axis('off')
        
        # For each model
        for i, (model_path, model_name) in enumerate(zip(model_paths, model_names)):
            # Create a new instance with this model
            viz = SegFormerAttentionViz(model_path, num_classes=32)
            
            # Process image
            seg_pred, attention_maps = viz.process_image(image_path)
            
            # Plot segmentation
            plt.subplot(len(model_paths)+1, 3, 3*i+4)
            plt.imshow(self._colorize_segmentation(seg_pred))
            plt.title(f"{model_name} Segmentation")
            plt.axis('off')
            
            # Plot early layer attention
            key = "block0_layer0"
            if key in attention_maps:
                attn_map = attention_maps[key][0, 0].cpu().numpy()
                h = w = int(np.sqrt(attn_map.shape[0]))
                attn_map = attn_map.reshape(h, w, h, w)
                attn_map = attn_map.mean(axis=(0, 1))
                attn_map = cv2.resize(attn_map, (512, 512))
                attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
                
                plt.subplot(len(model_paths)+1, 3, 3*i+5)
                plt.imshow(np.array(image_resized))
                plt.imshow(attn_map, alpha=0.5, cmap='jet')
                plt.title(f"{model_name} Early Layer Attention")
                plt.axis('off')
            
            # Plot late layer attention
            key = "block3_layer0"
            if key in attention_maps:
                attn_map = attention_maps[key][0, 0].cpu().numpy()
                h = w = int(np.sqrt(attn_map.shape[0]))
                attn_map = attn_map.reshape(h, w, h, w)
                attn_map = attn_map.mean(axis=(0, 1))
                attn_map = cv2.resize(attn_map, (512, 512))
                attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
                
                plt.subplot(len(model_paths)+1, 3, 3*i+6)
                plt.imshow(np.array(image_resized))
                plt.imshow(attn_map, alpha=0.5, cmap='jet')
                plt.title(f"{model_name} Late Layer Attention")
                plt.axis('off')
        
        plt.tight_layout()
        plt.savefig("model_comparison.png", dpi=300)
        plt.close()

    def _colorize_segmentation(self, segmentation):
        """Convert segmentation indices to colors for visualization"""
        # Create a colormap (this is a simple example, you might want a more sophisticated one)
        colormap = np.random.randint(0, 256, (256, 3), dtype=np.uint8)
        colormap[0] = [0, 0, 0]  # background
        
        # Apply colormap
        colored_segmentation = colormap[segmentation]
        
        return colored_segmentation

# Usage example
if __name__ == "__main__":
    # Initialize with your model path
    model_path = "path/to/segformer_b5_camvid.pth"
    viz = SegFormerAttentionViz(model_path, num_classes=32)
    
    # Process an image
    image_path = "path/to/your/image.jpg"
    seg_pred, attention_maps = viz.process_image(image_path)
    
    # Visualize attention for specific block/layer/head
    viz.visualize_attention(block_idx=3, layer_idx=5, head_idx=0)
    
    # Visualize multi-scale attention maps
    viz.visualize_multi_scale_attention()
    
    # Compare different model variants
    model_paths = [
        "path/to/segformer_b3_camvid.pth",
        "path/to/segformer_b4_camvid.pth",
        "path/to/segformer_b5_camvid.pth"
    ]
    model_names = ["SegFormer-B3", "SegFormer-B4", "SegFormer-B5"]
    viz.compare_models(image_path, model_paths, model_names)

config.json:   0%|          | 0.00/1.68k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/339M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/339M [00:00<?, ?B/s]

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b5-finetuned-cityscapes-1024-1024 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([19, 768, 1, 1]) in the checkpoint and torch.Size([32, 768, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([19]) in the checkpoint and torch.Size([32]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


FileNotFoundError: [Errno 2] No such file or directory: 'path/to/segformer_b5_camvid.pth'