Hierarchical Explainable Prototypical Network (HEPN) on CUB-200-2011

Integrates learned part prototypes into the few-shot meta-learning process
for enhanced explainability. Manual Grad-CAM implementation.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from tqdm.notebook import tqdm # Use standard tqdm if not in notebook
import math
import os
from PIL import Image
import cv2 # For resizing CAM and colormap
import copy
import json # For split saving/loading

In [None]:
# -------------------------------------
# --- Configuration ---
# -------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Dataset params
# Adapt DATA_DIR to use BASE_PATH
BASE_PATH = '/content/drive/MyDrive' # Update with your path
DATA_DIR = os.path.join(BASE_PATH, 'CUB_200_2011')
IMAGES_DIR = os.path.join(DATA_DIR, 'images')
IMAGE_SIZE = 224
SPLIT_FILE = os.path.join(BASE_PATH, "cub_meta_split.json") # Use BASE_PATH
FORCE_RESPLIT = False # Set True to ignore saved split

# Episode params (Adjust as needed)
N_WAY = 3
K_SHOT = 5
N_QUERY = 10
N_TRAIN_EPISODES = 5000 # Needs sufficient episodes to learn parts
N_TEST_EPISODES = 600

# --- HEPN Model Params ---
EMBEDDING_DIM_GLOBAL = 256  # Dimension for global embeddings (Level 1)
EMBEDDING_DIM_PATCH = 512   # Dimension for patch embeddings (Level 2 - ResNet18 output)
NUM_PART_PROTOTYPES = 5    # Number of learnable generic part prototypes (Hyperparameter)
PART_PROJECTION_INTERVAL = 500 # How often (in episodes) to project part prototypes
PROJECTION_BATCH_SIZE = 128 # Batch size for projection dataset processing

# Encoder Params
PRETRAINED = True
FREEZE_UNTIL_LAYER = "layer3" # Or None, layer1, layer2 etc.
DROPOUT_RATE = 0.5
TARGET_LAYER_NAME = 'encoder.layer4' # Target layer for manual Grad-CAM

# Training params
LR_BACKBONE = 1e-5
LR_HEAD_GLOBAL = 1e-4
LR_PART_PROTOTYPES = 1e-4 # Learning rate for the part prototypes themselves
LR_COMBINER = 1e-4      # Learning rate for combination layer/params
WEIGHT_DECAY = 5e-4
LABEL_SMOOTHING = 0.1
GRADIENT_CLIP_NORM = 1.0

# --- Loss Weights (Crucial Hyperparameters) ---
LAMBDA_CLST = 0.1         # Weight for Part Prototype Cluster Loss
LAMBDA_DIVERSITY = 0.05   # Weight for Part Prototype Diversity Loss
LAMBDA_L1 = 0.00          # Weight for L1 penalty on part activations (encourages sparsity)
LAMBDA_PART_SIM = 0.5     # Weight for combining part similarity score into final logits

# Visualization Params
# PATCH_SIZE_VIS = 28 # Note: Actual patch size depends on encoder stride

In [None]:

# -------------------------------------
# --- CUB Data Handling (Robust Version) ---
# -------------------------------------

def parse_cub_metadata(data_dir):
    """Reads CUB metadata files into pandas DataFrames."""
    images_path = os.path.join(data_dir, 'images.txt')
    labels_path = os.path.join(data_dir, 'image_class_labels.txt')
    split_path = os.path.join(data_dir, 'train_test_split.txt')
    bbox_path = os.path.join(data_dir, 'bounding_boxes.txt')
    classes_path = os.path.join(data_dir, 'classes.txt')

    required_files = [images_path, labels_path, split_path, bbox_path, classes_path]
    for f_path in required_files:
        if not os.path.exists(f_path):
            raise FileNotFoundError(f"Metadata file not found: {f_path}.")

    df_images = pd.read_csv(images_path, sep=' ', names=['img_id', 'filepath'])
    df_labels = pd.read_csv(labels_path, sep=' ', names=['img_id', 'class_id'])
    df_split = pd.read_csv(split_path, sep=' ', names=['img_id', 'is_training'])
    df_bboxes = pd.read_csv(bbox_path, sep=' ', names=['img_id', 'x', 'y', 'width', 'height'])
    df_classes = pd.read_csv(classes_path, sep=' ', names=['class_id', 'class_name'])

    df_labels['class_id'] = df_labels['class_id'] - 1 # 0-based
    df = df_images.merge(df_labels, on='img_id')
    df = df.merge(df_split, on='img_id')
    df = df.merge(df_bboxes, on='img_id')
    df['full_path'] = df['filepath'].apply(lambda x: os.path.join(IMAGES_DIR, x))
    class_id_to_name = df_classes.set_index('class_id')['class_name'].to_dict()
    class_id_to_name = {(k - 1): v for k, v in class_id_to_name.items()} # Adjust keys to 0-based
    print(f"Parsed metadata for {len(df)} total image entries across {df['class_id'].nunique()} original classes.")
    return df, class_id_to_name


class CubDataset(Dataset):
    """Custom Dataset for CUB-200-2011 with cropping."""
    def __init__(self, df_subset, transform=None):
        self.df = df_subset
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_info = self.df.iloc[idx]
        img_path = img_info['full_path']
        label = img_info['subset_class_id'] # Use 0-based subset ID
        bbox = (img_info['x'], img_info['y'], img_info['width'], img_info['height'])

        try:
            image = Image.open(img_path).convert('RGB')
        except FileNotFoundError:
            print(f"Warning: Image not found {img_path}. Skipping.")
            return None, -1 # Return None image, handle in dataloader or loop
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return None, -1

        x, y, w, h = bbox
        left, upper = int(np.floor(x)), int(np.floor(y))
        right, lower = int(np.ceil(x + w)), int(np.ceil(y + h))
        img_width, img_height = image.size
        left, upper = max(0, left), max(0, upper)
        right, lower = min(img_width, right), min(img_height, lower)

        if right > left and lower > upper:
             image = image.crop((left, upper, right, lower))
        # else: Optional: print warning about invalid crop using full image

        if self.transform:
            image = self.transform(image)

        # Check if transform failed (e.g., returned None)
        if image is None:
            print(f"Warning: Transform returned None for image {img_path}. Skipping.")
            return None, -1

        return image, label


def prepare_cub_data_splits(data_dir, n_meta_train_ratio=0.7, split_save_path="cub_meta_split.json", force_resplit=False):
    """
    Loads CUB data, detects available classes, prepares meta-splits
    (loading existing split if available, otherwise creating and saving),
    creates data lists for samplers, and returns a Dataset for meta-train.
    """
    images_dir_local = os.path.join(data_dir, 'images')
    if not os.path.isdir(images_dir_local):
         raise FileNotFoundError(f"Images directory not found: {images_dir_local}")

    df_all, class_id_to_name_map = parse_cub_metadata(data_dir)

    # --- Detect Available Classes ---
    available_folders = [d for d in os.listdir(images_dir_local) if os.path.isdir(os.path.join(images_dir_local, d))]
    available_original_class_ids = []
    for folder_name in available_folders:
        try:
            class_num_str = folder_name.split('.')[0]
            original_id_one_based = int(class_num_str)
            available_original_class_ids.append(original_id_one_based - 1) # Store 0-based original ID
        except ValueError:
             print(f"Skipping folder '{folder_name}' - cannot parse class number.")
             continue
    available_original_class_ids = sorted(list(set(available_original_class_ids)))

    if not available_original_class_ids:
        raise ValueError("No valid class folders found or parsed.")

    N_CLASSES_TOTAL_DETECTED = len(available_original_class_ids)
    print(f"\nDetected {N_CLASSES_TOTAL_DETECTED} classes based on folder names.")
    if N_CLASSES_TOTAL_DETECTED > 0:
        print(f"Sample detected original class IDs (0-based): {available_original_class_ids[:5]}...")

    # --- Load or Create Meta-Split ---
    meta_train_subset_indices = None
    meta_test_subset_indices = None

    if os.path.exists(split_save_path) and not force_resplit:
        print(f"Loading existing class split from: {split_save_path}")
        try:
            with open(split_save_path, 'r') as f:
                split_data = json.load(f)
            if ('meta_train_indices' in split_data and 'meta_test_indices' in split_data and
                isinstance(split_data['meta_train_indices'], list) and
                isinstance(split_data['meta_test_indices'], list)):
                loaded_train_indices = split_data['meta_train_indices']
                loaded_test_indices = split_data['meta_test_indices']
                all_loaded_indices = set(loaded_train_indices + loaded_test_indices)
                available_set = set(available_original_class_ids)
                invalid_indices = all_loaded_indices - available_set

                if not invalid_indices:
                    if not set(loaded_train_indices).intersection(set(loaded_test_indices)):
                        meta_train_subset_indices = loaded_train_indices
                        meta_test_subset_indices = loaded_test_indices
                        print("Successfully loaded and validated existing split (using original class IDs).")
                    else: print("Warning: Loaded split has overlapping train/test indices. Will create new split.")
                else: print(f"Warning: Loaded split contains original class IDs ({invalid_indices}) not present in currently detected classes ({available_set}). Will create new split.")
            else: print(f"Warning: Invalid format in {split_save_path}. Will create new split.")
        except Exception as e: print(f"Error loading split file {split_save_path}: {e}. Will create new split.")

    if meta_train_subset_indices is None or meta_test_subset_indices is None:
        print("Creating new meta-train/test class split...")
        n_meta_train_actual = int(len(available_original_class_ids) * n_meta_train_ratio)
        n_meta_test_actual = len(available_original_class_ids) - n_meta_train_actual
        if n_meta_train_actual <= 0 or n_meta_test_actual <= 0:
            raise ValueError(f"Cannot split {len(available_original_class_ids)} detected classes into train/test with ratio {n_meta_train_ratio}. Need more classes or adjust ratio.")

        shuffled_available_ids = random.sample(available_original_class_ids, len(available_original_class_ids))
        meta_train_subset_indices = sorted(shuffled_available_ids[:n_meta_train_actual]) # Store ORIGINAL IDs
        meta_test_subset_indices = sorted(shuffled_available_ids[n_meta_train_actual:]) # Store ORIGINAL IDs

        try:
            split_data_to_save = {'meta_train_indices': meta_train_subset_indices, 'meta_test_indices': meta_test_subset_indices,
                                  'comment': f'Split of {len(available_original_class_ids)} detected classes based on original CUB IDs.'}
            with open(split_save_path, 'w') as f: json.dump(split_data_to_save, f, indent=4)
            print(f"Saved new class split (using original IDs) to: {split_save_path}")
        except Exception as e: print(f"Error saving new split file {split_save_path}: {e}")

    n_meta_train_actual = len(meta_train_subset_indices)
    n_meta_test_actual = len(meta_test_subset_indices)
    print(f"Using finalized split: {n_meta_train_actual} meta-train classes, {n_meta_test_actual} meta-test classes (based on original CUB IDs).")

    # Filter main dataframe based on DETECTED available original IDs
    df_subset = df_all[df_all['class_id'].isin(available_original_class_ids)].copy()

    # Create map from original ID to NEW, CONTIGUOUS 0-based subset ID (includes ALL detected)
    all_detected_subset_map = {orig_id: new_id for new_id, orig_id in enumerate(available_original_class_ids)}
    df_subset.loc[:, 'subset_class_id'] = df_subset['class_id'].map(all_detected_subset_map)

    # Get the subset_class_ids for the chosen meta-train/test original IDs
    meta_train_subset_class_ids = sorted([all_detected_subset_map[orig_id] for orig_id in meta_train_subset_indices])
    meta_test_subset_class_ids = sorted([all_detected_subset_map[orig_id] for orig_id in meta_test_subset_indices])

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        normalize,
    ])

    # Adjust global N_WAY if needed
    global N_WAY
    max_way_train = len(meta_train_subset_class_ids)
    max_way_test = len(meta_test_subset_class_ids)
    original_N_WAY = N_WAY
    updated_n_way = False
    if N_WAY > max_way_train:
        print(f"WARNING: N_WAY ({N_WAY}) > available meta-train classes ({max_way_train}). Reducing N_WAY to {max_way_train}.")
        N_WAY = max_way_train
        updated_n_way = True
    if N_WAY > max_way_test:
        # Check if N_WAY was already reduced
        current_n_way_for_test_check = N_WAY if updated_n_way else original_N_WAY
        print(f"WARNING: N_WAY ({current_n_way_for_test_check}) > available meta-test classes ({max_way_test}). Reducing N_WAY to {max_way_test}.")
        N_WAY = max_way_test
        updated_n_way = True
    if N_WAY <= 1:
        print(f"CRITICAL WARNING: N_WAY reduced to {N_WAY}. Few-shot learning might not be meaningful.")

    # Separate CUB's original train/test split *within* our meta-split classes (using SUBSET CLASS IDs)
    df_meta_train_pool = df_subset[df_subset['subset_class_id'].isin(meta_train_subset_class_ids)]
    df_meta_test_pool = df_subset[df_subset['subset_class_id'].isin(meta_test_subset_class_ids)]
    df_meta_train_from_cub_train = df_meta_train_pool[df_meta_train_pool['is_training'] == 1]
    df_meta_test_from_cub_test = df_meta_test_pool[df_meta_test_pool['is_training'] == 0]

    # --- Helper to load data for samplers ---
    def load_data_for_sampler(target_subset_class_ids, df_source, split_name):
        data = [] # List of (local_sampler_class_id, [image_tensor, ...])
        class_map_for_sampler = {subset_id: i for i, subset_id in enumerate(target_subset_class_ids)}

        pbar = tqdm(target_subset_class_ids, desc=f"Processing Meta-{split_name} Classes for Sampler")
        for subset_class_id in pbar:
            local_label = class_map_for_sampler[subset_class_id]
            class_df = df_source[df_source['subset_class_id'] == subset_class_id]
            if not class_df.empty:
                temp_class_df = class_df.copy()
                temp_class_df.loc[:, 'subset_class_id'] = local_label # Overwrite with 0-based local label
                temp_dataset = CubDataset(temp_class_df, transform=transform)
                images, labels = [], []
                for i in range(len(temp_dataset)):
                    img, lbl = temp_dataset[i]
                    if img is not None: # Skip if image loading/transform failed
                        images.append(img)
                        labels.append(lbl)
                if images:
                    data.append((local_label, images))
            pbar.set_postfix({"Class": subset_class_id, "Images": len(images) if images else 0})

        data.sort(key=lambda x: x[0])
        min_samples_needed = K_SHOT + N_QUERY
        valid_classes_count = sum(1 for _, imgs in data if len(imgs) >= min_samples_needed)
        num_classes_in_sampler = len(data)
        if valid_classes_count < num_classes_in_sampler: print(f"*** WARNING ({split_name} Sampler): {num_classes_in_sampler - valid_classes_count}/{num_classes_in_sampler} classes have < {min_samples_needed} samples. Replacement needed.")
        if valid_classes_count < N_WAY: print(f"*** CRITICAL WARNING ({split_name} Sampler): Only {valid_classes_count} classes have enough samples. N_WAY ({N_WAY}) might be impossible without replacement.")
        return data

    print("\nBuilding meta-train data for sampler (from CUB train split)...")
    meta_train_data = load_data_for_sampler(meta_train_subset_class_ids, df_meta_train_from_cub_train, "Train")
    print("\nBuilding meta-test data for sampler (from CUB test split)...")
    meta_test_data = load_data_for_sampler(meta_test_subset_class_ids, df_meta_test_from_cub_test, "Test")

    # --- Create the Full Meta-Train Dataset for Projection ---
    print("\nCreating full meta-train dataset for projection...")
    meta_train_full_dataset = CubDataset(df_meta_train_from_cub_train, transform=transform)
    print(f"Full meta-train dataset size (for projection): {len(meta_train_full_dataset)} images")

    print(f"\nFinal classes available for meta-train sampler: {len(meta_train_data)}")
    print(f"Final classes available for meta-test sampler: {len(meta_test_data)}")

    if not meta_train_data or not meta_test_data or len(meta_train_full_dataset) == 0:
        print("\n*** ERROR: Empty meta_train_data, meta_test_data, or meta_train_full_dataset. Check data. ***")
        return None, None, None

    return meta_train_data, meta_test_data, meta_train_full_dataset

