# Individual Subject RDM Analysis

This notebook creates Representational Dissimilarity Matrices (RDMs) for each individual subject.
Each subject's RDM shows the similarity structure of object categories based on their averaged embeddings.

## Overview

This analysis:
1. Loads normalized age-month level embeddings from notebook 05 (normalized grouped embeddings)
2. Aggregates embeddings per subject across all age_mo (simple average across age bins)
3. Organizes categories using either a predefined category list (for consistent ordering) or automatic organization by type
4. Computes RDM for each subject using cosine distance with consistent category ordering
5. Handles missing categories by placing NA values for categories not present for each subject
6. Visualizes and saves individual subject RDMs with NA cells blacked out

## Key Features

- **Normalized embeddings**: Uses pre-normalized embeddings from notebook 05
- **Consistent category ordering**: Supports loading a predefined category list (e.g., from notebook 02) to ensure all subjects' RDMs have the same category order for easy visual comparison
- **Missing category handling**: Places NA values for categories not present for each subject, ensuring all RDMs have the same dimensions
- **NA visualization**: Blackouts NA cells in RDM visualizations to clearly indicate missing data
- **Data density handling**: Subjects with more data get more reliable RDMs, but all RDMs maintain the same structure
- **Age-month aggregation**: Averages embeddings across all age_mo bins for each subject-category combination


## Setup and Imports


In [27]:
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity, cosine_distances
from scipy.cluster.hierarchy import linkage, dendrogram, optimal_leaf_ordering
from scipy.spatial.distance import squareform
from collections import defaultdict
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set matplotlib backend
import matplotlib
matplotlib.use('Agg')

print("All imports successful!")

All imports successful!


## Configuration


In [28]:
# CDI words CSV file (required for category type organization)
cdi_path = Path("../../data/cdi_words.csv")

# Hierarchical clustering options
use_clustering = True  # Enable hierarchical clustering within category groups
save_dendrograms = True  # Save dendrogram plots for each category group

# Predefined category list for consistent RDM ordering (optional)
# Set to None to use automatic organization, or provide path to category order file
# This allows comparing RDMs across subjects with the same category ordering
USE_PREDEFINED_CATEGORY_LIST = True  # If True, load category order from PREDEFINED_CATEGORY_LIST_PATH
PREDEFINED_CATEGORY_LIST_PATH = "../vss-2026/bv_things_comp_12252025/bv_clip_filtered_zscored_hierarchical_163cats/category_order_reorganized.txt"  # Path to text file with category order (one category per line), or None
# Example: PREDEFINED_CATEGORY_LIST_PATH = "../vss-2026/bv_things_comp_12252025/bv_clip_filtered_zscored_hierarchical_163cats/category_order_reorganized.txt"

print(f"CDI path: {cdi_path}")
print(f"Use clustering: {use_clustering}")
print(f"Use predefined category list: {USE_PREDEFINED_CATEGORY_LIST}")
if USE_PREDEFINED_CATEGORY_LIST and PREDEFINED_CATEGORY_LIST_PATH:
    print(f"Predefined category list path: {PREDEFINED_CATEGORY_LIST_PATH}")

CDI path: ../../data/cdi_words.csv
Use clustering: True
Use predefined category list: True
Predefined category list path: ../vss-2026/bv_things_comp_12252025/bv_clip_filtered_zscored_hierarchical_163cats/category_order_reorganized.txt


In [29]:
def load_category_types(cdi_path):
    """Load category type information from CDI words CSV"""
    print(f"Loading category types from {cdi_path}...")
    cdi_df = pd.read_csv(cdi_path)
    
    category_types = {}
    for _, row in cdi_df.iterrows():
        category_types[row['uni_lemma']] = {
            'is_animate': bool(row.get('is_animate', 0)),
            'is_bodypart': bool(row.get('is_bodypart', 0)),
            'is_small': bool(row.get('is_small', 0)),
            'is_big': bool(row.get('is_big', 0))
        }
    
    print(f"Loaded type information for {len(category_types)} categories")
    return category_types

def cluster_categories_within_group(group_categories, cat_to_embedding, save_dendrogram=False, output_dir=None, group_name=None):
    """
    Perform hierarchical clustering within a group of categories.
    
    Args:
        group_categories: List of category names in the group
        cat_to_embedding: Dictionary mapping category names to embeddings
        save_dendrogram: Whether to save dendrogram plot (default: False)
        output_dir: Output directory for saving dendrogram (required if save_dendrogram=True)
        group_name: Name of the group for saving dendrogram (required if save_dendrogram=True)
    
    Returns:
        List of category names reordered according to clustering dendrogram
    """
    if len(group_categories) <= 1:
        return group_categories, None
    
    # Get embeddings for this group
    group_embeddings = np.array([cat_to_embedding[cat].flatten() for cat in group_categories])
    
    # Normalize embeddings (z-score normalization per embedding)
    normalized_embeddings = (group_embeddings - group_embeddings.mean(axis=0)) / (group_embeddings.std(axis=0) + 1e-10)
    
    # Compute distance matrix (1 - cosine similarity)
    similarity_matrix = cosine_similarity(normalized_embeddings)
    distance_matrix = 1 - similarity_matrix
    np.fill_diagonal(distance_matrix, 0)
    
    # Convert to condensed form for linkage
    condensed_distances = squareform(distance_matrix)
    
    # Perform hierarchical clustering
    linkage_matrix = linkage(condensed_distances, method='ward')
    
    # Get optimal leaf ordering for better visualization
    try:
        linkage_matrix = optimal_leaf_ordering(linkage_matrix, condensed_distances)
    except:
        # If optimal leaf ordering fails, use original linkage
        pass
    
    # Extract the order from the dendrogram
    dendro_dict = dendrogram(linkage_matrix, no_plot=True)
    leaf_order = dendro_dict['leaves']
    
    # Reorder categories according to clustering
    clustered_categories = [group_categories[i] for i in leaf_order]
    
    # Save dendrogram if requested
    if save_dendrogram and output_dir is not None and group_name is not None:
        output_dir = Path(output_dir)
        output_dir.mkdir(exist_ok=True, parents=True)
        
        plt.figure(figsize=(12, 8))
        dendrogram(linkage_matrix, 
                  labels=group_categories,
                  leaf_rotation=90,
                  leaf_font_size=10)
        plt.title(f'Hierarchical Clustering Dendrogram: {group_name.upper()}\n({len(group_categories)} categories)',
                 fontsize=16, pad=20)
        plt.xlabel('Category', fontsize=12)
        plt.ylabel('Distance', fontsize=12)
        plt.tight_layout()
        
        # Save as PNG
        output_path_png = output_dir / f'dendrogram_{group_name}.png'
        plt.savefig(output_path_png, dpi=300, bbox_inches='tight', pad_inches=0.2)
        print(f"    Saved dendrogram to {output_path_png}")
        
        # Save as PDF
        output_path_pdf = output_dir / f'dendrogram_{group_name}.pdf'
        plt.savefig(output_path_pdf, bbox_inches='tight', pad_inches=0.2)
        print(f"    Saved dendrogram to {output_path_pdf}")
        
        plt.close()
    
    return clustered_categories, linkage_matrix

print("Helper functions loaded!")

Helper functions loaded!