# --- Episode Sampler (Unchanged) ---
class EpisodeSampler:
    """Samples episodes for N-way K-shot learning."""
    def __init__(self, meta_data, n_way, k_shot, n_query):
        self.meta_data = meta_data
        if not meta_data: raise ValueError("Meta data cannot be empty")
        self.num_classes = len(meta_data)
        if n_way <= 0 or n_way > self.num_classes:
            raise ValueError(f"Invalid n_way ({n_way}) for {self.num_classes} available classes.")
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query
        # Check minimum samples required per class
        self.min_samples_needed = k_shot + n_query
        self.classes_with_enough_samples = [
            i for i, (_, images) in enumerate(meta_data) if len(images) >= self.min_samples_needed
        ]
        self.classes_requiring_replacement = [
            i for i, (_, images) in enumerate(meta_data) if len(images) < self.min_samples_needed
        ]
        if len(self.classes_with_enough_samples) < self.n_way:
             print(f"WARNING (Sampler Init): Only {len(self.classes_with_enough_samples)} classes have >= {self.min_samples_needed} samples. "
                   f"N_WAY ({self.n_way}) requires sampling *with replacement* from some classes even for class selection, "
                   f"or relying heavily on image replacement within selected classes.")
             # Consider alternative: raise ValueError if len(self.classes_with_enough_samples) == 0


    def sample(self):
        support_imgs, support_lbls, query_imgs, query_lbls = [], [], [], []

        # Prioritize sampling classes with enough samples if possible
        available_indices = list(range(self.num_classes))
        try:
            # Sample class indices. If fewer than N_WAY have enough samples,
            # replacement=True might be necessary for the class selection itself,
            # though ideally we avoid sampling the same class multiple times in one episode.
            # A safer strategy is to sample from all available, even if some require image replacement later.
            sampled_class_indices = random.sample(available_indices, self.n_way)
        except ValueError:
             print(f"Error: Cannot sample {self.n_way} distinct classes from {self.num_classes} available.")
             # Fallback: sample with replacement if absolutely necessary (not ideal for few-shot)
             # sampled_class_indices = random.choices(available_indices, k=self.n_way)
             # Or simply raise the error:
             raise ValueError(f"Not enough unique classes ({self.num_classes}) available to sample N_WAY={self.n_way}.")


        for local_lbl, class_idx in enumerate(sampled_class_indices):
            _, images = self.meta_data[class_idx]
            n_available = len(images)
            n_needed = self.k_shot + self.n_query

            if n_available == 0:
                print(f"Warning: Class index {class_idx} (local label {local_lbl}) has 0 images. Skipping.")
                # This case should ideally be prevented by checks in data loading.
                # If it happens, the episode might be smaller than intended.
                # A more robust approach might be to resample the class or episode.
                continue # Skip this class for this episode

            use_replacement = n_available < n_needed
            if use_replacement:
                # print(f"Note: Using replacement for class index {class_idx} (local {local_lbl}) - needed {n_needed}, got {n_available}")
                indices = random.choices(range(n_available), k=n_needed)
            else:
                indices = random.sample(range(n_available), n_needed)

            support_imgs.extend([images[i] for i in indices[:self.k_shot]])
            support_lbls.extend([local_lbl] * self.k_shot)
            query_imgs.extend([images[i] for i in indices[self.k_shot:]])
            query_lbls.extend([local_lbl] * self.n_query)

        # Check if any images were actually added before stacking
        if not support_imgs or not query_imgs:
             # This could happen if all selected classes had 0 images
             raise ValueError("Failed to gather any support or query images for the episode.")


        support_imgs = torch.stack(support_imgs)
        support_lbls = torch.LongTensor(support_lbls)
        query_imgs = torch.stack(query_imgs)
        query_lbls = torch.LongTensor(query_lbls)

        # Shuffle query set
        perm = torch.randperm(len(query_lbls))
        query_imgs, query_lbls = query_imgs[perm], query_lbls[perm]

        return support_imgs, support_lbls, query_imgs, query_lbls

In [None]:
# -------------------------------------
# --- HEPN Model Architecture ---
# -------------------------------------

class EncoderHEPN(nn.Module):
    """
    Encoder for HEPN. Outputs both patch-level conv features and global embedding.
    Based on ResNet18 with freezing and dropout.
    """
    def __init__(self, embedding_dim_global=EMBEDDING_DIM_GLOBAL,
                 patch_feature_dim=EMBEDDING_DIM_PATCH, # Should match ResNet layer4 output
                 pretrained=True, freeze_until=None, dropout_rate=0.5):
        super().__init__()
        resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT if pretrained else None)

        # Backbone Conv Layers
        self.stem = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4 # Output feature map for patches

        # Keep individual layers accessible, also group them
        self.conv_feature_extractor_layers = [self.stem, self.layer1, self.layer2, self.layer3, self.layer4]

        if freeze_until:
            self._freeze_layers(freeze_until)

        # Head for Global Embedding (Level 1)
        resnet_out_dim = resnet.fc.in_features # 512 for ResNet18
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.global_dropout = nn.Dropout(p=dropout_rate)
        self.global_embedding_layer = nn.Linear(resnet_out_dim, embedding_dim_global)

        self.global_embedding_head = nn.Sequential(
            self.global_pool,
            self.flatten,
            self.global_dropout,
            self.global_embedding_layer
        )

        # Patch feature dimension is output of last conv layer (layer4)
        self.patch_feature_dim = resnet_out_dim
        if self.patch_feature_dim != patch_feature_dim:
             print(f"Info: Requested patch_feature_dim {patch_feature_dim} differs "
                   f"from ResNet output {self.patch_feature_dim}. Using {self.patch_feature_dim}.")
             # No override needed, just use the actual dim from ResNet

        self.total_stride = 32 # For ResNet18 (2*2*2*2 * 2 from stem maxpool)

    def _freeze_layers(self, freeze_until):
        print(f"Freezing ResNet layers up to and including: {freeze_until}")
        layers_to_freeze = []
        if freeze_until == "stem": layers_to_freeze = [self.stem]
        elif freeze_until == "layer1": layers_to_freeze = [self.stem, self.layer1]
        elif freeze_until == "layer2": layers_to_freeze = [self.stem, self.layer1, self.layer2]
        elif freeze_until == "layer3": layers_to_freeze = [self.stem, self.layer1, self.layer2, self.layer3]
        elif freeze_until == "layer4": layers_to_freeze = [self.stem, self.layer1, self.layer2, self.layer3, self.layer4]
        elif freeze_until is None or freeze_until.lower() == "none":
            print("No layers frozen.")
            return
        else:
            print(f"Warning: Unknown freeze_until value '{freeze_until}'. No layers frozen.")
            return

        for layer in layers_to_freeze:
            for param in layer.parameters():
                param.requires_grad = False
        print(f"Finished freezing {len(layers_to_freeze)} layer groups.")


    def forward(self, x):
        # Pass through convolutional layers sequentially
        conv_features = x
        for layer in self.conv_feature_extractor_layers:
            conv_features = layer(conv_features)
        # Now conv_features is the output of layer4: [B, D_patch, H, W]

        global_embedding = self.global_embedding_head(conv_features) # [B, E_global]
        return conv_features, global_embedding

    def get_trainable_parameters(self):
        backbone_params = []
        head_params = [] # Global head params
        # Iterate through named parameters to identify head vs backbone
        for name, param in self.named_parameters():
            if not param.requires_grad:
                continue # Skip frozen layers
            # Check if param belongs to the global embedding head
            if name.startswith('global_embedding_layer.') or \
               name.startswith('global_dropout.') or \
               name.startswith('flatten.') or \
               name.startswith('global_pool.'): # Check module names
                head_params.append(param)
            else:
                # Assume remaining trainable params are part of the conv backbone
                backbone_params.append(param)

        print(f"Encoder Param Groups: {len(backbone_params)} backbone, {len(head_params)} global_head")
        return backbone_params, head_params


class PartPrototypeLayer(nn.Module):
    """
    Learns generic part prototypes and calculates part activation profiles.
    Includes prototype projection functionality.
    """
    def __init__(self, num_parts=NUM_PART_PROTOTYPES, patch_feature_dim=EMBEDDING_DIM_PATCH):
        super().__init__()
        self.num_parts = num_parts
        self.patch_feature_dim = patch_feature_dim

        # Learnable part prototypes
        self.part_prototypes = nn.Parameter(torch.randn(self.num_parts, self.patch_feature_dim),
                                            requires_grad=True)

        # Store img_idx (in dataset), h_idx, w_idx of the patch projected onto
        self.register_buffer('projected_patch_locations', torch.full((self.num_parts, 3), -1, dtype=torch.long)) # Store img_idx, h, w
        self.register_buffer('projected_image_indices', torch.full((self.num_parts,), -1, dtype=torch.long)) # Redundant? Keep maybe.


    def _calculate_similarity(self, patches, prototypes):
        """ Calculates patch-prototype similarity (negative squared L2 distance) """
        # Patches: [N_Patches, D]
        # Prototypes: [M_Prototypes, D]
        # Ensure consistent dtype for cdist
        patches_f = patches.float()
        prototypes_f = prototypes.float()
        distances_sq = torch.cdist(patches_f, prototypes_f)**2 # [N_Patches, M_Prototypes]
        similarity = -distances_sq
        return similarity

    def calculate_query_activation(self, query_conv_features):
        """ Calculates the part activation vector for a query image. """
        # query_conv_features: [B, D, H, W] (Assume B=1 for single query)
        B, D, H, W = query_conv_features.shape
        if B != 1: print(f"Warning: calculate_query_activation called with B={B}")
        N_PATCHES_Q = H * W
        if N_PATCHES_Q <= 0: return torch.zeros(self.num_parts, device=query_conv_features.device)

        query_patches = query_conv_features.permute(0, 2, 3, 1).reshape(N_PATCHES_Q, D) # [HW, D]
        sim_matrix = self._calculate_similarity(query_patches, self.part_prototypes) # [HW, M]

        # Max similarity for each prototype across all query patches
        activation_q, _ = torch.max(sim_matrix, dim=0) # [M]
        return activation_q

    def calculate_class_profile(self, support_conv_features_list):
        """ Calculates the average part activation profile for a class's support set. """
        if isinstance(support_conv_features_list, list):
            if not support_conv_features_list: return torch.zeros(self.num_parts, device=self.part_prototypes.device)
            support_conv_features = torch.stack(support_conv_features_list, dim=0) # [K, D, H, W]
        else:
             support_conv_features = support_conv_features_list # Assume already stacked tensor

        K, D, H, W = support_conv_features.shape
        N_PATCHES_S = H * W
        if K == 0 or N_PATCHES_S <= 0: return torch.zeros(self.num_parts, device=support_conv_features.device)

        support_patches = support_conv_features.permute(0, 2, 3, 1).reshape(K * N_PATCHES_S, D) # [K*HW, D]
        sim_matrix_s = self._calculate_similarity(support_patches, self.part_prototypes) # [K*HW, M]
        sim_matrix_s = sim_matrix_s.view(K, N_PATCHES_S, self.num_parts) # [K, HW, M]

        # Max similarity per image in the support set
        max_sim_per_image, _ = torch.max(sim_matrix_s, dim=1) # [K, M]

        # Average activation profile across the support set
        class_profile = torch.mean(max_sim_per_image, dim=0) # [M]
        return class_profile

    def calculate_part_similarity_score(self, activation_q, class_profile):
        """ Calculates similarity between query activation and class profile (e.g., Cosine Sim). """
        return F.cosine_similarity(activation_q.unsqueeze(0), class_profile.unsqueeze(0), eps=1e-6).squeeze()

    # --- Losses ---
    def calculate_cluster_loss(self, conv_features):
        """ Encourages patches to be close to *some* part prototype. """
        B, D, H, W = conv_features.shape
        N_PATCHES = H * W
        if B == 0 or N_PATCHES <= 0: return torch.tensor(0.0, device=conv_features.device)

        patches = conv_features.permute(0, 2, 3, 1).reshape(B * N_PATCHES, D)
        dist_matrix_sq = torch.cdist(patches.float(), self.part_prototypes.float())**2 # [B*HW, M]
        min_dist_sq, _ = torch.min(dist_matrix_sq, dim=1) # [B*HW]
        cluster_loss = torch.mean(min_dist_sq)
        return cluster_loss

    def calculate_diversity_loss(self):
        """ Encourage part prototypes to be distinct (maximize avg distance). """
        proto_dist_sq = torch.cdist(self.part_prototypes.float(), self.part_prototypes.float())**2 # [M, M]
        n = self.num_parts
        if n <= 1: return torch.tensor(0.0, device=self.part_prototypes.device)
        # Sum upper triangle (excluding diagonal)
        sum_dist_sq = torch.triu(proto_dist_sq, diagonal=1).sum()
        num_pairs = n * (n - 1) / 2
        avg_dist_sq = sum_dist_sq / num_pairs if num_pairs > 0 else 1.0
        # Penalize inverse of average distance (lower loss = larger avg dist)
        diversity_loss = 1.0 / (avg_dist_sq + 1e-6)
        return diversity_loss

    def calculate_l1_loss(self, activations):
         """ Penalizes dense part activations (if used). """
         if activations is None or activations.numel() == 0:
             return torch.tensor(0.0, device=self.part_prototypes.device)
         return torch.mean(torch.abs(activations))

    # --- Prototype Projection ---
    @torch.no_grad()
    def project_part_prototypes(self, dataset_for_proj, encoder_model, device, batch_size=PROJECTION_BATCH_SIZE):
        """
        Finds the patch feature vector in the dataset closest to each prototype
        and updates the prototype to be that patch feature vector. Includes refined checks.
        """
        print(f"\n--- Starting Part Prototype Projection (Dataset size: {len(dataset_for_proj)}) ---")
        if len(dataset_for_proj) == 0:
            print("Warning: Projection dataset is empty. Skipping projection.")
            return

        original_encoder_mode = encoder_model.training
        encoder_model.eval()

        # Filter out None items from dataset BEFORE creating DataLoader
        # This requires iterating through the dataset once, could be slow for large datasets
        # Alternatively, handle None inside the dataloader loop (collate_fn can help)
        valid_indices = [i for i, (img, _) in enumerate(dataset_for_proj) if img is not None]
        if len(valid_indices) != len(dataset_for_proj):
            print(f"Warning: Filtering out {len(dataset_for_proj) - len(valid_indices)} None items from projection dataset.")
            dataset_for_proj_filtered = torch.utils.data.Subset(dataset_for_proj, valid_indices)
            # Store mapping from filtered index to original index if needed for location tracking
            index_map = {new_idx: orig_idx for new_idx, orig_idx in enumerate(valid_indices)}
        else:
            dataset_for_proj_filtered = dataset_for_proj
            index_map = {i: i for i in range(len(dataset_for_proj))} # Identity map

        if len(dataset_for_proj_filtered) == 0:
             print("Warning: Projection dataset empty after filtering None items. Skipping projection.")
             encoder_model.train(original_encoder_mode)
             return

        pin_memory_flag = True if device.type == 'cuda' else False
        num_workers = min(4, os.cpu_count() // 2 if os.cpu_count() else 1) if pin_memory_flag else 0
        dataloader = DataLoader(dataset_for_proj_filtered, batch_size=batch_size, shuffle=False,
                                num_workers=num_workers, pin_memory=pin_memory_flag)

        all_patch_features = []
        all_patch_origins = [] # List of (original_dataset_idx, h, w)
        current_filtered_idx = 0

        pbar = tqdm(dataloader, desc="Extracting Features for Projection", leave=False)
        for images, _ in pbar:
            if images is None: continue # Should be handled by filtering, but double-check
            images = images.to(device, non_blocking=pin_memory_flag)
            conv_features, _ = encoder_model(images)
            B, D, H, W = conv_features.shape
            if H <= 0 or W <= 0: continue

            patches = conv_features.permute(0, 2, 3, 1).reshape(B * H * W, D)
            all_patch_features.append(patches.cpu()) # Store on CPU first

            for b in range(B):
                # Map filtered index back to original dataset index
                original_dataset_index = index_map[current_filtered_idx + b]
                for h in range(H):
                    for w in range(W):
                        all_patch_origins.append((original_dataset_index, h, w))
            current_filtered_idx += B
            pbar.set_postfix({"Patches": f"{len(all_patch_origins):,}"})


        if not all_patch_features:
            print("Warning: No patch features extracted. Skipping projection.")
            encoder_model.train(original_encoder_mode)
            return

        all_patch_features = torch.cat(all_patch_features, dim=0)
        print(f"Extracted {all_patch_features.shape[0]:,} patch features ({all_patch_features.element_size() * all_patch_features.nelement() / 1024**2:.1f} MB on CPU).")

        # Move to device in chunks if necessary, but try all at once first
        try:
            all_patch_features = all_patch_features.to(device)
            print(f"Moved all patch features to {device}.")
        except RuntimeError as e:
            print(f"Error moving all patch features to {device}: {e}. Projection might fail or be slow.")
            # Implement chunking here if needed
            encoder_model.train(original_encoder_mode)
            return # Or proceed with CPU-based projection if memory is an issue

        min_dists_sq = torch.full((self.num_parts,), float('inf'), device=device, dtype=torch.float32)
        nearest_patch_indices = torch.full((self.num_parts,), -1, dtype=torch.long, device=device)
        intended_locations = {} # {proto_idx: (orig_img_idx, h, w)}

        print("Finding nearest patches for prototypes...")
        # Process prototypes in batches to potentially reduce memory usage during cdist
        proto_batch_size = 10 # Adjust based on GPU memory
        num_proto_batches = math.ceil(self.num_parts / proto_batch_size)

        for i in tqdm(range(num_proto_batches), desc="Projecting Prototypes", leave=False):
            start_idx = i * proto_batch_size
            end_idx = min((i + 1) * proto_batch_size, self.num_parts)
            current_proto_indices_abs = torch.arange(start_idx, end_idx, device='cpu')
            current_prototypes = self.part_prototypes[start_idx:end_idx].to(device=device, dtype=torch.float32)

            # Calculate distances between all patches and the current batch of prototypes
            dist_matrix_sq = torch.cdist(all_patch_features, current_prototypes)**2 # [N_Patches, N_Proto_Batch]

            # Find the closest patch for each prototype *in this batch*
            batch_min_dists_sq, batch_nearest_patch_indices_in_all = torch.min(dist_matrix_sq, dim=0) # [N_Proto_Batch]

            # Update global minimums and corresponding patch indices
            min_dists_sq[start_idx:end_idx] = batch_min_dists_sq
            nearest_patch_indices[start_idx:end_idx] = batch_nearest_patch_indices_in_all

            # Store intended locations for the successfully found patches in this batch
            valid_batch_mask = batch_nearest_patch_indices_in_all != -1 # Should usually be true
            valid_global_patch_indices = batch_nearest_patch_indices_in_all[valid_batch_mask]
            valid_batch_proto_indices_rel = torch.where(valid_batch_mask)[0]
            valid_batch_proto_indices_abs = current_proto_indices_abs[valid_batch_proto_indices_rel.cpu()]

            batch_origins = [all_patch_origins[idx.item()] for idx in valid_global_patch_indices.cpu()]
            for k, abs_proto_idx in enumerate(valid_batch_proto_indices_abs.tolist()):
                intended_locations[abs_proto_idx] = batch_origins[k] # Store (orig_img_idx, h, w)

        # --- Overall Validation and Update ---
        valid_mask = nearest_patch_indices != -1
        valid_proto_indices = torch.where(valid_mask)[0] # Indices of prototypes that found a match
        valid_nearest_patch_indices = nearest_patch_indices[valid_mask] # Indices into all_patch_features

        if len(valid_nearest_patch_indices) == 0:
             print("Error: No valid nearest patches found for any prototype. Skipping update.")
        else:
            print(f"Projecting {len(valid_proto_indices)}/{self.num_parts} prototypes.")
            updated_prototype_features = all_patch_features[valid_nearest_patch_indices]
            # Update the parameters IN-PLACE
            self.part_prototypes.data[valid_proto_indices] = updated_prototype_features.to(self.part_prototypes.dtype)

            # Update the location buffers only for the projected prototypes
            valid_origins = [intended_locations.get(idx.item()) for idx in valid_proto_indices.cpu()]
            valid_origins_filtered = [o for o in valid_origins if o is not None] # Filter out None if lookup failed

            if len(valid_origins_filtered) != len(valid_proto_indices):
                 print(f"*** WARNING: Mismatch between valid protos ({len(valid_proto_indices)}) and found origins ({len(valid_origins_filtered)}). Check logic.")
                 # Decide how to handle: maybe only update buffers for those with origins?
                 # For now, proceed but this indicates a potential issue.

            if valid_origins_filtered:
                # Create tensors for image index, h, w from the filtered origins
                proj_img_indices_valid = torch.tensor([o[0] for o in valid_origins_filtered], dtype=torch.long, device=self.projected_image_indices.device)
                proj_locations_valid = torch.tensor([(o[1], o[2]) for o in valid_origins_filtered], dtype=torch.long, device=self.projected_patch_locations.device)

                # Map valid_origins_filtered back to the correct valid_proto_indices
                # This is tricky if filtering occurred. Let's assume no filtering for simplicity first,
                # or update the buffers based on the indices that *did* have origins found.
                # Simplest: Assume indices match if no filtering occurred
                if len(valid_origins_filtered) == len(valid_proto_indices):
                    # Update projected_image_indices buffer
                    self.projected_image_indices.data[valid_proto_indices] = proj_img_indices_valid

                    # Update projected_patch_locations buffer (img_idx, h, w)
                    # First column is img_idx, next two are h, w
                    self.projected_patch_locations.data[valid_proto_indices, 0] = proj_img_indices_valid # Store img_idx
                    self.projected_patch_locations.data[valid_proto_indices, 1:3] = proj_locations_valid # Store h, w
                else:
                    print("Skipping buffer update due to origin/proto mismatch.")
            else:
                 print("No valid location origins found. Buffers not updated.")

            # --- Final Sanity Check (Example for Proto 0) ---
            check_proto_idx_abs = 0
            if check_proto_idx_abs in valid_proto_indices.tolist():
                check_patch_idx_abs = nearest_patch_indices[check_proto_idx_abs].item()
                final_dist_check = torch.dist(
                    self.part_prototypes[check_proto_idx_abs].to(device).float(),
                    all_patch_features[check_patch_idx_abs].to(device).float()
                ).item()
                min_dist_found = min_dists_sq[check_proto_idx_abs].sqrt().item()
                intended_img, intended_h, intended_w = intended_locations.get(check_proto_idx_abs, (-1,-1,-1))
                buffer_img = self.projected_patch_locations[check_proto_idx_abs, 0].item()
                buffer_h = self.projected_patch_locations[check_proto_idx_abs, 1].item()
                buffer_w = self.projected_patch_locations[check_proto_idx_abs, 2].item()
                print(f"Projection Check (Proto {check_proto_idx_abs}): Projected.")
                print(f"  - Final dist: {final_dist_check:.4f} (min found: {min_dist_found:.4f})")
                print(f"  - Intended Origin (Img, h, w): ({intended_img}, {intended_h}, {intended_w})")
                print(f"  - Buffer Read (Img, h, w): ({buffer_img}, {buffer_h}, {buffer_w})")
                if buffer_img != intended_img or buffer_h != intended_h or buffer_w != intended_w:
                     print("  - *** WARNING: Final check shows mismatch between intended and buffer read! ***")
            else:
                 print(f"Projection Check (Proto {check_proto_idx_abs}): Not projected (or index out of bounds).")


        del all_patch_features, dist_matrix_sq # Free memory
        if device.type == 'cuda': torch.cuda.empty_cache()

        encoder_model.train(original_encoder_mode) # Restore original mode
        print("--- Finished Part Prototype Projection ---")


class HEPN(nn.Module):
    """ Hierarchical Explainable Prototypical Network """
    def __init__(self, encoder, num_parts=NUM_PART_PROTOTYPES,
                 lambda_part_sim=LAMBDA_PART_SIM, combination_method='add'):
        super().__init__()
        self.encoder = encoder
        self.part_prototype_layer = PartPrototypeLayer(
            num_parts=num_parts,
            patch_feature_dim=encoder.patch_feature_dim # Use actual dim from encoder
        )
        self.lambda_part_sim = lambda_part_sim
        self.combination_method = combination_method

        if self.combination_method == 'mlp':
            # Simple MLP: combines global logit and part similarity score
            self.combiner_mlp = nn.Sequential(
                nn.Linear(2, 16),
                nn.ReLU(),
                nn.Linear(16, 1)
            )
            print("Using MLP combiner.")
        else:
            self.combiner_mlp = None
            print("Using Additive combiner.")

    def _calculate_prototypes(self, support_embeddings, support_labels, n_way):
         """ Calculates class prototypes from support embeddings """
         prototypes = torch.zeros(n_way, support_embeddings.size(1), device=support_embeddings.device)
         for c in range(n_way):
             class_embeddings = support_embeddings[support_labels == c]
             if class_embeddings.size(0) > 0:
                 prototypes[c] = class_embeddings.mean(dim=0)
             # else: prototype remains zeros, distance will be high
         return prototypes

    def forward(self, support_images, support_labels, query_images, n_way):
        n_support = support_images.size(0)
        n_query = query_images.size(0)
        all_images = torch.cat([support_images, query_images], dim=0)

        # --- Get Features from Encoder ---
        try:
            # Ensure encoder is in eval mode if parts are frozen and BN stats shouldn't update
            # However, if fine-tuning, it should be in train mode. Assume outer loop handles mode.
            all_conv_features, all_global_embeddings = self.encoder(all_images)
        except Exception as e:
            print(f"RuntimeError in encoder forward: {e}")
            # Return dummy outputs matching expected structure
            dummy_logits = torch.zeros(n_query, n_way, device=all_images.device)
            dummy_loss_info = {'clst_loss': torch.tensor(0.0, device=all_images.device),
                               'diversity_loss': torch.tensor(0.0, device=all_images.device),
                               'l1_loss': torch.tensor(0.0, device=all_images.device)}
            dummy_explain_info = {'query_part_activations': None, 'class_part_profiles': {}}
            return dummy_logits, dummy_loss_info, dummy_explain_info


        support_conv_features = all_conv_features[:n_support]
        query_conv_features = all_conv_features[n_support:]
        support_global_embeddings = all_global_embeddings[:n_support]
        query_global_embeddings = all_global_embeddings[n_support:]

        # --- Level 1: Global Reasoning ---
        global_prototypes = self._calculate_prototypes(support_global_embeddings, support_labels, n_way)
        # Negative squared Euclidean distance as logits
        dist_global = torch.cdist(query_global_embeddings, global_prototypes)**2
        logits_global = -dist_global # Higher score = closer

        # --- Level 2: Part Reasoning ---
        logits_part_sim = torch.zeros(n_query, n_way, device=query_images.device)
        query_part_activations_list = []
        class_part_profiles = {} # Store {class_idx: profile_tensor}

        # Calculate class profiles first (once per class)
        for c in range(n_way):
            class_support_indices = torch.where(support_labels == c)[0]
            if len(class_support_indices) > 0:
                 class_features = support_conv_features[class_support_indices]
                 class_part_profiles[c] = self.part_prototype_layer.calculate_class_profile(class_features)
            else:
                 # Handle case where a class has no support images (shouldn't happen with good sampling)
                 class_part_profiles[c] = torch.zeros(self.part_prototype_layer.num_parts, device=query_images.device)

        # Calculate query activations and compare to class profiles
        for i in range(n_query):
            query_feat_single = query_conv_features[i].unsqueeze(0) # [1, D, H, W]
            # Calculate activation profile for this single query image
            act_q = self.part_prototype_layer.calculate_query_activation(query_feat_single) # [M]
            query_part_activations_list.append(act_q)

            # Calculate similarity score between query activation and each class profile
            for c in range(n_way):
                sim_score = self.part_prototype_layer.calculate_part_similarity_score(act_q, class_part_profiles[c])
                logits_part_sim[i, c] = sim_score # Cosine similarity score

        # --- Combine Level 1 and Level 2 ---
        if self.combination_method == 'add':
            # Combine negative distance and positive similarity
            # Scaling might be needed depending on relative magnitudes
            final_logits = logits_global + self.lambda_part_sim * logits_part_sim
        elif self.combination_method == 'mlp' and self.combiner_mlp is not None:
            # Reshape for MLP: Input features are [global_logit, part_sim_logit]
            combined_input = torch.stack([logits_global, logits_part_sim], dim=-1) # [N_query, N_way, 2]
            flat_input = combined_input.view(-1, 2) # [N_query * N_way, 2]
            flat_output = self.combiner_mlp(flat_input) # [N_query * N_way, 1]
            final_logits = flat_output.view(n_query, n_way) # [N_query, N_way]
        else: # Default to global only if combination method is unknown or MLP missing
             final_logits = logits_global

        # --- Calculate Part Prototype Losses (for regularization during training) ---
        # Use all features (support + query) for cluster loss calculation
        reg_features = torch.cat([support_conv_features, query_conv_features], dim=0)
        clst_loss = self.part_prototype_layer.calculate_cluster_loss(reg_features)
        diversity_loss = self.part_prototype_layer.calculate_diversity_loss()
        # L1 loss could be applied to query activations if needed
        # l1_loss = self.part_prototype_layer.calculate_l1_loss(torch.stack(query_part_activations_list)) if query_part_activations_list else torch.tensor(0.0)
        l1_loss = torch.tensor(0.0, device=final_logits.device) # Keep disabled for now

        loss_info = {'clst_loss': clst_loss, 'diversity_loss': diversity_loss, 'l1_loss': l1_loss}
        explanation_info = {
            'query_part_activations': torch.stack(query_part_activations_list, dim=0) if query_part_activations_list else None, # [N_query, M]
            'class_part_profiles': class_part_profiles # Dict {class_idx: tensor[M]}
        }

        return final_logits, loss_info, explanation_info

    # --- Helper for Manual Grad-CAM ---
    # This calculates the specific score needed for the backward pass in Grad-CAM
    def get_part_activation_score(self, image_tensor, target_part_index):
        """
        Calculates the activation score for a *single* target part prototype
        given a *single* input image tensor. Required for manual Grad-CAM.
        Ensures gradients can flow back from this score.
        """
        if image_tensor.dim() != 4 or image_tensor.size(0) != 1:
            raise ValueError("Input tensor must be [1, C, H, W]")
        if not (0 <= target_part_index < self.part_prototype_layer.num_parts):
            raise ValueError(f"Invalid target_part_index: {target_part_index}")

        # --- Re-compute relevant parts WITH gradient tracking ---
        # 1. Get convolutional features
        # NOTE: Ensure encoder is in appropriate mode (eval/train) outside this function
        conv_features, _ = self.encoder(image_tensor) # [1, D, H, W]

        # 2. Calculate patch similarities ONLY for the target prototype
        B, D, H, W = conv_features.shape
        N_PATCHES_Q = H * W
        if N_PATCHES_Q <= 0: return torch.tensor(0.0, device=image_tensor.device, requires_grad=True) # Handle empty feature map

        patches = conv_features.permute(0, 2, 3, 1).reshape(N_PATCHES_Q, D) # [HW, D]
        target_prototype = self.part_prototype_layer.part_prototypes[target_part_index].unsqueeze(0) # [1, D]

        # Similarity between all patches and the ONE target prototype
        sim_matrix = self.part_prototype_layer._calculate_similarity(patches, target_prototype) # [HW, 1]

        # 3. Find the maximum similarity (this is our target score)
        max_similarity_score, _ = torch.max(sim_matrix, dim=0) # [1]

        # Return the scalar score, ensuring it requires grad
        return max_similarity_score.squeeze()

In [None]:
# -------------------------------------
# --- Manual Grad-CAM Calculation ---
# -------------------------------------

# Store hook results globally (or pass via lists/objects)
forward_features = {}
backward_gradients = {}

def setup_hooks(model, target_layer_name):
    """ Finds the target layer and sets up forward/backward hooks. """
    target_layer = None
    try:
        # Navigate nested modules if necessary (e.g., 'encoder.layer4')
        module_names = target_layer_name.split('.')
        m = model
        for name in module_names:
            m = getattr(m, name)
        target_layer = m
    except AttributeError:
        print(f"ERROR: Could not find target layer '{target_layer_name}' in the model.")
        return None, None, None

    # Clear previous hook data
    forward_features.clear()
    backward_gradients.clear()

    def forward_hook(module, input, output):
        # Store output activation map. Use layer name as key for safety.
        # Detach might not be needed if backward hook handles requires_grad correctly.
        forward_features[target_layer_name] = output # Store tensor directly

    def backward_hook(module, grad_input, grad_output):
        # Store the gradient flowing OUT of the layer (input to the next layer backward)
        # grad_output is a tuple, usually take the first element
        backward_gradients[target_layer_name] = grad_output[0] # Store tensor directly

    # Register hooks
    forward_handle = target_layer.register_forward_hook(forward_hook)
    backward_handle = target_layer.register_full_backward_hook(backward_hook) # Use full backward hook

    print(f"Hooks registered on: {target_layer_name} ({type(target_layer)})")
    return target_layer, forward_handle, backward_handle

def calculate_gradcam_manual(model, target_layer_name, input_tensor, target_part_index):
    """
    Calculates Grad-CAM manually for a specific part activation.

    Args:
        model (HEPN): The trained HEPN model.
        target_layer_name (str): Name of the target conv layer (e.g., 'encoder.layer4').
        input_tensor (torch.Tensor): Input image tensor [1, C, H, W].
        target_part_index (int): Index of the part prototype to visualize.

    Returns:
        np.ndarray: The Grad-CAM heatmap (H, W), or None if failed.
    """
    model.eval() # Ensure model is in evaluation mode

    target_layer, f_handle, b_handle = setup_hooks(model, target_layer_name)
    if target_layer is None:
        return None # Hook setup failed

    # --- Perform forward and backward pass ---
    try:
        # Zero gradients before backward pass
        model.zero_grad()

        # Calculate the specific score we want to visualize (ensures grad tracking)
        # This requires requires_grad=True context
        score = model.get_part_activation_score(input_tensor, target_part_index)

        # Propagate gradients back from this score
        score.backward()

        # --- Retrieve activations and gradients from hooks ---
        activations = forward_features.get(target_layer_name)
        gradients = backward_gradients.get(target_layer_name)

        if activations is None or gradients is None:
            print("Error: Failed to capture activations or gradients from hooks.")
            return None

        # Ensure tensors are on CPU for numpy conversion later if needed
        activations = activations.detach().cpu() # [1, Channels, H_feat, W_feat]
        gradients = gradients.detach().cpu()     # [1, Channels, H_feat, W_feat]

        # --- Calculate Grad-CAM ---
        # 1. Global Average Pooling of Gradients (alpha weights)
        pooled_gradients = torch.mean(gradients, dim=[2, 3], keepdim=True) # [1, C, 1, 1]

        # 2. Weighted Combination of Activations
        # Multiply activations by weights channel-wise and sum across channels
        weighted_activations = activations * pooled_gradients # Broadcasting: [1, C, Hf, Wf]
        heatmap = weighted_activations.sum(dim=1, keepdim=True) # [1, 1, Hf, Wf]

        # 3. Apply ReLU
        heatmap = F.relu(heatmap)

        # Remove batch and channel dimensions
        heatmap = heatmap.squeeze() # [Hf, Wf]

        # --- Post-processing ---
        # Resize heatmap to original image size (optional, often done)
        heatmap = heatmap.numpy() # Convert to numpy
        # Ensure input_tensor is on CPU for shape access
        input_h, input_w = input_tensor.shape[2:]
        heatmap = cv2.resize(heatmap, (input_w, input_h))

        # Normalize heatmap (0 to 1)
        heatmap_min, heatmap_max = np.min(heatmap), np.max(heatmap)
        if heatmap_max > heatmap_min:
             heatmap = (heatmap - heatmap_min) / (heatmap_max - heatmap_min)
        else:
             heatmap = np.zeros_like(heatmap) # Avoid division by zero if heatmap is flat

        return heatmap

    except Exception as e:
        print(f"Error during manual Grad-CAM calculation: {e}")
        import traceback
        traceback.print_exc()
        return None
    finally:
        # --- CRITICAL: Remove hooks ---
        if f_handle: f_handle.remove()
        if b_handle: b_handle.remove()
        # Clear stored data just in case
        forward_features.clear()
        backward_gradients.clear()
        # print("Hooks removed.")

In [None]:
# -------------------------------------
# --- Training & Evaluation (Adapted for HEPN) ---
# -------------------------------------

# --- Label Smoothing Loss (Unchanged) ---
class LabelSmoothingLoss(nn.Module):
    def __init__(self, classes, smoothing=0.0, dim=-1):
        super(LabelSmoothingLoss, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.dim = dim

    def forward(self, pred, target):
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            # Ensure target is correct shape for scatter_
            target_reshaped = target.data.unsqueeze(1)
            true_dist.scatter_(1, target_reshaped, self.confidence)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))