In [30]:
# Paths
# Path to normalized embeddings from notebook 05 (age-month level normalized embeddings)
# These are saved in category folders: {normalized_embeddings_dir}/{category}/{subject_id}_{age_mo}_month_level_avg.npy
normalized_embeddings_dir = Path("/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/facebook_dinov3-vitb16-pretrain-lvd1689m_grouped_by_age-mo_normalized")

# Detect embedding type from path
normalized_embeddings_dir_str = str(normalized_embeddings_dir).lower()
if "dinov3" in normalized_embeddings_dir_str or "dinov" in normalized_embeddings_dir_str:
    embedding_type = "dinov3"
elif "clip" in normalized_embeddings_dir_str:
    embedding_type = "clip"
else:
    embedding_type = "unknown"

# Create output directory with embedding type in name
output_dir = Path(f"individual_subject_rdms_{embedding_type}")
output_dir.mkdir(exist_ok=True, parents=True)

# Create subdirectories for organizing files
csv_dir = output_dir / "csv"
npy_dir = output_dir / "npy"
csv_dir.mkdir(exist_ok=True, parents=True)
npy_dir.mkdir(exist_ok=True, parents=True)

# Subject to exclude from analyses
excluded_subject = "00270001"

# Minimum categories required per subject to compute RDM
min_categories_per_subject = 10

print(f"Normalized embeddings directory: {normalized_embeddings_dir}")
print(f"Detected embedding type: {embedding_type}")
print(f"Output directory: {output_dir}")
print(f"CSV subdirectory: {csv_dir}")
print(f"NPY subdirectory: {npy_dir}")
print(f"Excluded subject: {excluded_subject}")
print(f"Min categories per subject: {min_categories_per_subject}")

Normalized embeddings directory: /data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/facebook_dinov3-vitb16-pretrain-lvd1689m_grouped_by_age-mo_normalized
Detected embedding type: dinov3
Output directory: individual_subject_rdms_dinov3
CSV subdirectory: individual_subject_rdms_dinov3/csv
NPY subdirectory: individual_subject_rdms_dinov3/npy
Excluded subject: 00270001
Min categories per subject: 10


## Load Normalized Grouped Embeddings from Notebook 05

This section loads normalized age-month level embeddings from notebook 05 and aggregates them to subject level.


In [31]:
# Load normalized age-month level embeddings from notebook 05 and aggregate to subject level
print("Loading normalized age-month level embeddings from notebook 05...")
print(f"  Source directory: {normalized_embeddings_dir}")

# Get all category folders
category_folders = [f for f in normalized_embeddings_dir.iterdir() if f.is_dir()]
print(f"  Found {len(category_folders)} category folders")

# Collect all embeddings by subject and category
# Structure: {subject_id: {category: [list of age_mo embeddings]}}
subject_category_embeddings = defaultdict(lambda: defaultdict(list))
all_categories_set = set()

for category_folder in tqdm(category_folders, desc="Loading category folders"):
    category = category_folder.name
    all_categories_set.add(category)
    
    # Get all embedding files in this category
    embedding_files = list(category_folder.glob("*.npy"))
    
    for emb_file in embedding_files:
        # Parse filename: {subject_id}_{age_mo}_month_level_avg.npy
        filename = emb_file.stem  # without .npy
        parts = filename.split('_')
        
        if len(parts) < 2:
            continue
        
        # Extract subject_id and age_mo
        subject_id = parts[0]
        age_mo = int(parts[1]) if parts[1].isdigit() else None
        
        if age_mo is None:
            continue
        
        # Exclude subject if specified
        if excluded_subject and subject_id == excluded_subject:
            continue
        
        try:
            embedding = np.load(emb_file)
            subject_category_embeddings[subject_id][category].append(embedding)
        except Exception as e:
            print(f"Error loading {emb_file}: {e}")
            continue

# Aggregate embeddings per subject: average across age_mo for each category
print(f"\nAggregating embeddings per subject (averaging across age_mo)...")
subject_embeddings_normalized = {}

for subject_id in tqdm(subject_category_embeddings.keys(), desc="Aggregating subjects"):
    subject_embeddings_normalized[subject_id] = {}
    
    for category, age_mo_embeddings in subject_category_embeddings[subject_id].items():
        if len(age_mo_embeddings) > 0:
            # Average across all age_mo embeddings for this category
            # Stack embeddings and compute mean
            stacked = np.array([emb.flatten() for emb in age_mo_embeddings])
            avg_embedding = stacked.mean(axis=0)
            subject_embeddings_normalized[subject_id][category] = avg_embedding

print(f"\nLoaded and aggregated normalized embeddings for {len(subject_embeddings_normalized)} subjects")
print(f"  Total unique categories across all subjects: {len(all_categories_set)}")

# Convert to list for easier handling
all_categories = sorted(list(all_categories_set))

Loading normalized age-month level embeddings from notebook 05...
  Source directory: /data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/facebook_dinov3-vitb16-pretrain-lvd1689m_grouped_by_age-mo_normalized
  Found 163 category folders


Loading category folders: 100%|██████████| 163/163 [00:01<00:00, 130.38it/s]



Aggregating embeddings per subject (averaging across age_mo)...


Aggregating subjects: 100%|██████████| 31/31 [00:00<00:00, 479.20it/s]


Loaded and aggregated normalized embeddings for 31 subjects
  Total unique categories across all subjects: 163





## Organize Categories (with Predefined List Option)

This section organizes categories either by loading a predefined category list (for consistent ordering across subjects) or by automatic organization.


In [32]:
# Organize categories: either load predefined list or organize automatically
print("Organizing categories...")

if USE_PREDEFINED_CATEGORY_LIST and PREDEFINED_CATEGORY_LIST_PATH is not None:
    # Load predefined category list
    predefined_path = Path(PREDEFINED_CATEGORY_LIST_PATH)
    if not predefined_path.exists():
        raise FileNotFoundError(f"Predefined category list file not found: {predefined_path}")
    
    print(f"  Loading predefined category order from {predefined_path}...")
    with open(predefined_path, 'r') as f:
        # Skip comment lines (lines starting with #)
        ordered_categories = [line.strip() for line in f if line.strip() and not line.strip().startswith('#')]
    
    # Verify that all categories in predefined list exist in our data
    predefined_set = set(ordered_categories)
    all_categories_set = set(all_categories)
    
    if predefined_set != all_categories_set:
        missing_in_predefined = all_categories_set - predefined_set
        extra_in_predefined = predefined_set - all_categories_set
        if missing_in_predefined:
            print(f"  Warning: {len(missing_in_predefined)} categories in data but not in predefined list: {sorted(missing_in_predefined)[:5]}...")
        if extra_in_predefined:
            print(f"  Warning: {len(extra_in_predefined)} categories in predefined list but not in data: {sorted(extra_in_predefined)[:5]}...")
        # Use intersection: only categories that exist in both
        ordered_categories = [cat for cat in ordered_categories if cat in all_categories_set]
        print(f"  Using intersection: {len(ordered_categories)} categories")
    
    print(f"  Loaded {len(ordered_categories)} categories in predefined order")
    
    # Create dummy organized dict for compatibility (won't be used for visualization boundaries)
    organized = {'animals': [], 'bodyparts': [], 'big_objects': [], 'small_objects': [], 'others': []}
    