# --- Setup Optimizer for HEPN ---
def setup_optimizer_hepn(model, lr_backbone, lr_head_global, lr_part_prototypes, lr_combiner, wd):
     """ Creates optimizer with differential learning rates for HEPN components """
     encoder_backbone_params, encoder_head_params = model.encoder.get_trainable_parameters()
     # Get part prototype parameters explicitly
     part_prototype_params = [model.part_prototype_layer.part_prototypes]

     # Add combiner params if MLP method is used and it exists
     combiner_params = []
     if hasattr(model, 'combiner_mlp') and model.combiner_mlp is not None:
         combiner_params = list(model.combiner_mlp.parameters())
         print(f"Adding {len(combiner_params)} combiner MLP parameters to optimizer.")

     param_groups = [
          {'params': encoder_backbone_params, 'lr': lr_backbone, 'weight_decay': wd, 'name': 'backbone'},
          {'params': encoder_head_params, 'lr': lr_head_global, 'weight_decay': wd, 'name': 'global_head'},
          {'params': part_prototype_params, 'lr': lr_part_prototypes, 'weight_decay': 0.0, 'name': 'part_protos'}, # Often no WD on prototypes
     ]
     if combiner_params:
          param_groups.append({'params': combiner_params, 'lr': lr_combiner, 'weight_decay': wd, 'name': 'combiner'})

     # Filter out empty parameter groups (e.g., if head is frozen)
     param_groups = [pg for pg in param_groups if pg['params']]

     print(f"Optimizer Groups:")
     total_params = 0
     for pg in param_groups:
         pg_params = sum(p.numel() for p in pg['params'])
         total_params += pg_params
         print(f"  - {pg['name']}: {len(pg['params'])} tensors, {pg_params:,} params, LR {pg['lr']:.1e}, WD {pg.get('weight_decay', 0.0):.1e}")
     print(f"Total trainable parameters in optimizer: {total_params:,}")

     optimizer = optim.AdamW(param_groups) # WD applied per group by AdamW
     return optimizer