else:
    # Automatic organization by type (similar to notebook 02)
    print(f"  Organizing categories by type...")
    cdi_path = Path(cdi_path)
    
    if cdi_path.exists():
        category_types = load_category_types(cdi_path)
        
        # Organize by type
        organized = {
            'animals': [],
            'bodyparts': [],
            'big_objects': [],
            'small_objects': [],
            'others': []
        }
        
        for cat in all_categories:
            if cat not in category_types:
                organized['others'].append(cat)
                continue
            
            types = category_types[cat]
            if types['is_animate']:
                organized['animals'].append(cat)
            elif types['is_bodypart']:
                organized['bodyparts'].append(cat)
            elif types['is_big']:
                organized['big_objects'].append(cat)
            elif types['is_small']:
                organized['small_objects'].append(cat)
            else:
                organized['others'].append(cat)
        
        print(f"  Organized into: {len(organized['animals'])} animals, {len(organized['bodyparts'])} bodyparts, "
              f"{len(organized['big_objects'])} big objects, {len(organized['small_objects'])} small objects, "
              f"{len(organized['others'])} others")
        
        # Apply hierarchical clustering if enabled
        if use_clustering:
            print(f"  Applying hierarchical clustering within groups...")
            # Create a representative embedding dict for clustering (use first subject that has all categories)
            cat_to_embedding = {}
            for subject_id, subject_embeddings in subject_embeddings_normalized.items():
                if all(cat in subject_embeddings for cat in all_categories):
                    cat_to_embedding = {cat: subject_embeddings[cat] for cat in all_categories}
                    break
            
            # If no subject has all categories, use average across subjects
            if not cat_to_embedding:
                print(f"    No subject has all categories, computing average embeddings for clustering...")
                for cat in all_categories:
                    cat_embeddings = []
                    for subject_embeddings in subject_embeddings_normalized.values():
                        if cat in subject_embeddings:
                            cat_embeddings.append(subject_embeddings[cat])
                    if cat_embeddings:
                        cat_to_embedding[cat] = np.array(cat_embeddings).mean(axis=0)
            
            for group_name in ['animals', 'bodyparts', 'big_objects', 'small_objects', 'others']:
                if len(organized[group_name]) > 1:
                    print(f"    Clustering {group_name} ({len(organized[group_name])} categories)...")
                    organized[group_name], _ = cluster_categories_within_group(
                        organized[group_name],
                        cat_to_embedding,
                        save_dendrogram=save_dendrograms,
                        output_dir=output_dir,
                        group_name=group_name
                    )
        else:
            for group_name in organized:
                organized[group_name] = sorted(organized[group_name])
        
        # Create ordered list
        ordered_categories = (
            organized['animals'] +
            organized['bodyparts'] +
            organized['big_objects'] +
            organized['small_objects'] +
            organized['others']
        )
    else:
        print(f"  Warning: CDI path {cdi_path} not found. Using alphabetical order.")
        organized = {'animals': [], 'bodyparts': [], 'big_objects': [], 'small_objects': [], 'others': all_categories}
        ordered_categories = sorted(all_categories)

print(f"\nFinal ordered category list: {len(ordered_categories)} categories")

Organizing categories...
  Loading predefined category order from ../vss-2026/bv_things_comp_12252025/bv_clip_filtered_zscored_hierarchical_163cats/category_order_reorganized.txt...
  Loaded 163 categories in predefined order

Final ordered category list: 163 categories


## Compute Individual Subject RDMs


In [33]:
def compute_subject_rdm_with_na(subject_embeddings_dict, ordered_categories_list):
    """
    Compute RDM for a single subject with NA for missing categories.
    
    Args:
        subject_embeddings_dict: dict[category] = embedding array (should be normalized)
        ordered_categories_list: list of all categories in desired order (may include categories not present for this subject)
    
    Returns:
        rdm: numpy array of shape (n_categories, n_categories) with np.nan for missing categories
        mask: boolean array of shape (n_categories, n_categories) where True indicates NA (missing category)
        available_categories: list of categories actually present for this subject
    """
    n_categories = len(ordered_categories_list)
    
    # Find available categories (categories that exist for this subject)
    available_categories = [cat for cat in ordered_categories_list if cat in subject_embeddings_dict]
    
    if len(available_categories) < 2:
        # Return RDM full of NaN if not enough categories
        rdm = np.full((n_categories, n_categories), np.nan)
        mask = np.ones((n_categories, n_categories), dtype=bool)
        return rdm, mask, available_categories
    
    # Build embedding matrix for available categories (already normalized)
    embedding_matrix = np.array([subject_embeddings_dict[cat].flatten() for cat in available_categories])
    
    # Ensure 2D shape: (n_available_categories, embedding_dim)
    if embedding_matrix.ndim != 2:
        raise ValueError(f"Expected 2D embedding matrix, got shape {embedding_matrix.shape}")
    
    # Compute cosine similarity for available categories
    similarity_matrix_available = cosine_similarity(embedding_matrix)
    
    # Convert to distance (RDM) for available categories
    distance_matrix_available = 1 - similarity_matrix_available
    np.fill_diagonal(distance_matrix_available, 0)  # Ensure diagonal is 0
    
    # Make symmetric (in case of numerical errors)
    distance_matrix_available = (distance_matrix_available + distance_matrix_available.T) / 2
    
    # Create full RDM with NaN for missing categories
    rdm = np.full((n_categories, n_categories), np.nan)
    mask = np.ones((n_categories, n_categories), dtype=bool)
    
    # Map available categories to their indices in ordered_categories_list
    available_indices = [ordered_categories_list.index(cat) for cat in available_categories]
    
    # Fill in the RDM for available categories
    for i, idx_i in enumerate(available_indices):
        for j, idx_j in enumerate(available_indices):
            rdm[idx_i, idx_j] = distance_matrix_available[i, j]
            mask[idx_i, idx_j] = False  # False means not NA (data present)
    
    return rdm, mask, available_categories

# Compute RDMs for each subject using normalized embeddings with NA for missing categories
print("\nComputing RDMs for each subject (with NA for missing categories)...")
subject_rdms = {}
subject_rdm_masks = {}  # Store masks indicating NA cells
subject_rdm_categories = {}  # Store available categories for each subject

for subject_id, subject_embeddings in tqdm(subject_embeddings_normalized.items(), desc="Computing RDMs"):
    if len(subject_embeddings) < min_categories_per_subject:
        continue
    
    rdm, mask, available_cats = compute_subject_rdm_with_na(subject_embeddings, ordered_categories)
    
    if rdm is not None and len(available_cats) >= min_categories_per_subject:
        subject_rdms[subject_id] = rdm
        subject_rdm_masks[subject_id] = mask
        subject_rdm_categories[subject_id] = available_cats

print(f"\nComputed RDMs for {len(subject_rdms)} subjects")
print(f"  (Excluded {len(subject_embeddings_normalized) - len(subject_rdms)} subjects with < {min_categories_per_subject} categories)")

# Compute group boundaries for visualization (based on ordered_categories)
print("\nComputing group boundaries for visualization...")
subject_group_boundaries = {}  # Store group boundaries for visual separators

for subject_id in subject_rdms.keys():
    # Compute group boundaries based on ordered_categories
    group_boundaries = []
    current_idx = 0
    for group_name in ['animals', 'bodyparts', 'big_objects', 'small_objects', 'others']:
        group_cats = [cat for cat in organized[group_name] if cat in ordered_categories]
        if len(group_cats) > 0:
            group_start = current_idx
            group_end = current_idx + len(group_cats)
            group_boundaries.append({
                'name': group_name,
                'start': group_start,
                'end': group_end,
                'categories': group_cats
            })
            current_idx = group_end
    
    subject_group_boundaries[subject_id] = group_boundaries

print(f"Computed group boundaries for {len(subject_rdms)} subjects")


Computing RDMs for each subject (with NA for missing categories)...


Computing RDMs: 100%|██████████| 31/31 [00:00<00:00, 161.50it/s]


Computed RDMs for 31 subjects
  (Excluded 0 subjects with < 10 categories)

Computing group boundaries for visualization...
Computed group boundaries for 31 subjects





## Save Individual Subject RDMs


In [34]:
## Visualize All Individual Subject RDMs

# Plot all individual subject RDMs in a grid
n_subjects = len(subject_rdms)
subject_ids = list(subject_rdms.keys())

# Calculate grid dimensions
n_cols = 6  # Number of columns
n_rows = int(np.ceil(n_subjects / n_cols))

# Create figure with appropriate size
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3))
axes = axes.flatten() if n_subjects > 1 else [axes]

# Find global min/max for consistent color scale across all RDMs (excluding NaN)
all_rdm_values = []
for rdm in subject_rdms.values():
    valid_values = rdm[~np.isnan(rdm)]
    if len(valid_values) > 0:
        all_rdm_values.extend(valid_values)
vmin = np.percentile(all_rdm_values, 1) if len(all_rdm_values) > 0 else 0
vmax = np.percentile(all_rdm_values, 99) if len(all_rdm_values) > 0 else 2

print(f"Plotting {n_subjects} individual subject RDMs...")
print(f"Color scale range: [{vmin:.4f}, {vmax:.4f}]")

for idx, subject_id in enumerate(subject_ids):
    rdm = subject_rdms[subject_id]
    mask = subject_rdm_masks[subject_id]
    available_cats = subject_rdm_categories[subject_id]
    
    ax = axes[idx]
    
    # Create a masked array for visualization
    # Use set_bad() to color NaN/masked values with white (highly visible)
    rdm_masked = np.ma.masked_where(mask, rdm)
    cmap = plt.cm.get_cmap('viridis').copy()  # Get a copy to avoid modifying global colormap
    cmap.set_bad(color='white', alpha=1.0)  # White for NA cells - highly visible
    im = ax.imshow(rdm_masked, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
    
    ax.set_title(f"{subject_id}\n({len(available_cats)}/{len(ordered_categories)} cats)", fontsize=9, pad=5)
    ax.set_xlabel('Category', fontsize=7)
    ax.set_ylabel('Category', fontsize=7)
    ax.tick_params(labelsize=6)
    
    # Set ticks and labels to show all categories in ordered_categories
    ax.set_xticks(range(len(ordered_categories)))
    ax.set_yticks(range(len(ordered_categories)))
    ax.set_xticklabels(ordered_categories, rotation=90, ha='right', fontsize=4)
    ax.set_yticklabels(ordered_categories, fontsize=4)
    
    # Add colorbar for each subplot
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

# Hide unused subplots
for idx in range(n_subjects, len(axes)):
    axes[idx].axis('off')

plt.suptitle(f'All Individual Subject RDMs (n={n_subjects})', fontsize=16, y=0.995)
plt.tight_layout(rect=[0, 0, 1, 0.99])
plt.savefig(output_dir / "all_individual_rdms.png", dpi=200, bbox_inches='tight')
print(f"\nSaved all individual RDM visualization to {output_dir / 'all_individual_rdms.png'}")
plt.close()

# Also create a version with coolwarm colormap
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3))
axes = axes.flatten() if n_subjects > 1 else [axes]