# --- Modified Training Step ---
def train_step_hepn(model, optimizer, cls_loss_fn, support_images, support_labels, query_images, query_labels, n_way, gradient_clip_norm, loss_weights):
    model.train() # Ensure model is in training mode
    optimizer.zero_grad()

    support_images = support_images.to(DEVICE, non_blocking=True)
    support_labels = support_labels.to(DEVICE, non_blocking=True)
    query_images = query_images.to(DEVICE, non_blocking=True)
    query_labels = query_labels.to(DEVICE, non_blocking=True)

    # --- Forward Pass ---
    final_logits, loss_info, _ = model(support_images, support_labels, query_images, n_way)

    if final_logits is None or torch.isnan(final_logits).any() or torch.isinf(final_logits).any():
         print("Warning: NaN/Inf final logits detected during training. Skipping batch.")
         nan_loss_info = {k: float('nan') for k in loss_info}
         return {'loss': float('nan'), 'acc': 0.0, 'cls_loss': float('nan'), **nan_loss_info}

    # --- Calculate Losses ---
    cls_loss = cls_loss_fn(final_logits, query_labels)

    # Combine classification loss with regularization losses using weights
    clst_term = loss_weights.get('clst', 0.0) * loss_info.get('clst_loss', torch.tensor(0.0, device=DEVICE))
    div_term = loss_weights.get('diversity', 0.0) * loss_info.get('diversity_loss', torch.tensor(0.0, device=DEVICE))
    l1_term = loss_weights.get('l1', 0.0) * loss_info.get('l1_loss', torch.tensor(0.0, device=DEVICE))

    total_loss = cls_loss + clst_term + div_term + l1_term

    # --- Backward and Optimize ---
    if torch.isnan(total_loss):
        print("Warning: NaN total loss calculated. Skipping backward pass.")
        debug_loss_info = {k: v.item() if torch.is_tensor(v) and v.numel() == 1 else float('nan') for k, v in loss_info.items()}
        return {'loss': float('nan'), 'acc': 0.0, 'cls_loss': cls_loss.item() if torch.is_tensor(cls_loss) else float('nan'), **debug_loss_info}

    total_loss.backward()

    # Gradient Clipping (applied to all parameters)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=gradient_clip_norm)

    optimizer.step()

    # --- Calculate Accuracy ---
    with torch.no_grad():
        predictions = torch.argmax(final_logits, dim=1)
        accuracy = torch.mean((predictions == query_labels).float())

    # Return detailed loss info (scalar values)
    loss_dict_items = {k: v.item() if torch.is_tensor(v) and v.numel() == 1 else 0.0 for k, v in loss_info.items()}
    loss_dict = {
        'loss': total_loss.item(),
        'acc': accuracy.item(),
        'cls_loss': cls_loss.item(),
        **loss_dict_items
    }
    return loss_dict


# --- Modified Evaluation Step ---
def evaluate_step_hepn(model, support_images, support_labels, query_images, query_labels, n_way):
    model.eval() # Ensure model is in evaluation mode

    support_images = support_images.to(DEVICE)
    support_labels = support_labels.to(DEVICE)
    query_images = query_images.to(DEVICE)
    query_labels = query_labels.to(DEVICE)

    with torch.no_grad():
        # Forward pass - don't need loss info or explanation details for accuracy
        final_logits, _, _ = model(support_images, support_labels, query_images, n_way)

        if final_logits is None or torch.isnan(final_logits).any() or torch.isinf(final_logits).any():
            print("Warning: NaN/Inf detected in final_logits during evaluation.")
            return 0.0 # Return 0 accuracy or handle appropriately

    predictions = torch.argmax(final_logits, dim=1)
    accuracy = torch.mean((predictions == query_labels).float())
    return accuracy.item()