for idx, subject_id in enumerate(subject_ids):
    rdm = subject_rdms[subject_id]
    mask = subject_rdm_masks[subject_id]
    available_cats = subject_rdm_categories[subject_id]
    
    ax = axes[idx]
    
    # Create a masked array for visualization
    # Use set_bad() to color NaN/masked values with white (highly visible)
    rdm_masked = np.ma.masked_where(mask, rdm)
    cmap = plt.cm.get_cmap('coolwarm').copy()  # Get a copy to avoid modifying global colormap
    cmap.set_bad(color='white', alpha=1.0)  # White for NA cells - highly visible
    im = ax.imshow(rdm_masked, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
    
    ax.set_title(f"{subject_id}\n({len(available_cats)}/{len(ordered_categories)} cats)", fontsize=9, pad=5)
    ax.set_xlabel('Category', fontsize=7)
    ax.set_ylabel('Category', fontsize=7)
    ax.tick_params(labelsize=6)
    
    # Set ticks and labels to show all categories in ordered_categories
    ax.set_xticks(range(len(ordered_categories)))
    ax.set_yticks(range(len(ordered_categories)))
    ax.set_xticklabels(ordered_categories, rotation=90, ha='right', fontsize=4)
    ax.set_yticklabels(ordered_categories, fontsize=4)
    
    # Add colorbar for each subplot
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

# Hide unused subplots
for idx in range(n_subjects, len(axes)):
    axes[idx].axis('off')

plt.suptitle(f'All Individual Subject RDMs (n={n_subjects}) - Coolwarm', fontsize=16, y=0.995)
plt.tight_layout(rect=[0, 0, 1, 0.99])
plt.savefig(output_dir / "all_individual_rdms_coolwarm.png", dpi=200, bbox_inches='tight')
print(f"Saved all individual RDM visualization (coolwarm) to {output_dir / 'all_individual_rdms_coolwarm.png'}")
plt.close()

# Save each individual RDM separately with category names as axis labels
print("\nSaving individual RDM plots...")
individual_rdm_dir = output_dir / "individual_rdm_plots"
individual_rdm_dir.mkdir(exist_ok=True, parents=True)

# Find global min/max for consistent color scale (excluding NaN)
all_rdm_values = []
for rdm in subject_rdms.values():
    valid_values = rdm[~np.isnan(rdm)]
    if len(valid_values) > 0:
        all_rdm_values.extend(valid_values)
vmin = np.percentile(all_rdm_values, 1) if len(all_rdm_values) > 0 else 0
vmax = np.percentile(all_rdm_values, 99) if len(all_rdm_values) > 0 else 2

for subject_id in tqdm(subject_rdms.keys(), desc="Saving individual RDMs"):
    rdm = subject_rdms[subject_id]
    mask = subject_rdm_masks[subject_id]
    available_cats = subject_rdm_categories[subject_id]
    
    # Determine figure size based on number of categories (use ordered_categories for size)
    n_cats = len(ordered_categories)
    fig_size = max(10, n_cats * 0.3)
    
    # Set font size for category labels (adaptive)
    if n_cats <= 50:
        label_fontsize = 12
        tick_fontsize = 20
    elif n_cats <= 100:
        label_fontsize = 10
        tick_fontsize = 18
    else:
        label_fontsize = 8
        tick_fontsize = 16
    
    # Create figure with viridis colormap
    fig, ax = plt.subplots(figsize=(fig_size, fig_size))
    
    # Create a masked array for visualization
    # Use set_bad() to color NaN/masked values with white (highly visible)
    rdm_masked = np.ma.masked_where(mask, rdm)
    cmap = plt.cm.get_cmap('viridis').copy()  # Get a copy to avoid modifying global colormap
    cmap.set_bad(color='white', alpha=1.0)  # White for NA cells - highly visible
    im = ax.imshow(rdm_masked, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
    
    # Add visual separators between category groups
    group_boundaries = subject_group_boundaries.get(subject_id, [])
    for boundary in group_boundaries:
        # Draw vertical line
        if boundary['start'] > 0:
            ax.axvline(x=boundary['start'] - 0.5, color='white', linewidth=2, linestyle='--', alpha=0.7)
        # Draw horizontal line
        if boundary['start'] > 0:
            ax.axhline(y=boundary['start'] - 0.5, color='white', linewidth=2, linestyle='--', alpha=0.7)
    
    # Set category names as axis labels (use ordered_categories for all subjects)
    ax.set_xticks(range(len(ordered_categories)))
    ax.set_yticks(range(len(ordered_categories)))
    ax.set_xticklabels(ordered_categories, rotation=90, ha='right', fontsize=tick_fontsize)
    ax.set_yticklabels(ordered_categories, fontsize=tick_fontsize)
    
    ax.set_xlabel('Category', fontsize=label_fontsize)
    ax.set_ylabel('Category', fontsize=label_fontsize)
    ax.set_title(f'RDM: {subject_id}\n({len(available_cats)}/{n_cats} categories)', fontsize=14, pad=10)
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('Distance (1 - Cosine Similarity)', fontsize=12)
    
    plt.tight_layout()
    
    # Save as PNG
    output_path_png = individual_rdm_dir / f"rdm_{subject_id}.png"
    plt.savefig(output_path_png, dpi=200, bbox_inches='tight')
    
    plt.close()
    
    # Create and save individual dendrogram for this subject (only for available categories)
    if len(available_cats) > 1:
        # Get embeddings for this subject's available categories
        subject_embeddings = {cat: subject_embeddings_normalized[subject_id][cat] 
                             for cat in available_cats if cat in subject_embeddings_normalized[subject_id]}
        
        if len(subject_embeddings) > 1:
            # Build embedding matrix
            embedding_matrix = np.array([subject_embeddings[cat].flatten() for cat in available_cats])
            
            # Normalize embeddings
            normalized_embeddings = (embedding_matrix - embedding_matrix.mean(axis=0)) / (embedding_matrix.std(axis=0) + 1e-10)
            
            # Compute distance matrix
            similarity_matrix = cosine_similarity(normalized_embeddings)
            distance_matrix = 1 - similarity_matrix
            np.fill_diagonal(distance_matrix, 0)
            
            # Convert to condensed form for linkage
            condensed_distances = squareform(distance_matrix)
            
            # Perform hierarchical clustering
            linkage_matrix = linkage(condensed_distances, method='ward')
            
            # Get optimal leaf ordering
            try:
                linkage_matrix = optimal_leaf_ordering(linkage_matrix, condensed_distances)
            except:
                pass
            
            # Create dendrogram
            plt.figure(figsize=(max(16, len(available_cats) * 0.5), 10))
            dendrogram(linkage_matrix, 
                      labels=available_cats,
                      leaf_rotation=90,
                      leaf_font_size=max(8, min(14, 200 // len(available_cats))))
            plt.title(f'Individual Subject Dendrogram: {subject_id}\n({len(available_cats)} categories)',
                     fontsize=16, pad=20)
            plt.xlabel('Category', fontsize=14)
            plt.ylabel('Distance', fontsize=14)
            plt.tight_layout()
            
            # Save dendrogram
            dendrogram_dir = individual_rdm_dir / "dendrograms"
            dendrogram_dir.mkdir(exist_ok=True, parents=True)
            dendrogram_path = dendrogram_dir / f"dendrogram_{subject_id}.png"
            plt.savefig(dendrogram_path, dpi=300, bbox_inches='tight', pad_inches=0.2)
            plt.close()

print(f"\nSaved {len(subject_rdms)} individual RDM plots to {individual_rdm_dir}")
print(f"  Each subject has 1 RDM file: rdm_{subject_id}.png")
print(f"  Individual dendrograms saved to {individual_rdm_dir / 'dendrograms'}")

Plotting 31 individual subject RDMs...
Color scale range: [0.2783, 1.3280]

Saved all individual RDM visualization to individual_subject_rdms_dinov3/all_individual_rdms.png
Saved all individual RDM visualization (coolwarm) to individual_subject_rdms_dinov3/all_individual_rdms_coolwarm.png

Saving individual RDM plots...


Saving individual RDMs:  65%|██████▍   | 20/31 [03:15<01:47,  9.78s/it]


: 

In [None]:
# Save RDMs
print("Saving individual subject RDMs...")
for subject_id, rdm in tqdm(subject_rdms.items(), desc="Saving RDMs"):
    available_cats = subject_rdm_categories[subject_id]
    
    # Save as numpy array (includes NaN for missing categories) - save to npy subdirectory
    np.save(npy_dir / f"rdm_{subject_id}.npy", rdm)
    
    # Save as CSV with category labels (use ordered_categories for consistent ordering) - save to csv subdirectory
    rdm_df = pd.DataFrame(rdm, index=ordered_categories, columns=ordered_categories)
    rdm_df.to_csv(csv_dir / f"rdm_{subject_id}.csv")
    
    # Save metadata - save to csv subdirectory
    # Compute statistics only on valid (non-NaN) values
    valid_rdm = rdm[~np.isnan(rdm)]
    valid_rdm_positive = valid_rdm[valid_rdm > 0]
    
    metadata = {
        'subject_id': subject_id,
        'n_categories_total': len(ordered_categories),
        'n_categories_available': len(available_cats),
        'n_categories_missing': len(ordered_categories) - len(available_cats),
        'mean_distance': float(np.nanmean(rdm)),
        'std_distance': float(np.nanstd(rdm)),
        'min_distance': float(valid_rdm_positive.min()) if len(valid_rdm_positive) > 0 else np.nan,
        'max_distance': float(np.nanmax(rdm))
    }
    
    metadata_df = pd.DataFrame([metadata])
    metadata_df.to_csv(csv_dir / f"metadata_{subject_id}.csv", index=False)

print(f"\nSaved RDMs to {output_dir}")
print(f"  CSV files: {csv_dir}")
print(f"  NPY files: {npy_dir}")

## Create Summary Statistics


In [None]:
# Create summary dataframe
summary_data = []
for subject_id, rdm in subject_rdms.items():
    available_cats = subject_rdm_categories[subject_id]
    
    # Compute statistics only on valid (non-NaN) values
    valid_rdm = rdm[~np.isnan(rdm)]
    valid_rdm_positive = valid_rdm[valid_rdm > 0]
    
    summary_data.append({
        'subject_id': subject_id,
        'n_categories_total': len(ordered_categories),
        'n_categories_available': len(available_cats),
        'n_categories_missing': len(ordered_categories) - len(available_cats),
        'mean_distance': float(np.nanmean(rdm)),
        'std_distance': float(np.nanstd(rdm)),
        'min_distance': float(valid_rdm_positive.min()) if len(valid_rdm_positive) > 0 else np.nan,
        'max_distance': float(np.nanmax(rdm))
    })

summary_df = pd.DataFrame(summary_data)
summary_df = summary_df.sort_values('n_categories_available', ascending=False)
summary_df.to_csv(csv_dir / "summary_statistics.csv", index=False)

print("Summary statistics:")
print(summary_df.describe())
print(f"\nSaved summary to {csv_dir / 'summary_statistics.csv'}")

## Visualize Sample RDMs


In [None]:
# Visualize a few sample RDMs
n_samples = min(6, len(subject_rdms))
sample_subjects = list(subject_rdms.keys())[:n_samples]

# Find global min/max for consistent color scale (excluding NaN)
all_rdm_values = []
for rdm in [subject_rdms[sid] for sid in sample_subjects]:
    valid_values = rdm[~np.isnan(rdm)]
    if len(valid_values) > 0:
        all_rdm_values.extend(valid_values)
vmin = np.percentile(all_rdm_values, 1) if len(all_rdm_values) > 0 else 0
vmax = np.percentile(all_rdm_values, 99) if len(all_rdm_values) > 0 else 2

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for idx, subject_id in enumerate(sample_subjects):
    rdm = subject_rdms[subject_id]
    mask = subject_rdm_masks[subject_id]
    available_cats = subject_rdm_categories[subject_id]
    
    ax = axes[idx]
    
    # Create a masked array for visualization
    # Use set_bad() to color NaN/masked values with white (highly visible)
    rdm_masked = np.ma.masked_where(mask, rdm)
    cmap = plt.cm.get_cmap('viridis').copy()  # Get a copy to avoid modifying global colormap
    cmap.set_bad(color='white', alpha=1.0)  # White for NA cells - highly visible
    im = ax.imshow(rdm_masked, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
    
    ax.set_title(f"{subject_id}\n({len(available_cats)}/{len(ordered_categories)} cats)", fontsize=10)
    ax.set_xlabel('Category')
    ax.set_ylabel('Category')
    plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.savefig(output_dir / "sample_rdms.png", dpi=150, bbox_inches='tight')
print(f"Saved sample RDM visualization to {output_dir / 'sample_rdms.png'}")
plt.close()

## Data Density Analysis


In [None]:
# Analyze data density across subjects
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Category count distribution
axes[0].hist([len(cats) for cats in subject_rdm_categories.values()], bins=20, edgecolor='black')
axes[0].set_xlabel('Number of Categories per Subject')
axes[0].set_ylabel('Number of Subjects')
axes[0].set_title('Data Density: Categories per Subject')
axes[0].axvline(min_categories_per_subject, color='red', linestyle='--', label=f'Min threshold ({min_categories_per_subject})')
axes[0].legend()

# Mean distance vs category count
mean_distances = [np.nanmean(subject_rdms[sid]) for sid in subject_rdms.keys()]
n_categories = [len(subject_rdm_categories[sid]) for sid in subject_rdms.keys()]
axes[1].scatter(n_categories, mean_distances, alpha=0.6)
axes[1].set_xlabel('Number of Categories')
axes[1].set_ylabel('Mean RDM Distance')
axes[1].set_title('RDM Distance vs Data Density')

plt.tight_layout()
plt.savefig(output_dir / "data_density_analysis.png", dpi=150, bbox_inches='tight')
print(f"Saved data density analysis to {output_dir / 'data_density_analysis.png'}")
plt.close()

## Category Intersection Analysis

This section analyzes which categories are shared across subjects and helps determine which subjects and categories to include for intersection-based analysis.


In [None]:
# Analyze category intersections across subjects
print("Analyzing category intersections...")

# Get all subject category sets
subject_category_sets = {sid: set(cats) for sid, cats in subject_rdm_categories.items()}

# Compute category counts across all subjects (do this once)
category_counts = {}
for sid, cat_set in subject_category_sets.items():
    for cat in cat_set:
        category_counts[cat] = category_counts.get(cat, 0) + 1

# Compute intersections for different numbers of subjects
# For each n, find categories that appear in at least n subjects
n_subjects_total = len(subject_category_sets)
intersection_analysis = []

for n in range(1, n_subjects_total + 1):
    # Categories that appear in at least n subjects
    intersecting_cats = [cat for cat, count in category_counts.items() if count >= n]
    intersection_analysis.append({
        'n_subjects': n,
        'intersection_size': len(intersecting_cats),
        'categories': intersecting_cats
    })

# Create DataFrame for easier analysis
intersection_df = pd.DataFrame(intersection_analysis)

print(f"\nIntersection Analysis:")
print(f"  Total subjects: {n_subjects_total}")
print(f"\nIntersection sizes by number of subjects:")
for _, row in intersection_df.iterrows():
    print(f"  {row['n_subjects']:2d} subjects: {row['intersection_size']:3d} categories")

# Plot: Number of subjects vs intersection size
fig, ax = plt.subplots(figsize=(10, 6))
ax.plot(intersection_df['n_subjects'], intersection_df['intersection_size'], 
        marker='o', linewidth=2, markersize=8)
ax.set_xlabel('Number of Subjects (Minimum)', fontsize=12)
ax.set_ylabel('Intersection Size (Number of Categories)', fontsize=12)
ax.set_title('Category Intersection Analysis\n(Number of Categories Shared by at Least N Subjects)', fontsize=14)
ax.grid(True, alpha=0.3)
ax.set_xticks(range(1, n_subjects_total + 1))

# Add annotations for key points
for n in [1, n_subjects_total // 2, n_subjects_total]:
    if n <= n_subjects_total:
        row = intersection_df[intersection_df['n_subjects'] == n].iloc[0]
        ax.annotate(f"{row['intersection_size']} cats", 
                   xy=(n, row['intersection_size']),
                   xytext=(5, 5), textcoords='offset points',
                   fontsize=9, bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.5))

plt.tight_layout()
plt.savefig(output_dir / "category_intersection_analysis.png", dpi=150, bbox_inches='tight')
print(f"\nSaved intersection analysis plot to {output_dir / 'category_intersection_analysis.png'}")
plt.close()

# Save intersection data
intersection_df.to_csv(csv_dir / "category_intersection_analysis.csv", index=False)
print(f"Saved intersection data to {csv_dir / 'category_intersection_analysis.csv'}")

# Display intersection sizes for different thresholds
print("\n" + "="*60)
print("Recommended thresholds for intersection-based analysis:")
print("="*60)
for threshold in [n_subjects_total, int(n_subjects_total * 0.9), int(n_subjects_total * 0.8), 
                  int(n_subjects_total * 0.7), int(n_subjects_total * 0.5)]:
    if threshold <= n_subjects_total:
        row = intersection_df[intersection_df['n_subjects'] == threshold].iloc[0]
        print(f"  At least {threshold:2d} subjects ({threshold/n_subjects_total*100:.1f}%): {row['intersection_size']:3d} categories")


## RDMs with Intersecting Categories Only

This section computes and visualizes RDMs using only the categories that are shared across all (or a specified subset of) subjects, maintaining consistent category ordering.


In [None]:
# Configuration for intersection-based RDMs
# Set minimum number of subjects that must have a category for it to be included
# Can be a single value or a list of values to iterate through
MIN_SUBJECTS_FOR_CATEGORY = list(range(20, n_subjects_total + 1))  # 20+ subjects
# Examples:
# MIN_SUBJECTS_FOR_CATEGORY = [20, 24, 27, 31]  # Specific thresholds
# MIN_SUBJECTS_FOR_CATEGORY = n_subjects_total  # Single value (all subjects)
# MIN_SUBJECTS_FOR_CATEGORY = list(range(20, n_subjects_total + 1))  # All thresholds from 20 to total

# Convert single value to list for uniform handling
if not isinstance(MIN_SUBJECTS_FOR_CATEGORY, list):
    MIN_SUBJECTS_FOR_CATEGORY = [MIN_SUBJECTS_FOR_CATEGORY]

print(f"Computing intersection-based RDMs for {len(MIN_SUBJECTS_FOR_CATEGORY)} threshold(s)...")
print(f"  Thresholds: {MIN_SUBJECTS_FOR_CATEGORY}")

# Store RDMs for each threshold
all_intersection_rdms = {}  # {threshold: {subject_id: rdm}}
all_intersection_masks = {}  # {threshold: {subject_id: mask}}
all_intersection_categories = {}  # {threshold: {subject_id: available_cats}}
all_intersection_category_lists = {}  # {threshold: intersecting_categories_ordered}

# Iterate through each threshold
for min_subjects in MIN_SUBJECTS_FOR_CATEGORY:
    print(f"\n{'='*60}")
    print(f"Processing threshold: {min_subjects} subjects")
    print(f"{'='*60}")
    
    # Get the intersecting categories for this threshold
    if min_subjects not in intersection_df['n_subjects'].values:
        print(f"  Warning: Threshold {min_subjects} not found in intersection_df. Skipping.")
        continue
    
    intersection_row = intersection_df[intersection_df['n_subjects'] == min_subjects].iloc[0]
    intersecting_categories = intersection_row['categories']
    
    print(f"  Intersecting categories: {len(intersecting_categories)}")
    
    # Filter ordered_categories to only include intersecting categories (maintain order)
    intersecting_categories_ordered = [cat for cat in ordered_categories if cat in intersecting_categories]
    
    print(f"  Intersecting categories (ordered): {len(intersecting_categories_ordered)}")
    
    # Compute RDMs using only intersecting categories
    subject_rdms_intersection = {}
    subject_rdm_masks_intersection = {}
    subject_rdm_categories_intersection = {}
    
    for subject_id in tqdm(subject_rdms.keys(), desc=f"Computing RDMs (threshold={min_subjects})"):
        # Get available categories for this subject that are in the intersection
        # Only include subjects that have ALL intersecting categories
        available_cats = [cat for cat in intersecting_categories_ordered 
                         if cat in subject_rdm_categories[subject_id]]
        
        # Only include subjects with all intersecting categories
        if len(available_cats) < len(intersecting_categories_ordered):
            continue
        
        # Get embeddings for intersecting categories
        subject_embeddings = subject_embeddings_normalized[subject_id]
        intersection_embeddings = {cat: subject_embeddings[cat] for cat in available_cats}
        
        # Compute RDM with NA for missing categories in intersection
        rdm, mask, _ = compute_subject_rdm_with_na(intersection_embeddings, intersecting_categories_ordered)
        
        if rdm is not None:
            subject_rdms_intersection[subject_id] = rdm
            subject_rdm_masks_intersection[subject_id] = mask
            subject_rdm_categories_intersection[subject_id] = available_cats
    
    print(f"  Computed intersection-based RDMs for {len(subject_rdms_intersection)} subjects")
    print(f"    (Excluded {len(subject_rdms) - len(subject_rdms_intersection)} subjects without all intersecting categories)")
    
    # Check which subjects were excluded
    excluded_subjects = set(subject_rdms.keys()) - set(subject_rdms_intersection.keys())
    if excluded_subjects:
        print(f"    Excluded subjects: {sorted(excluded_subjects)}")
    
    # Store results for this threshold
    all_intersection_rdms[min_subjects] = subject_rdms_intersection
    all_intersection_masks[min_subjects] = subject_rdm_masks_intersection
    all_intersection_categories[min_subjects] = subject_rdm_categories_intersection
    all_intersection_category_lists[min_subjects] = intersecting_categories_ordered

print(f"\n{'='*60}")
print(f"Completed processing {len(all_intersection_rdms)} threshold(s)")
print(f"{'='*60}")


In [None]:
# Visualize all intersection-based RDMs for each threshold
# Iterate through each threshold
for min_subjects in sorted(all_intersection_rdms.keys()):
    print(f"\n{'='*60}")
    print(f"Visualizing RDMs for threshold: {min_subjects} subjects")
    print(f"{'='*60}")
    
    subject_rdms_intersection = all_intersection_rdms[min_subjects]
    subject_rdm_masks_intersection = all_intersection_masks[min_subjects]
    subject_rdm_categories_intersection = all_intersection_categories[min_subjects]
    intersecting_categories_ordered = all_intersection_category_lists[min_subjects]
    
    n_subjects_intersection = len(subject_rdms_intersection)
    subject_ids_intersection = list(subject_rdms_intersection.keys())
    
    if n_subjects_intersection > 0:
        # Calculate grid dimensions
        n_cols = 6
        n_rows = int(np.ceil(n_subjects_intersection / n_cols))
        
        # Create figure with appropriate size
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3))
        # Handle axes flattening properly (works for both 1D and 2D arrays)
        if n_rows == 1:
            axes = axes if isinstance(axes, np.ndarray) else [axes]
        else:
            axes = axes.flatten()
        
        # Find global min/max for consistent color scale across all RDMs (excluding NaN)
        all_rdm_values = []
        for rdm in subject_rdms_intersection.values():
            valid_values = rdm[~np.isnan(rdm)]
            if len(valid_values) > 0:
                all_rdm_values.extend(valid_values)
        vmin = np.percentile(all_rdm_values, 1) if len(all_rdm_values) > 0 else 0
        vmax = np.percentile(all_rdm_values, 99) if len(all_rdm_values) > 0 else 2
        
        print(f"Plotting {n_subjects_intersection} intersection-based RDMs...")
        print(f"Color scale range: [{vmin:.4f}, {vmax:.4f}]")
        print(f"Categories in intersection: {len(intersecting_categories_ordered)}")
        
        for idx, subject_id in enumerate(subject_ids_intersection):
            rdm = subject_rdms_intersection[subject_id]
            mask = subject_rdm_masks_intersection[subject_id]
            available_cats = subject_rdm_categories_intersection[subject_id]
            
            ax = axes[idx]
            
            # Create a masked array for visualization
            rdm_masked = np.ma.masked_where(mask, rdm)
            cmap = plt.cm.get_cmap('viridis').copy()
            cmap.set_bad(color='white', alpha=1.0)
            im = ax.imshow(rdm_masked, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
            
            ax.set_title(f"{subject_id}\n({len(available_cats)}/{len(intersecting_categories_ordered)} cats)", 
                        fontsize=9, pad=5)
            ax.set_xlabel('Category', fontsize=7)
            ax.set_ylabel('Category', fontsize=7)
            ax.tick_params(labelsize=6)
            
            # Set ticks and labels to show all intersecting categories
            ax.set_xticks(range(len(intersecting_categories_ordered)))
            ax.set_yticks(range(len(intersecting_categories_ordered)))
            ax.set_xticklabels(intersecting_categories_ordered, rotation=90, ha='right', fontsize=4)
            ax.set_yticklabels(intersecting_categories_ordered, fontsize=4)
            
            # Add colorbar for each subplot
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        
        # Hide unused subplots
        for idx in range(n_subjects_intersection, len(axes)):
            axes[idx].axis('off')
        
        plt.suptitle(f'Intersection-Based RDMs (n={n_subjects_intersection}, {len(intersecting_categories_ordered)} shared categories, threshold={min_subjects})', 
                    fontsize=16, y=0.995)
        plt.tight_layout(rect=[0, 0, 1, 0.99])
        output_filename = f"all_individual_rdms_intersection_threshold_{min_subjects}.png"
        plt.savefig(output_dir / output_filename, dpi=200, bbox_inches='tight')
        print(f"\nSaved intersection-based RDM visualization to {output_dir / output_filename}")
        plt.close()
        
        # Also create a version with coolwarm colormap
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3))
        # Handle axes flattening properly (works for both 1D and 2D arrays)
        if n_rows == 1:
            axes = axes if isinstance(axes, np.ndarray) else [axes]
        else:
            axes = axes.flatten()
        
        for idx, subject_id in enumerate(subject_ids_intersection):
            rdm = subject_rdms_intersection[subject_id]
            mask = subject_rdm_masks_intersection[subject_id]
            available_cats = subject_rdm_categories_intersection[subject_id]
            
            ax = axes[idx]
            
            rdm_masked = np.ma.masked_where(mask, rdm)
            cmap = plt.cm.get_cmap('coolwarm').copy()
            cmap.set_bad(color='white', alpha=1.0)
            im = ax.imshow(rdm_masked, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
            
            ax.set_title(f"{subject_id}\n({len(available_cats)}/{len(intersecting_categories_ordered)} cats)", 
                        fontsize=9, pad=5)
            ax.set_xlabel('Category', fontsize=7)
            ax.set_ylabel('Category', fontsize=7)
            ax.tick_params(labelsize=6)
            
            ax.set_xticks(range(len(intersecting_categories_ordered)))
            ax.set_yticks(range(len(intersecting_categories_ordered)))
            ax.set_xticklabels(intersecting_categories_ordered, rotation=90, ha='right', fontsize=4)
            ax.set_yticklabels(intersecting_categories_ordered, fontsize=4)
            
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        
        # Hide unused subplots
        for idx in range(n_subjects_intersection, len(axes)):
            axes[idx].axis('off')
        
        plt.suptitle(f'Intersection-Based RDMs (n={n_subjects_intersection}, threshold={min_subjects}) - Coolwarm', 
                    fontsize=16, y=0.995)
        plt.tight_layout(rect=[0, 0, 1, 0.99])
        output_filename = f"all_individual_rdms_intersection_threshold_{min_subjects}_coolwarm.png"
        plt.savefig(output_dir / output_filename, dpi=200, bbox_inches='tight')
        print(f"Saved intersection-based RDM visualization (coolwarm) to {output_dir / output_filename}")
        plt.close()
        
        # Save individual intersection-based RDMs
        print(f"\nSaving individual intersection-based RDM plots for threshold {min_subjects}...")
        intersection_rdm_dir = output_dir / f"individual_rdm_plots_intersection_threshold_{min_subjects}"
        intersection_rdm_dir.mkdir(exist_ok=True, parents=True)
        
        for subject_id in tqdm(subject_rdms_intersection.keys(), desc=f"Saving intersection RDMs (threshold={min_subjects})"):
            rdm = subject_rdms_intersection[subject_id]
            mask = subject_rdm_masks_intersection[subject_id]
            available_cats = subject_rdm_categories_intersection[subject_id]
            
            # Determine figure size
            n_cats = len(intersecting_categories_ordered)
            fig_size = max(10, n_cats * 0.3)
            
            # Set font size
            if n_cats <= 50:
                label_fontsize = 12
                tick_fontsize = 20
            elif n_cats <= 100:
                label_fontsize = 10
                tick_fontsize = 18
            else:
                label_fontsize = 8
                tick_fontsize = 16
            
            # Create figure
            fig, ax = plt.subplots(figsize=(fig_size, fig_size))
            
            rdm_masked = np.ma.masked_where(mask, rdm)
            cmap = plt.cm.get_cmap('viridis').copy()
            cmap.set_bad(color='white', alpha=1.0)
            im = ax.imshow(rdm_masked, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
            
            ax.set_xticks(range(len(intersecting_categories_ordered)))
            ax.set_yticks(range(len(intersecting_categories_ordered)))
            ax.set_xticklabels(intersecting_categories_ordered, rotation=90, ha='right', fontsize=tick_fontsize)
            ax.set_yticklabels(intersecting_categories_ordered, fontsize=tick_fontsize)
            
            ax.set_xlabel('Category', fontsize=label_fontsize)
            ax.set_ylabel('Category', fontsize=label_fontsize)
            ax.set_title(f'RDM (Intersection, threshold={min_subjects}): {subject_id}\n({len(available_cats)}/{n_cats} categories)', 
                        fontsize=14, pad=10)
            
            cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
            cbar.set_label('Distance (1 - Cosine Similarity)', fontsize=12)
            
            plt.tight_layout()
            
            output_path_png = intersection_rdm_dir / f"rdm_intersection_threshold_{min_subjects}_{subject_id}.png"
            plt.savefig(output_path_png, dpi=200, bbox_inches='tight')
            plt.close()
        
        print(f"\nSaved {len(subject_rdms_intersection)} intersection-based RDM plots to {intersection_rdm_dir}")
        
        # Save intersection-based RDMs as files
        # Get number of overlapping categories for this threshold
        n_overlapping_cats = len(intersecting_categories_ordered)
        print(f"\nSaving intersection-based RDMs to files for threshold {min_subjects} (n_categories={n_overlapping_cats})...")
        for subject_id, rdm in tqdm(subject_rdms_intersection.items(), desc=f"Saving RDMs (threshold={min_subjects})"):
            available_cats = subject_rdm_categories_intersection[subject_id]
            
            # Save as numpy array - include threshold and number of overlapping categories in filename
            np.save(npy_dir / f"rdm_intersection_threshold_{min_subjects}_ncat{n_overlapping_cats}_{subject_id}.npy", rdm)
            
            # Save as CSV - include threshold and number of overlapping categories in filename
            rdm_df = pd.DataFrame(rdm, index=intersecting_categories_ordered, columns=intersecting_categories_ordered)
            rdm_df.to_csv(csv_dir / f"rdm_intersection_threshold_{min_subjects}_ncat{n_overlapping_cats}_{subject_id}.csv")
            
            # Save metadata - include threshold and number of overlapping categories in filename
            valid_rdm = rdm[~np.isnan(rdm)]
            valid_rdm_positive = valid_rdm[valid_rdm > 0]
            
            metadata = {
                'subject_id': subject_id,
                'threshold_min_subjects': min_subjects,
                'n_categories_total': len(intersecting_categories_ordered),
                'n_categories_available': len(available_cats),
                'n_categories_missing': len(intersecting_categories_ordered) - len(available_cats),
                'mean_distance': float(np.nanmean(rdm)),
                'std_distance': float(np.nanstd(rdm)),
                'min_distance': float(valid_rdm_positive.min()) if len(valid_rdm_positive) > 0 else np.nan,
                'max_distance': float(np.nanmax(rdm))
            }
            
            metadata_df = pd.DataFrame([metadata])
            metadata_df.to_csv(csv_dir / f"metadata_intersection_threshold_{min_subjects}_ncat{n_overlapping_cats}_{subject_id}.csv", index=False)
        
        print(f"Saved intersection-based RDMs (threshold {min_subjects}, n_categories={n_overlapping_cats}) to {output_dir}")
        print(f"  CSV files: {csv_dir}")
        print(f"  NPY files: {npy_dir}")
    
    else:
        print(f"No subjects have sufficient intersecting categories for threshold {min_subjects}.")

print(f"\n{'='*60}")
print(f"Completed visualization for all thresholds")
print(f"{'='*60}")