# --- Modified Training Loop ---
def main_training_loop_hepn(model, train_sampler, test_sampler, meta_train_full_dataset,
                           n_train_episodes, n_test_episodes,
                           lr_backbone, lr_head_global, lr_part_prototypes, lr_combiner, wd,
                           label_smoothing, grad_clip_norm,
                           lambda_clst, lambda_diversity, lambda_l1,
                           part_proj_interval):

    optimizer = setup_optimizer_hepn(model, lr_backbone, lr_head_global, lr_part_prototypes, lr_combiner, wd)
    cls_loss_fn = LabelSmoothingLoss(classes=train_sampler.n_way, smoothing=label_smoothing).to(DEVICE)

    loss_weights = {'clst': lambda_clst, 'diversity': lambda_diversity, 'l1': lambda_l1}
    print(f"Using Loss Weights: {loss_weights}")

    # Dynamic history keys based on used losses
    train_history = {'loss': [], 'acc': [], 'cls_loss': []}
    for key, weight in loss_weights.items():
        if weight > 0: train_history[f'{key}_loss'] = []

    test_accuracies = []
    best_test_acc = 0.0
    test_eval_interval = 500
    log_interval = 100
    best_model_path = os.path.join(BASE_PATH, "best_hepn_model_manual_cam.pth") # Updated name

    print("\n--- Starting HEPN Meta-Training (Manual CAM Version) ---")
    print(f"Config: N={N_WAY}, K={K_SHOT}, Q={N_QUERY}, Train Eps={n_train_episodes}, Test Eps={n_test_episodes}")
    print(f"LRs: BB={lr_backbone:.1e}, GlobalH={lr_head_global:.1e}, Parts={lr_part_prototypes:.1e}, Comb={lr_combiner:.1e}")
    print(f"Loss Lambdas -> Clst:{loss_weights.get('clst',0):.2f}, Div:{loss_weights.get('diversity',0):.2f}, L1:{loss_weights.get('l1',0):.2f}, PartSimCombine:{model.lambda_part_sim:.2f}")
    print(f"Projection Interval: {part_proj_interval} episodes")

    pbar = tqdm(range(n_train_episodes))
    for episode_idx in pbar:
        try:
            support_images, support_labels, query_images, query_labels = train_sampler.sample()
        except ValueError as e:
            print(f"Skipping train episode {episode_idx+1} due to sampler error: {e}")
            continue # Skip episode if sampling fails

        # --- Training Step ---
        loss_dict = train_step_hepn(model, optimizer, cls_loss_fn, support_images, support_labels,
                                   query_images, query_labels, train_sampler.n_way, grad_clip_norm, loss_weights)

        # --- Log Training Stats ---
        if not math.isnan(loss_dict['loss']):
             for key in train_history.keys():
                 if key in loss_dict and not math.isnan(loss_dict[key]):
                     train_history[key].append(loss_dict[key])

             if (episode_idx + 1) % log_interval == 0 and train_history['loss']:
                 desc = f"Ep {episode_idx+1}"
                 for key, values in train_history.items():
                      if values:
                           avg_val = np.mean(values[-log_interval:])
                           key_short = "".join(w[0].upper() for w in key.split('_'))
                           if key == 'acc': key_short = 'Acc'
                           if key == 'loss': key_short = 'L'
                           desc += f" | {key_short}:{avg_val:.3f}"
                 desc += f" | BestT:{best_test_acc:.4f}"
                 pbar.set_description(desc)

        # --- Evaluate and Save Best Model ---
        if (episode_idx + 1) % test_eval_interval == 0 or episode_idx == n_train_episodes - 1:
             eval_eps = n_test_episodes // 2 if episode_idx != n_train_episodes - 1 else n_test_episodes
             current_test_acc = evaluate_on_test_set_hepn(model, test_sampler, eval_eps)
             if not math.isnan(current_test_acc):
                 test_accuracies.append(current_test_acc)
                 print(f" | Test Acc @ Ep {episode_idx+1}: {current_test_acc:.4f}")
                 if current_test_acc > best_test_acc:
                      best_test_acc = current_test_acc
                      print(f"*** New best test accuracy: {best_test_acc:.4f}. Saving model to {best_model_path} ***")
                      try:
                          torch.save(model.state_dict(), best_model_path)
                      except Exception as e:
                          print(f"Error saving model: {e}")
             else:
                  print(f" | Test Acc @ Ep {episode_idx+1}: NaN")
             model.train() # Ensure model is back in train mode


        # --- Part Prototype Projection ---
        if (episode_idx > 0 and (episode_idx + 1) % part_proj_interval == 0):
             if meta_train_full_dataset is not None:
                  print(f"\n--- Projecting prototypes at episode {episode_idx + 1} ---")
                  model.part_prototype_layer.project_part_prototypes(
                       dataset_for_proj=meta_train_full_dataset,
                       encoder_model=model.encoder,
                       device=DEVICE,
                       batch_size=PROJECTION_BATCH_SIZE
                  )
                  model.train() # Ensure model back in train mode
             else:
                  print(f"Warning: Cannot project prototypes at episode {episode_idx + 1}, meta_train_full_dataset is None.")


    print("\n--- Finished HEPN Meta-Training ---")
    print(f"Best meta-test accuracy achieved: {best_test_acc:.4f}")

    # --- Final Evaluation ---
    print("\nRunning final evaluation on full test set...")
    final_test_acc = evaluate_on_test_set_hepn(model, test_sampler, n_test_episodes)
    print(f"Final Meta-Test Accuracy: {final_test_acc:.4f}")

    # --- Plotting ---
    num_plots = len(train_history)
    if num_plots > 0:
        plt.figure(figsize=(min(18, 5 * num_plots) , 5))
        plot_idx = 1
        for key, values in train_history.items():
             if values: # Only plot if data exists
                 plt.subplot(1, num_plots, plot_idx)
                 plt.plot(values, alpha=0.6, label='Raw')
                 if len(values) >= log_interval:
                     rolling_avg = pd.Series(values).rolling(log_interval, min_periods=log_interval//2).mean()
                     plt.plot(rolling_avg, label=f'Avg {log_interval}eps', linewidth=2)
                 plt.title(f'Training {key.replace("_", " ").title()}')
                 plt.xlabel('Episode')
                 plt.ylabel(key)
                 if 'acc' in key: plt.ylim(0, max(1.0, np.max(values)*1.1) if values else 1.0)
                 plt.grid(True, linestyle='--', alpha=0.6)
                 plt.legend()
                 plot_idx += 1
        plt.tight_layout()
        plt.show()

    return model # Return the trained model


# --- Evaluation Loop Helper ---
def evaluate_on_test_set_hepn(model, test_sampler, n_episodes):
    all_accuracies = []
    pbar = tqdm(range(n_episodes), desc="Meta-Testing", leave=False)
    for i in pbar:
        try:
             support_images, support_labels, query_images, query_labels = test_sampler.sample()
        except ValueError as e:
             print(f"Skipping test episode due to sampler error: {e}")
             continue

        acc = evaluate_step_hepn(model, support_images, support_labels, query_images, query_labels, test_sampler.n_way)
        if not math.isnan(acc):
            all_accuracies.append(acc)
            pbar.set_postfix({"Avg Acc": f"{np.mean(all_accuracies):.4f}"})

    if not all_accuracies: return float('nan') # Return NaN if no episodes were successful
    return np.mean(all_accuracies)

In [None]:
# -------------------------------------
# --- Visualization (with Manual Grad-CAM) ---
# -------------------------------------

# Utility function to overlay CAM - similar to pytorch-gradcam's show_cam_on_image
def show_cam_on_image(img: np.ndarray,
                      mask: np.ndarray,
                      use_rgb: bool = False,
                      colormap: int = cv2.COLORMAP_JET,
                      image_weight: float = 0.5) -> np.ndarray:
    """ Overlays a CAM mask onto an image.
        Args:
            img: Input image Numpy array (H, W, C). Float 0-1.
            mask: CAM mask Numpy array (H, W). Float 0-1.
            use_rgb: Whether to use an RGB or BGR heatmap map.
            colormap: OpenCV colormap to use.
            image_weight: Weight for blending the original image.
        Returns:
            Numpy array: Image with CAM overlay.
    """
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
    if use_rgb:
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    heatmap = np.float32(heatmap) / 255

    if np.max(img) > 1:
        raise Exception("The input image should np.float32 in the range [0, 1]")
    if image_weight < 0 or image_weight > 1:
        raise Exception(f"image_weight should be in the range [0, 1]. Got: {image_weight}")

    cam = (1 - image_weight) * heatmap + image_weight * img
    cam = cam / np.max(cam) # Normalize to 0-1 range
    return np.uint8(255 * cam)


# Inverse transform for visualization
inv_normalize_imagenet = transforms.Normalize(
   mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
   std=[1/0.229, 1/0.224, 1/0.225]
)

def preprocess_image_for_viz(tensor_img):
    """ Prepares tensor for showing with CAM overlay (float 0-1, HWC) """
    if tensor_img is None: return None
    img = tensor_img.detach().cpu()
    img = inv_normalize_imagenet(img)
    img = img.permute(1, 2, 0) # HWC
    img = img.clamp(0, 1)
    return img.numpy()


def visualize_explanation_hepn_manual_gradcam(model, test_sampler, target_layer_name, num_explanations=1, top_k_parts=3):
    """
    Visualizes HEPN explanations using MANUALLY calculated Grad-CAM.
    """
    print(f"\n--- Visualizing HEPN Explanations (Manual Grad-CAM) for {num_explanations} Test Episodes ---")
    model.eval() # Ensure evaluation mode

    for vis_idx in range(num_explanations):
        print(f"\n--- Explanation Example {vis_idx + 1}/{num_explanations} ---")
        try:
            support_imgs, support_lbls, query_imgs, query_lbls = test_sampler.sample()
        except ValueError as e:
            print(f"Skipping visualization {vis_idx+1} due to sampler error: {e}")
            continue
        if query_imgs is None or query_imgs.size(0) == 0: continue

        query_idx_to_explain = random.randrange(query_imgs.size(0))
        support_images_dev = support_imgs.to(DEVICE)
        support_labels_dev = support_lbls.to(DEVICE)
        query_image_single_tensor = query_imgs[query_idx_to_explain]
        query_image_single_input = query_image_single_tensor.unsqueeze(0).to(DEVICE) # [1, C, H, W]
        true_label = query_lbls[query_idx_to_explain].item()
        n_way = test_sampler.n_way

        # --- Get Model Outputs & Explanation Info ---
        predicted_label = -1
        query_part_activation = None
        class_part_profiles = {}
        with torch.no_grad(): # Get predictions without grad
             try:
                 final_logits, _, explanation_info = model(
                     support_images_dev, support_labels_dev, query_image_single_input, n_way
                 )
                 if final_logits is not None and final_logits.numel() > 0:
                      predicted_label = torch.argmax(final_logits, dim=1).item()
                      query_part_activation = explanation_info.get('query_part_activations') # [1, M]
                      if query_part_activation is not None: query_part_activation = query_part_activation.squeeze(0).cpu() # [M]
                      class_part_profiles = explanation_info.get('class_part_profiles', {})
                 else: print("Warning: Model forward pass returned invalid logits.")
             except Exception as e:
                  print(f"Error during model forward pass for visualization: {e}")
                  import traceback; traceback.print_exc(); continue

        print(f"Explaining Query Img Idx: {query_idx_to_explain} | Pred: {predicted_label} | True: {true_label}")

        # --- Create Figure ---
        num_bar_charts = 0
        if query_part_activation is not None: num_bar_charts += 1
        if class_part_profiles.get(predicted_label) is not None: num_bar_charts += 1
        can_do_cam = query_part_activation is not None and query_part_activation.numel() > 0
        actual_top_k_parts = min(top_k_parts, query_part_activation.numel()) if can_do_cam else 0
        num_cam_plots = actual_top_k_parts
        total_plots = 1 + num_bar_charts + num_cam_plots

        if total_plots <= 1: N_COLS_VIS, N_ROWS_VIS = 1, 1
        else: N_COLS_VIS = min(total_plots, 4); N_ROWS_VIS = math.ceil(total_plots / N_COLS_VIS)

        plt.figure(figsize=(max(8, 4 * N_COLS_VIS), 3.5 * N_ROWS_VIS))
        plot_idx = 1
        suptitle = f"HEPN Expl {vis_idx+1}: Q Img {query_idx_to_explain} (Pred: {predicted_label}, True: {true_label})"
        if can_do_cam: suptitle += f" | Top {actual_top_k_parts} CAMs (Manual)"
        plt.suptitle(suptitle, fontsize=12, y=0.99)

        # == Plot Query Image ==
        ax = plt.subplot(N_ROWS_VIS, N_COLS_VIS, plot_idx); plot_idx+=1
        query_img_np_viz = preprocess_image_for_viz(query_image_single_tensor)
        if query_img_np_viz is not None: ax.imshow(query_img_np_viz)
        ax.set_title(f"Query Img {query_idx_to_explain}\nTrue: {true_label} / Pred: {predicted_label}")
        ax.axis('off')

        # == Plot Bar Charts ==
        top_activated_part_indices = []
        if query_part_activation is not None and query_part_activation.numel() > 0:
            ax = plt.subplot(N_ROWS_VIS, N_COLS_VIS, plot_idx); plot_idx+=1
            sorted_vals, sorted_indices = torch.sort(query_part_activation, descending=True)
            k_bar = min(10, query_part_activation.numel())
            if k_bar > 0:
                bar_vals, bar_indices = sorted_vals[:k_bar], sorted_indices[:k_bar]
                top_activated_part_indices = sorted_indices[:actual_top_k_parts].tolist() # Store indices for CAM
                ax.barh(range(k_bar), bar_vals.flip(dims=[0]), tick_label=[f"P {i}" for i in bar_indices.flip(dims=[0]).tolist()])
                ax.set_title(f"Query Top {k_bar} Part Acts"); ax.tick_params(axis='y', labelsize=8)
            else: ax.axis('off'); ax.text(0.5, 0.5, "No activations", ha='center')
        elif num_bar_charts > 0: ax=plt.subplot(N_ROWS_VIS, N_COLS_VIS, plot_idx); plot_idx+=1; ax.axis('off'); ax.text(0.5, 0.5, "No activation data", ha='center')

        pred_class_profile = class_part_profiles.get(predicted_label)
        if pred_class_profile is not None and pred_class_profile.numel() > 0:
            ax = plt.subplot(N_ROWS_VIS, N_COLS_VIS, plot_idx); plot_idx+=1
            pred_class_profile_cpu = pred_class_profile.cpu()
            sorted_prof_vals, sorted_prof_indices = torch.sort(pred_class_profile_cpu, descending=True)
            k_prof_bar = min(10, pred_class_profile_cpu.numel())
            if k_prof_bar > 0:
                 bar_prof_vals, bar_prof_indices = sorted_prof_vals[:k_prof_bar], sorted_prof_indices[:k_prof_bar]
                 ax.barh(range(k_prof_bar), bar_prof_vals.flip(dims=[0]), tick_label=[f"P {i}" for i in bar_prof_indices.flip(dims=[0]).tolist()])
                 ax.set_title(f"Pred Cls {predicted_label} Top {k_prof_bar} Profile"); ax.tick_params(axis='y', labelsize=8)
            else: ax.axis('off'); ax.text(0.5, 0.5, "No profile values", ha='center')
        elif num_bar_charts > 1: ax=plt.subplot(N_ROWS_VIS, N_COLS_VIS, plot_idx); plot_idx+=1; ax.axis('off'); ax.text(0.5, 0.5, "No profile data", ha='center')


        # == Generate and Plot Manual Grad-CAM ==
        if can_do_cam and top_activated_part_indices:
            print(f"  Generating Manual Grad-CAM for top {len(top_activated_part_indices)} activated parts: {top_activated_part_indices}")
            input_image_np = query_img_np_viz # Use the HWC, 0-1 numpy image
            if input_image_np is None: print("  Skipping CAMs as base image processing failed.")
            else:
                for rank, part_idx in enumerate(top_activated_part_indices):
                    if plot_idx > N_ROWS_VIS * N_COLS_VIS: print(f"  Warning: Not enough plot slots..."); break
                    ax = plt.subplot(N_ROWS_VIS, N_COLS_VIS, plot_idx); plot_idx+=1

                    # Calculate Grad-CAM manually
                    # Pass the model, target layer name string, input tensor [1,C,H,W], and part index
                    grayscale_cam = calculate_gradcam_manual(model,
                                                             target_layer_name,
                                                             query_image_single_input, # Use the [1,C,H,W] tensor
                                                             part_idx)

                    if grayscale_cam is not None:
                        visualization = show_cam_on_image(input_image_np, grayscale_cam, use_rgb=True)
                        ax.imshow(visualization)
                        ax.set_title(f"Part {part_idx} CAM (Rank {rank+1})"); ax.axis('off')
                    else:
                         print(f"  Failed to generate GradCAM for Part {part_idx}.")
                         ax.set_title(f"GradCAM Fail P{part_idx}"); ax.axis('off'); continue

        # --- Cleanup and Show ---
        plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout
        plt.show()

    # End of visualization loop

In [None]:
# -------------------------------------
# --- Main Execution ---
# -------------------------------------
if __name__ == "__main__":
    # Set random seeds for reproducibility (optional)
    seed = 42
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        # Might impact performance, but increases reproducibility
        # torch.backends.cudnn.deterministic = True
        # torch.backends.cudnn.benchmark = False

    print("1. Preparing CUB Data Splits...")
    # Ensure split file path uses BASE_PATH
    meta_train_data, meta_test_data, meta_train_full_dataset_for_proj = prepare_cub_data_splits(
        data_dir=DATA_DIR,
        split_save_path=SPLIT_FILE,
        force_resplit=FORCE_RESPLIT
    )

    if meta_train_data and meta_test_data and meta_train_full_dataset_for_proj:
        print(f"\nUsing N_WAY={N_WAY}, K_SHOT={K_SHOT}, N_QUERY={N_QUERY} after potential adjustment.")

        print("\n2. Creating Episode Samplers...")
        try:
             train_sampler = EpisodeSampler(meta_train_data, N_WAY, K_SHOT, N_QUERY)
             test_sampler = EpisodeSampler(meta_test_data, N_WAY, K_SHOT, N_QUERY)
             print(f"Train Sampler: {train_sampler.num_classes} classes.")
             print(f"Test Sampler: {test_sampler.num_classes} classes.")
        except ValueError as e:
             print(f"\n*** ERROR initializing samplers: {e} ***")
             print("Check N_WAY/K_SHOT/N_QUERY vs available classes/samples.")
             exit()

        print("\n3. Initializing HEPN Model...")
        encoder = EncoderHEPN(
            embedding_dim_global=EMBEDDING_DIM_GLOBAL,
            patch_feature_dim=EMBEDDING_DIM_PATCH, # Informational, uses ResNet's dim
            pretrained=PRETRAINED, freeze_until=FREEZE_UNTIL_LAYER, dropout_rate=DROPOUT_RATE
        ).to(DEVICE)

        actual_patch_dim = encoder.patch_feature_dim
        print(f"Instantiated Encoder with actual patch feature dim: {actual_patch_dim}")

        model = HEPN(encoder=encoder,
                     num_parts=NUM_PART_PROTOTYPES,
                     lambda_part_sim=LAMBDA_PART_SIM,
                     combination_method='add' # or 'mlp'
                    ).to(DEVICE)
        print(f"HEPN Model Initialized on {DEVICE}.")
        try:
             trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
             total_params = sum(p.numel() for p in model.parameters())
             print(f"HEPN Trainable Params: {trainable_params:,} / Total Params: {total_params:,}")
        except Exception as e:
             print(f"Could not calculate parameter count: {e}")


        print("\n4. Starting/Loading HEPN Training...")
        LOAD_SAVED_MODEL = False # Set True to attempt loading
        model_path = os.path.join(BASE_PATH, "best_hepn_model_manual_cam.pth") # Ensure path consistency
        trained_model = model

        if LOAD_SAVED_MODEL and os.path.exists(model_path):
            print(f"Attempting to load saved model state from: {model_path}")
            try:
                 # Use strict=True for reliable loading unless you know the architecture differs slightly
                 model.load_state_dict(torch.load(model_path, map_location=DEVICE), strict=True)
                 print("Successfully loaded saved model state.")
            except Exception as e:
                 print(f"----------------------------------------------------")
                 print(f"ERROR loading saved model: {e}")
                 print(f"Model state dict NOT loaded. Training from scratch.")
                 print(f"----------------------------------------------------")
                 LOAD_SAVED_MODEL = False # Force training if load fails
        elif LOAD_SAVED_MODEL:
            print(f"Saved model path specified ({model_path}) but file not found. Training from scratch.")
            LOAD_SAVED_MODEL = False
        else:
            print("LOAD_SAVED_MODEL is False. Training from scratch.")


        # --- Training ---
        if not LOAD_SAVED_MODEL:
            print("Starting new training run...")
            trained_model = main_training_loop_hepn(
                model=model,
                train_sampler=train_sampler,
                test_sampler=test_sampler,
                meta_train_full_dataset=meta_train_full_dataset_for_proj,
                n_train_episodes=N_TRAIN_EPISODES,
                n_test_episodes=N_TEST_EPISODES,
                lr_backbone=LR_BACKBONE,
                lr_head_global=LR_HEAD_GLOBAL,
                lr_part_prototypes=LR_PART_PROTOTYPES,
                lr_combiner=LR_COMBINER,
                wd=WEIGHT_DECAY,
                label_smoothing=LABEL_SMOOTHING,
                grad_clip_norm=GRADIENT_CLIP_NORM,
                lambda_clst=LAMBDA_CLST,
                lambda_diversity=LAMBDA_DIVERSITY,
                lambda_l1=LAMBDA_L1,
                part_proj_interval=PART_PROJECTION_INTERVAL
            )
            print("Training finished.")
        else:
             print("Skipping training, using loaded model.")


        print("\n5. Visualizing HEPN Explanation with Manual Grad-CAM...")
        if trained_model is not None:
             visualize_explanation_hepn_manual_gradcam(
                  model=trained_model,
                  test_sampler=test_sampler,
                  target_layer_name=TARGET_LAYER_NAME, # Pass the target layer name string
                  num_explanations=5,
                  top_k_parts=4
             )
        else:
             print("ERROR: No valid model available for visualization.")

    else:
        print("\n--- ERROR: Failed to prepare data. Check paths, data integrity, and split parameters. Exiting. ---")

    print("\n--- Script Execution Finished ---")

In [None]:
visualize_explanation_hepn_manual_gradcam(
                  model=trained_model,
                  test_sampler=test_sampler,
                  target_layer_name=TARGET_LAYER_NAME, # Pass the target layer name string
                  num_explanations=5,
                  top_k_parts=4)

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import math
import random
import cv2 # Make sure cv2 is imported if not already

# Assume all other necessary imports and definitions (like preprocess_image_for_viz,
# calculate_gradcam_manual, show_cam_on_image, DEVICE, etc.) are present above.

def visualize_explanation_hepn_manual_gradcam(
    model,
    test_sampler,
    target_layer_name,
    num_explanations=1,
    top_k_parts=3,
    num_support_cams_per_class=2 # <-- New parameter
    ):
    """
    Visualizes HEPN explanations using MANUALLY calculated Grad-CAM for query
    AND selected support images of the predicted class.
    """
    print(f"\n--- Visualizing HEPN Explanations (Manual Grad-CAM) for {num_explanations} Test Episodes ---")
    print(f"    (Including CAM for up to {num_support_cams_per_class} support images per explanation)")
    model.eval() # Ensure evaluation mode

    for vis_idx in range(num_explanations):
        print(f"\n--- Explanation Example {vis_idx + 1}/{num_explanations} ---")
        try:
            support_imgs, support_lbls, query_imgs, query_lbls = test_sampler.sample()
        except ValueError as e:
            print(f"Skipping visualization {vis_idx+1} due to sampler error: {e}")
            continue
        if query_imgs is None or query_imgs.size(0) == 0: continue

        query_idx_to_explain = random.randrange(query_imgs.size(0))
        support_images_dev = support_imgs.to(DEVICE)
        support_labels_dev = support_lbls.to(DEVICE)
        # Keep original support tensors on CPU for potential visualization later
        support_images_cpu = support_imgs.cpu()
        support_labels_cpu = support_lbls.cpu()

        query_image_single_tensor = query_imgs[query_idx_to_explain]
        query_image_single_input = query_image_single_tensor.unsqueeze(0).to(DEVICE) # [1, C, H, W]
        true_label = query_lbls[query_idx_to_explain].item()
        n_way = test_sampler.n_way

        # --- Get Model Outputs & Explanation Info ---
        predicted_label = -1
        query_part_activation = None
        class_part_profiles = {}
        with torch.no_grad(): # Get predictions without grad
             try:
                 final_logits, _, explanation_info = model(
                     support_images_dev, support_labels_dev, query_image_single_input, n_way
                 )
                 if final_logits is not None and final_logits.numel() > 0:
                      predicted_label = torch.argmax(final_logits, dim=1).item()
                      query_part_activation = explanation_info.get('query_part_activations') # [1, M]
                      if query_part_activation is not None: query_part_activation = query_part_activation.squeeze(0).cpu() # [M]
                      class_part_profiles = explanation_info.get('class_part_profiles', {})
                 else: print("Warning: Model forward pass returned invalid logits.")
             except Exception as e:
                  print(f"Error during model forward pass for visualization: {e}")
                  import traceback; traceback.print_exc(); continue

        print(f"Explaining Query Img Idx: {query_idx_to_explain} | Pred: {predicted_label} | True: {true_label}")

        # --- Identify Support Images for Visualization ---
        selected_support_indices = []
        if predicted_label != -1:
            support_indices_pred_class = torch.where(support_labels_cpu == predicted_label)[0]
            # Select up to num_support_cams_per_class indices from the predicted class
            selected_support_indices = support_indices_pred_class[:num_support_cams_per_class].tolist()
            print(f"  Will visualize CAMs for support indices: {selected_support_indices} (Predicted Class: {predicted_label})")


        # --- Create Figure (Adjust layout dynamically) ---
        num_bar_charts = 0
        if query_part_activation is not None: num_bar_charts += 1
        if class_part_profiles.get(predicted_label) is not None: num_bar_charts += 1

        can_do_cam = query_part_activation is not None and query_part_activation.numel() > 0
        actual_top_k_parts = min(top_k_parts, query_part_activation.numel()) if can_do_cam else 0
        num_query_cam_plots = actual_top_k_parts
        num_support_cam_plots = len(selected_support_indices) * actual_top_k_parts # CAM for each selected support, for each top part

        total_plots = 1 + num_bar_charts + num_query_cam_plots + num_support_cam_plots

        if total_plots <= 1: N_COLS_VIS, N_ROWS_VIS = 1, 1
        else:
             # Adjust column count for better layout (e.g., up to 5 or 6 wide)
             N_COLS_VIS = min(total_plots, 5 + num_support_cams_per_class)
             N_ROWS_VIS = math.ceil(total_plots / N_COLS_VIS)

        plt.figure(figsize=(max(12, 3.5 * N_COLS_VIS), 3.5 * N_ROWS_VIS)) # Adjusted figsize
        plot_idx = 1
        suptitle = f"HEPN Expl {vis_idx+1}: Q Img {query_idx_to_explain} (Pred: {predicted_label}, True: {true_label})"
        if can_do_cam: suptitle += f" | Top {actual_top_k_parts} Parts CAMs (Manual)"
        plt.suptitle(suptitle, fontsize=12, y=0.99)

        # == Plot Query Image ==
        ax = plt.subplot(N_ROWS_VIS, N_COLS_VIS, plot_idx); plot_idx+=1
        query_img_np_viz = preprocess_image_for_viz(query_image_single_tensor)
        if query_img_np_viz is not None: ax.imshow(query_img_np_viz)
        ax.set_title(f"Query Img {query_idx_to_explain}\nTrue: {true_label} / Pred: {predicted_label}")
        ax.axis('off')

        # == Plot Bar Charts ==
        top_activated_part_indices = [] # Will store the indices of parts to visualize
        if query_part_activation is not None and query_part_activation.numel() > 0:
            ax = plt.subplot(N_ROWS_VIS, N_COLS_VIS, plot_idx); plot_idx+=1
            sorted_vals, sorted_indices = torch.sort(query_part_activation, descending=True)
            k_bar = min(10, query_part_activation.numel()) # Show top 10 in bar chart
            if k_bar > 0:
                bar_vals, bar_indices = sorted_vals[:k_bar], sorted_indices[:k_bar]
                # Store the indices we'll actually use for CAM plots (top_k_parts)
                top_activated_part_indices = sorted_indices[:actual_top_k_parts].tolist()
                ax.barh(range(k_bar), bar_vals.flip(dims=[0]), tick_label=[f"P {i}" for i in bar_indices.flip(dims=[0]).tolist()])
                ax.set_title(f"Query Top {k_bar} Part Acts"); ax.tick_params(axis='y', labelsize=8)
            else: ax.axis('off'); ax.text(0.5, 0.5, "No activations", ha='center')
        elif num_bar_charts > 0: ax=plt.subplot(N_ROWS_VIS, N_COLS_VIS, plot_idx); plot_idx+=1; ax.axis('off'); ax.text(0.5, 0.5, "No activation data", ha='center')


        pred_class_profile = class_part_profiles.get(predicted_label)
        if pred_class_profile is not None and pred_class_profile.numel() > 0:
            ax = plt.subplot(N_ROWS_VIS, N_COLS_VIS, plot_idx); plot_idx+=1
            pred_class_profile_cpu = pred_class_profile.cpu()
            sorted_prof_vals, sorted_prof_indices = torch.sort(pred_class_profile_cpu, descending=True)
            k_prof_bar = min(10, pred_class_profile_cpu.numel())
            if k_prof_bar > 0:
                 bar_prof_vals, bar_prof_indices = sorted_prof_vals[:k_prof_bar], sorted_prof_indices[:k_prof_bar]
                 ax.barh(range(k_prof_bar), bar_prof_vals.flip(dims=[0]), tick_label=[f"P {i}" for i in bar_prof_indices.flip(dims=[0]).tolist()])
                 ax.set_title(f"Pred Cls {predicted_label} Top {k_prof_bar} Profile"); ax.tick_params(axis='y', labelsize=8)
            else: ax.axis('off'); ax.text(0.5, 0.5, "No profile values", ha='center')
        elif num_bar_charts > 1: ax=plt.subplot(N_ROWS_VIS, N_COLS_VIS, plot_idx); plot_idx+=1; ax.axis('off'); ax.text(0.5, 0.5, "No profile data", ha='center')


        # == Generate and Plot Manual Grad-CAM for QUERY Image ==
        if can_do_cam and top_activated_part_indices:
            print(f"  Generating Manual Grad-CAM for QUERY image (idx {query_idx_to_explain}) top {len(top_activated_part_indices)} parts: {top_activated_part_indices}")
            input_image_np = query_img_np_viz # Use the HWC, 0-1 numpy image
            if input_image_np is None: print("  Skipping Query CAMs as base image processing failed.")
            else:
                for rank, part_idx in enumerate(top_activated_part_indices):
                    if plot_idx > N_ROWS_VIS * N_COLS_VIS: print(f"  Warning: Not enough plot slots for query CAMs..."); break
                    ax = plt.subplot(N_ROWS_VIS, N_COLS_VIS, plot_idx); plot_idx+=1

                    grayscale_cam = calculate_gradcam_manual(model,
                                                             target_layer_name,
                                                             query_image_single_input, # Use the [1,C,H,W] tensor
                                                             part_idx)

                    if grayscale_cam is not None:
                        visualization = show_cam_on_image(input_image_np, grayscale_cam, use_rgb=True)
                        ax.imshow(visualization)
                        ax.set_title(f"Query P{part_idx} CAM (Rank {rank+1})"); ax.axis('off')
                    else:
                         print(f"  Failed to generate GradCAM for Query, Part {part_idx}.")
                         ax.set_title(f"Query P{part_idx} CAM Fail"); ax.axis('off'); continue

        # == Generate and Plot Manual Grad-CAM for SUPPORT Images == ### NEW SECTION ###
        if can_do_cam and top_activated_part_indices and len(selected_support_indices) > 0:
            print(f"  Generating Manual Grad-CAM for {len(selected_support_indices)} SUPPORT images (Pred Class {predicted_label}) using query's top {len(top_activated_part_indices)} parts.")

            for support_idx in selected_support_indices: # Iterate through actual indices
                support_img_tensor = support_images_cpu[support_idx] # Get from CPU copy
                support_img_np_viz = preprocess_image_for_viz(support_img_tensor)
                support_img_input = support_img_tensor.unsqueeze(0).to(DEVICE) # Prep for model [1, C, H, W]

                if support_img_np_viz is None:
                    print(f"  Skipping CAMs for Support Img {support_idx} as base image processing failed.")
                    # If strict grid needed, increment plot_idx by actual_top_k_parts here
                    continue # Skip this support image

                # Generate CAM for each top part for this support image
                for rank, part_idx in enumerate(top_activated_part_indices):
                    if plot_idx > N_ROWS_VIS * N_COLS_VIS:
                        print(f"  Warning: Ran out of plot slots before visualizing all support CAMs.")
                        break # Break inner loop (parts)
                    ax = plt.subplot(N_ROWS_VIS, N_COLS_VIS, plot_idx); plot_idx+=1

                    # Calculate Grad-CAM for the SUPPORT image
                    grayscale_cam = calculate_gradcam_manual(
                        model,
                        target_layer_name,
                        support_img_input, # Pass the support image tensor
                        part_idx           # Use the part index from query's top parts
                    )

                    if grayscale_cam is not None:
                        visualization = show_cam_on_image(support_img_np_viz, grayscale_cam, use_rgb=True)
                        ax.imshow(visualization)
                        # Better title indicating support image and part index
                        ax.set_title(f"Supp Idx {support_idx} - P{part_idx} CAM\n(Q Part Rank {rank+1})")
                        ax.axis('off')
                    else:
                        print(f"  Failed to generate GradCAM for Support Img {support_idx}, Part {part_idx}.")
                        ax.set_title(f"Supp {support_idx} P{part_idx}\nCAM Fail"); ax.axis('off'); continue

                if plot_idx > N_ROWS_VIS * N_COLS_VIS: break # Break outer loop (support images)


        # --- Cleanup and Show ---
        plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout
        plt.show()

    # End of visualization loop


# --- Main Execution (or wherever the visualization function is called) ---
# Make sure the `trained_model` is loaded or trained before this call

# Update the final call to include the new parameter:
print("\n5. Visualizing HEPN Explanation with Manual Grad-CAM (Query & Support)...")
if trained_model is not None:
     visualize_explanation_hepn_manual_gradcam(
          model=trained_model,
          test_sampler=test_sampler,
          target_layer_name=TARGET_LAYER_NAME,
          num_explanations=5,
          top_k_parts=3,                 # Visualize top 3 parts based on query activation
          num_support_cams_per_class=3   # Show CAMs for first 2 support images of predicted class
     )
else:
     print("ERROR: No valid model available for visualization.")