# Individual Subject RDM Analysis

This notebook creates Representational Dissimilarity Matrices (RDMs) for each individual subject.
Each subject's RDM shows the similarity structure of object categories based on their averaged embeddings.

## Overview

This analysis:
1. Loads grouped embeddings (averaged by category, subject, and age_mo)
2. Aggregates embeddings per subject across all age_mo (weighted average if multiple age bins)
3. Computes RDM for each subject using cosine distance
4. Handles data density differences between subjects
5. Visualizes and saves individual subject RDMs

## Key Features

- **Data density handling**: Subjects with more data get more reliable RDMs
- **Missing category handling**: Only includes categories present for each subject
- **Weighted averaging**: If a subject has multiple age_mo bins, embeddings are weighted by sample count


## Setup and Imports


In [25]:
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity, cosine_distances
from scipy.cluster.hierarchy import linkage, dendrogram, optimal_leaf_ordering
from scipy.spatial.distance import squareform
from collections import defaultdict
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set matplotlib backend
import matplotlib
matplotlib.use('Agg')

print("All imports successful!")


All imports successful!


## Configuration


In [26]:
# CDI words CSV file (required for category type organization)
cdi_path = Path("../../data/cdi_words.csv")

# Hierarchical clustering options
use_clustering = True  # Enable hierarchical clustering within category groups
save_dendrograms = True  # Save dendrogram plots for each category group

print(f"CDI path: {cdi_path}")
print(f"Use clustering: {use_clustering}")

CDI path: ../../data/cdi_words.csv
Use clustering: True


In [27]:
def load_category_types(cdi_path):
    """Load category type information from CDI words CSV"""
    print(f"Loading category types from {cdi_path}...")
    cdi_df = pd.read_csv(cdi_path)
    
    category_types = {}
    for _, row in cdi_df.iterrows():
        category_types[row['uni_lemma']] = {
            'is_animate': bool(row.get('is_animate', 0)),
            'is_bodypart': bool(row.get('is_bodypart', 0)),
            'is_small': bool(row.get('is_small', 0)),
            'is_big': bool(row.get('is_big', 0))
        }
    
    print(f"Loaded type information for {len(category_types)} categories")
    return category_types

def cluster_categories_within_group(group_categories, cat_to_embedding, save_dendrogram=False, output_dir=None, group_name=None):
    """
    Perform hierarchical clustering within a group of categories.
    
    Args:
        group_categories: List of category names in the group
        cat_to_embedding: Dictionary mapping category names to embeddings
        save_dendrogram: Whether to save dendrogram plot (default: False)
        output_dir: Output directory for saving dendrogram (required if save_dendrogram=True)
        group_name: Name of the group for saving dendrogram (required if save_dendrogram=True)
    
    Returns:
        List of category names reordered according to clustering dendrogram
    """
    if len(group_categories) <= 1:
        return group_categories, None
    
    # Get embeddings for this group
    group_embeddings = np.array([cat_to_embedding[cat].flatten() for cat in group_categories])
    
    # Normalize embeddings (z-score normalization per embedding)
    normalized_embeddings = (group_embeddings - group_embeddings.mean(axis=0)) / (group_embeddings.std(axis=0) + 1e-10)
    
    # Compute distance matrix (1 - cosine similarity)
    similarity_matrix = cosine_similarity(normalized_embeddings)
    distance_matrix = 1 - similarity_matrix
    np.fill_diagonal(distance_matrix, 0)
    
    # Convert to condensed form for linkage
    condensed_distances = squareform(distance_matrix)
    
    # Perform hierarchical clustering
    linkage_matrix = linkage(condensed_distances, method='ward')
    
    # Get optimal leaf ordering for better visualization
    try:
        linkage_matrix = optimal_leaf_ordering(linkage_matrix, condensed_distances)
    except:
        # If optimal leaf ordering fails, use original linkage
        pass
    
    # Extract the order from the dendrogram
    dendro_dict = dendrogram(linkage_matrix, no_plot=True)
    leaf_order = dendro_dict['leaves']
    
    # Reorder categories according to clustering
    clustered_categories = [group_categories[i] for i in leaf_order]
    
    # Save dendrogram if requested
    if save_dendrogram and output_dir is not None and group_name is not None:
        output_dir = Path(output_dir)
        output_dir.mkdir(exist_ok=True, parents=True)
        
        plt.figure(figsize=(12, 8))
        dendrogram(linkage_matrix, 
                  labels=group_categories,
                  leaf_rotation=90,
                  leaf_font_size=10)
        plt.title(f'Hierarchical Clustering Dendrogram: {group_name.upper()}\\n({len(group_categories)} categories)',
                 fontsize=16, pad=20)
        plt.xlabel('Category', fontsize=12)
        plt.ylabel('Distance', fontsize=12)
        plt.tight_layout()
        
        # Save as PNG
        output_path_png = output_dir / f'dendrogram_{group_name}.png'
        plt.savefig(output_path_png, dpi=300, bbox_inches='tight', pad_inches=0.2)
        print(f"    Saved dendrogram to {output_path_png}")
        
        # Save as PDF
        output_path_pdf = output_dir / f'dendrogram_{group_name}.pdf'
        plt.savefig(output_path_pdf, bbox_inches='tight', pad_inches=0.2)
        print(f"    Saved dendrogram to {output_path_pdf}")
        
        plt.close()
    
    return clustered_categories, linkage_matrix

print("Helper functions loaded!")

Helper functions loaded!


In [28]:
# Paths
embeddings_dir = Path("/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/clip_embeddings_grouped_by_age-mo")
output_dir = Path("individual_subject_rdms")
output_dir.mkdir(exist_ok=True, parents=True)

# Categories file (optional - to filter to specific categories)
categories_file = Path("../../data/things_bv_overlap_categories_exclude_zero_precisions.txt")

# Minimum categories required per subject to compute RDM
min_categories_per_subject = 10

# Whether to weight by sample count when aggregating across age_mo
weight_by_sample_count = True

print(f"Embeddings directory: {embeddings_dir}")
print(f"Output directory: {output_dir}")
print(f"Min categories per subject: {min_categories_per_subject}")


Embeddings directory: /data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/clip_embeddings_grouped_by_age-mo
Output directory: individual_subject_rdms
Min categories per subject: 10


## Load Category List (Optional)


In [29]:
# Load allowed categories if file exists
allowed_categories = None
if categories_file.exists():
    print(f"Loading categories from {categories_file}...")
    with open(categories_file, 'r') as f:
        allowed_categories = set(line.strip() for line in f if line.strip())
    print(f"Loaded {len(allowed_categories)} categories")
else:
    print(f"Categories file not found, using all categories")


Loading categories from ../../data/things_bv_overlap_categories_exclude_zero_precisions.txt...
Loaded 163 categories


## Load Embeddings


In [30]:
def load_subject_embeddings(embeddings_dir, allowed_categories=None):
    """
    Load embeddings organized by subject and category.
    
    Returns:
        subject_embeddings: dict[subject_id][category] = {
            'embedding': np.array,  # averaged embedding
            'age_mo': int,  # age in months
            'sample_count': int  # number of embeddings averaged
        }
    """
    subject_embeddings = defaultdict(lambda: defaultdict(dict))
    
    # Get all category folders
    category_folders = [f for f in embeddings_dir.iterdir() if f.is_dir()]
    
    if allowed_categories:
        category_folders = [f for f in category_folders if f.name in allowed_categories]
    
    print(f"Loading embeddings from {len(category_folders)} categories...")
    
    for category_folder in tqdm(category_folders, desc="Loading categories"):
        category = category_folder.name
        
        # Get all embedding files in this category
        embedding_files = list(category_folder.glob("*.npy"))
        
        for emb_file in embedding_files:
            # Parse filename: {subject_id}_{age_mo}_month_level_avg.npy
            filename = emb_file.stem  # without .npy
            parts = filename.split('_')
            
            if len(parts) < 2:
                continue
            
            # Extract subject_id and age_mo
            # Format: S00560001_16_month_level_avg
            subject_id = parts[0]
            age_mo = int(parts[1]) if parts[1].isdigit() else None
            
            if age_mo is None:
                continue
            
            try:
                embedding = np.load(emb_file)
                
                # Store embedding with metadata
                # If subject already has this category at this age_mo, we'll average them later
                key = (category, age_mo)
                if key not in subject_embeddings[subject_id][category]:
                    subject_embeddings[subject_id][category][key] = {
                        'embedding': embedding,
                        'age_mo': age_mo,
                        'sample_count': 1  # We don't have this info, assume 1
                    }
                else:
                    # Average if duplicate
                    existing = subject_embeddings[subject_id][category][key]
                    n_existing = existing['sample_count']
                    n_new = 1
                    existing['embedding'] = (existing['embedding'] * n_existing + embedding) / (n_existing + n_new)
                    existing['sample_count'] += n_new
            except Exception as e:
                print(f"Error loading {emb_file}: {e}")
                continue
    
    return subject_embeddings

# Load embeddings
subject_embeddings_raw = load_subject_embeddings(embeddings_dir, allowed_categories)
print(f"\nLoaded embeddings for {len(subject_embeddings_raw)} subjects")


Loading embeddings from 163 categories...


Loading categories: 100%|██████████| 163/163 [00:01<00:00, 121.29it/s]


Loaded embeddings for 32 subjects





## Aggregate Embeddings Per Subject


In [31]:
def aggregate_subject_embeddings(subject_embeddings_raw, weight_by_sample_count=True):
    """
    Aggregate embeddings per subject across age_mo bins.
    
    For each subject-category pair, if there are multiple age_mo bins,
    compute a weighted average.
    """
    subject_embeddings_agg = {}
    
    for subject_id, categories in tqdm(subject_embeddings_raw.items(), desc="Aggregating subjects"):
        subject_embeddings_agg[subject_id] = {}
        
        for category, age_data in categories.items():
            # Get all age_mo bins for this subject-category pair
            embeddings_list = []
            weights_list = []
            
            for (cat, age_mo), data in age_data.items():
                embeddings_list.append(data['embedding'])
                if weight_by_sample_count:
                    weights_list.append(data['sample_count'])
                else:
                    weights_list.append(1.0)
            
            if len(embeddings_list) == 0:
                continue
            
            # Compute weighted average
            embeddings_array = np.array(embeddings_list)
            weights_array = np.array(weights_list)
            weights_array = weights_array / weights_array.sum()  # Normalize weights
            
            # Weighted average
            aggregated = np.average(embeddings_array, axis=0, weights=weights_array)
            
            subject_embeddings_agg[subject_id][category] = aggregated
    
    return subject_embeddings_agg

# Aggregate embeddings
subject_embeddings = aggregate_subject_embeddings(
    subject_embeddings_raw, 
    weight_by_sample_count=weight_by_sample_count
)

print(f"\nAggregated embeddings for {len(subject_embeddings)} subjects")

# Show category counts per subject
category_counts = {sid: len(cats) for sid, cats in subject_embeddings.items()}
print(f"\nCategory counts per subject:")
print(f"  Min: {min(category_counts.values())}")
print(f"  Max: {max(category_counts.values())}")
print(f"  Mean: {np.mean(list(category_counts.values())):.1f}")
print(f"  Median: {np.median(list(category_counts.values())):.1f}")


Aggregating subjects: 100%|██████████| 32/32 [00:00<00:00, 235.56it/s]


Aggregated embeddings for 32 subjects

Category counts per subject:
  Min: 55
  Max: 162
  Mean: 145.3
  Median: 154.5





## Normalize Embeddings

Before computing RDMs, we normalize embeddings using z-score normalization (mean=0, std=1) to ensure fair comparisons.


In [32]:
# NOTE: Category organization code has been moved to after normalization.
# This cell is kept for reference but is no longer executed.
# See the normalization cell for the organization code.

In [33]:
# Global normalization: normalize across ALL embeddings from ALL subjects
print("Computing global normalization statistics across all subjects...")

# Collect all embeddings from all subjects
all_embeddings_list = []
for subject_id, categories in subject_embeddings.items():
    for cat, embedding in categories.items():
        all_embeddings_list.append(embedding)

# Stack all embeddings
all_embeddings_matrix = np.array(all_embeddings_list)
print(f"  Collected {len(all_embeddings_list)} embeddings from {len(subject_embeddings)} subjects")

# Compute global mean and std across all embeddings
# Flatten embeddings first to ensure consistent 1D shape
all_embeddings_matrix_flat = np.array([emb.flatten() for emb in all_embeddings_list])
global_mean = all_embeddings_matrix_flat.mean(axis=0)
global_std = all_embeddings_matrix_flat.std(axis=0) + 1e-10  # Add small epsilon to avoid division by zero

print(f"  Global mean shape: {global_mean.shape}")
print(f"  Global std shape: {global_std.shape}")
print(f"  Global mean range: [{global_mean.min():.4f}, {global_mean.max():.4f}]")
print(f"  Global std range: [{global_std.min():.4f}, {global_std.max():.4f}]")

# Apply global normalization to each subject's embeddings
print("\nApplying global normalization to each subject...")
subject_embeddings_normalized = {}

for subject_id, categories in tqdm(subject_embeddings.items(), desc="Normalizing"):
    subject_embeddings_normalized[subject_id] = {}
    
    for cat, embedding in categories.items():
        # Apply global normalization: (x - global_mean) / global_std
        normalized_embedding = (embedding - global_mean) / global_std
        # Flatten to ensure 1D array (in case embedding has shape (1, 512) instead of (512,))
        normalized_embedding = normalized_embedding.flatten()
        subject_embeddings_normalized[subject_id][cat] = normalized_embedding

print(f"Normalized embeddings for {len(subject_embeddings_normalized)} subjects using global statistics")

## Organize Categories and Apply Hierarchical Clustering

# Load category types for organization
if cdi_path.exists():
    category_types = load_category_types(cdi_path)
else:
    print(f"Warning: CDI path {cdi_path} not found. Skipping category organization.")
    category_types = {}

# Get all unique categories across all subjects (needed for organization)
all_categories = set()
for subject_id, categories in subject_embeddings_normalized.items():
    all_categories.update(categories.keys())

all_categories = sorted(list(all_categories))
print(f"Total unique categories across all subjects: {len(all_categories)}")

# Organize categories by broad types and apply hierarchical clustering
print("\nOrganizing categories by type and applying hierarchical clustering...")

# Get a representative set of embeddings for clustering (use first subject with most categories)
# We'll use the average embeddings across all subjects for each category
representative_embeddings = {}
for cat in all_categories:
    cat_embeddings = []
    for subject_id, categories in subject_embeddings_normalized.items():
        if cat in categories:
            cat_embeddings.append(categories[cat])
    if len(cat_embeddings) > 0:
        # Average across subjects for this category
        representative_embeddings[cat] = np.mean(cat_embeddings, axis=0)

# Organize by type
organized = {
    'animals': [],
    'bodyparts': [],
    'big_objects': [],
    'small_objects': [],
    'others': []
}

for cat in all_categories:
    if cat not in category_types:
        organized['others'].append(cat)
        continue
    
    types = category_types[cat]
    if types['is_animate']:
        organized['animals'].append(cat)
    elif types['is_bodypart']:
        organized['bodyparts'].append(cat)
    elif types['is_big']:
        organized['big_objects'].append(cat)
    elif types['is_small']:
        organized['small_objects'].append(cat)
    else:
        organized['others'].append(cat)

print(f"  Organized into: {len(organized['animals'])} animals, {len(organized['bodyparts'])} bodyparts, "
      f"{len(organized['big_objects'])} big objects, {len(organized['small_objects'])} small objects, "
      f"{len(organized['others'])} others")

# Apply hierarchical clustering within each group
if use_clustering:
    print("\nApplying hierarchical clustering within groups...")
    for key in organized:
        if len(organized[key]) > 1:
            # Filter to categories that have representative embeddings
            group_cats = [cat for cat in organized[key] if cat in representative_embeddings]
            if len(group_cats) > 1:
                print(f"  Clustering {key} ({len(group_cats)} categories)...")
                organized[key], _ = cluster_categories_within_group(
                    group_cats,
                    representative_embeddings,
                    save_dendrogram=save_dendrograms,
                    output_dir=output_dir,
                    group_name=key
                )
            else:
                organized[key] = group_cats
        else:
            organized[key] = [cat for cat in organized[key] if cat in representative_embeddings]
else:
    for key in organized:
        organized[key] = sorted([cat for cat in organized[key] if cat in representative_embeddings])

# Create ordered list of categories
ordered_categories = (
    organized['animals'] +
    organized['bodyparts'] +
    organized['big_objects'] +
    organized['small_objects'] +
    organized['others']
)

print(f"\nFinal ordered category list: {len(ordered_categories)} categories")

def compute_subject_rdm(subject_embeddings_dict, categories_list):
    """
    Compute RDM for a single subject.
    
    Args:
        subject_embeddings_dict: dict[category] = embedding array (should be normalized)
        categories_list: list of categories to include (in order)
    
    Returns:
        rdm: numpy array of shape (n_categories, n_categories)
        available_categories: list of categories actually present
    """
    # Filter to categories that exist for this subject
    available_categories = [cat for cat in categories_list if cat in subject_embeddings_dict]
    
    if len(available_categories) < 2:
        return None, available_categories
    
    # Build embedding matrix (already normalized)
    # Flatten each embedding to ensure 1D (in case they have shape (1, 512) instead of (512,))
    embedding_matrix = np.array([subject_embeddings_dict[cat].flatten() for cat in available_categories])
    
    # Ensure 2D shape: (n_categories, embedding_dim)
    if embedding_matrix.ndim != 2:
        raise ValueError(f"Expected 2D embedding matrix, got shape {embedding_matrix.shape}")
    
    # Compute cosine similarity
    similarity_matrix = cosine_similarity(embedding_matrix)
    
    # Convert to distance (RDM)
    distance_matrix = 1 - similarity_matrix
    np.fill_diagonal(distance_matrix, 0)  # Ensure diagonal is 0
    
    # Make symmetric (in case of numerical errors)
    distance_matrix = (distance_matrix + distance_matrix.T) / 2
    
    return distance_matrix, available_categories

# Compute RDMs for each subject using normalized embeddings
# Note: all_categories is already defined in the organization cell above
subject_rdms = {}
subject_rdm_categories = {}

for subject_id, categories in tqdm(subject_embeddings_normalized.items(), desc="Computing RDMs"):
    if len(categories) < min_categories_per_subject:
        continue
    
    rdm, available_cats = compute_subject_rdm(categories, all_categories)
    
    if rdm is not None:
        subject_rdms[subject_id] = rdm
        subject_rdm_categories[subject_id] = available_cats

print(f"\nComputed RDMs for {len(subject_rdms)} subjects")
print(f"  (Excluded {len(subject_embeddings_normalized) - len(subject_rdms)} subjects with < {min_categories_per_subject} categories)")

# Reorganize each subject's RDM according to the new ordering
print("\nReorganizing individual subject RDMs according to new category ordering...")
subject_rdms_reorganized = {}
subject_rdm_categories_reorganized = {}
subject_group_boundaries = {}  # Store group boundaries for visual separators

for subject_id in tqdm(subject_rdms.keys(), desc="Reorganizing RDMs"):
    rdm = subject_rdms[subject_id]
    available_cats = subject_rdm_categories[subject_id]
    
    # Create mapping from old index to new index (only for categories present in this subject)
    old_to_new_index = {cat: i for i, cat in enumerate(available_cats)}
    
    # Get the ordered list of categories for this subject (subset of ordered_categories)
    subject_ordered_cats = [cat for cat in ordered_categories if cat in available_cats]
    
    # Create new indices for reorganized RDM
    new_indices = [available_cats.index(cat) for cat in subject_ordered_cats]
    
    # Reorganize the RDM
    rdm_reorganized = rdm[np.ix_(new_indices, new_indices)]
    
    # Compute group boundaries for this subject
    group_boundaries = []
    current_idx = 0
    for group_name in ['animals', 'bodyparts', 'big_objects', 'small_objects', 'others']:
        group_cats = [cat for cat in organized[group_name] if cat in subject_ordered_cats]
        if len(group_cats) > 0:
            group_start = current_idx
            group_end = current_idx + len(group_cats)
            group_boundaries.append({
                'name': group_name,
                'start': group_start,
                'end': group_end,
                'categories': group_cats
            })
            current_idx = group_end
    
    subject_rdms_reorganized[subject_id] = rdm_reorganized
    subject_rdm_categories_reorganized[subject_id] = subject_ordered_cats
    subject_group_boundaries[subject_id] = group_boundaries

# Update the main dictionaries
subject_rdms = subject_rdms_reorganized
subject_rdm_categories = subject_rdm_categories_reorganized

print(f"Reorganized RDMs for {len(subject_rdms)} subjects")


Computing global normalization statistics across all subjects...
  Collected 4651 embeddings from 32 subjects
  Global mean shape: (512,)
  Global std shape: (512,)
  Global mean range: [-9.0121, 1.2518]
  Global std range: [0.0573, 0.6508]

Applying global normalization to each subject...


Normalizing:   0%|          | 0/32 [00:00<?, ?it/s]

Normalizing: 100%|██████████| 32/32 [00:00<00:00, 2028.25it/s]


Normalized embeddings for 32 subjects using global statistics
Loading category types from ../../data/cdi_words.csv...
Loaded type information for 295 categories
Total unique categories across all subjects: 163

Organizing categories by type and applying hierarchical clustering...
  Organized into: 19 animals, 14 bodyparts, 32 big objects, 96 small objects, 2 others

Applying hierarchical clustering within groups...
  Clustering animals (19 categories)...
    Saved dendrogram to individual_subject_rdms/dendrogram_animals.png
    Saved dendrogram to individual_subject_rdms/dendrogram_animals.pdf
  Clustering bodyparts (14 categories)...
    Saved dendrogram to individual_subject_rdms/dendrogram_bodyparts.png
    Saved dendrogram to individual_subject_rdms/dendrogram_bodyparts.pdf
  Clustering big_objects (32 categories)...
    Saved dendrogram to individual_subject_rdms/dendrogram_big_objects.png
    Saved dendrogram to individual_subject_rdms/dendrogram_big_objects.pdf
  Clustering smal

Computing RDMs: 100%|██████████| 32/32 [00:00<00:00, 1342.94it/s]



Computed RDMs for 32 subjects
  (Excluded 0 subjects with < 10 categories)

Reorganizing individual subject RDMs according to new category ordering...


Reorganizing RDMs: 100%|██████████| 32/32 [00:00<00:00, 2988.53it/s]

Reorganized RDMs for 32 subjects





## Compute Individual Subject RDMs


## Save Individual Subject RDMs


In [None]:
## Visualize All Individual Subject RDMs

# Plot all individual subject RDMs in a grid
n_subjects = len(subject_rdms)
subject_ids = list(subject_rdms.keys())

# Calculate grid dimensions
n_cols = 6  # Number of columns
n_rows = int(np.ceil(n_subjects / n_cols))

# Create figure with appropriate size
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3))
axes = axes.flatten() if n_subjects > 1 else [axes]

# Find global min/max for consistent color scale across all RDMs
all_rdm_values = []
for rdm in subject_rdms.values():
    all_rdm_values.extend(rdm.flatten())
vmin = np.percentile(all_rdm_values, 1)  # Use 1st percentile to exclude outliers
vmax = np.percentile(all_rdm_values, 99)  # Use 99th percentile to exclude outliers

print(f"Plotting {n_subjects} individual subject RDMs...")
print(f"Color scale range: [{vmin:.4f}, {vmax:.4f}]")

for idx, subject_id in enumerate(subject_ids):
    rdm = subject_rdms[subject_id]
    categories = subject_rdm_categories[subject_id]
    
    ax = axes[idx]
    im = ax.imshow(rdm, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
    ax.set_title(f"{subject_id}\n({len(categories)} cats)", fontsize=9, pad=5)
    ax.set_xlabel('Category', fontsize=7)
    ax.set_ylabel('Category', fontsize=7)
    ax.tick_params(labelsize=6)
    
    # Add colorbar for each subplot
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

# Hide unused subplots
for idx in range(n_subjects, len(axes)):
    axes[idx].axis('off')

plt.suptitle(f'All Individual Subject RDMs (n={n_subjects})', fontsize=16, y=0.995)
plt.tight_layout(rect=[0, 0, 1, 0.99])
plt.savefig(output_dir / "all_individual_rdms.png", dpi=200, bbox_inches='tight')
print(f"\nSaved all individual RDM visualization to {output_dir / 'all_individual_rdms.png'}")
plt.close()

# Also create a version with coolwarm colormap
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3))
axes = axes.flatten() if n_subjects > 1 else [axes]

for idx, subject_id in enumerate(subject_ids):
    rdm = subject_rdms[subject_id]
    categories = subject_rdm_categories[subject_id]
    
    ax = axes[idx]
    im = ax.imshow(rdm, cmap='coolwarm', aspect='auto', vmin=vmin, vmax=vmax)
    ax.set_title(f"{subject_id}\n({len(categories)} cats)", fontsize=9, pad=5)
    ax.set_xlabel('Category', fontsize=7)
    ax.set_ylabel('Category', fontsize=7)
    ax.tick_params(labelsize=6)
    
    # Add colorbar for each subplot
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)

# Hide unused subplots
for idx in range(n_subjects, len(axes)):
    axes[idx].axis('off')

plt.suptitle(f'All Individual Subject RDMs (n={n_subjects}) - Coolwarm', fontsize=16, y=0.995)
plt.tight_layout(rect=[0, 0, 1, 0.99])
plt.savefig(output_dir / "all_individual_rdms_coolwarm.png", dpi=200, bbox_inches='tight')
print(f"Saved all individual RDM visualization (coolwarm) to {output_dir / 'all_individual_rdms_coolwarm.png'}")
plt.close()

# Save each individual RDM separately with category names as axis labels
print("\nSaving individual RDM plots...")
individual_rdm_dir = output_dir / "individual_rdm_plots"
individual_rdm_dir.mkdir(exist_ok=True, parents=True)

# Find global min/max for consistent color scale
all_rdm_values = []
for rdm in subject_rdms.values():
    all_rdm_values.extend(rdm.flatten())
vmin = np.percentile(all_rdm_values, 1)
vmax = np.percentile(all_rdm_values, 99)

for subject_id in tqdm(subject_rdms.keys(), desc="Saving individual RDMs"):
    rdm = subject_rdms[subject_id]
    categories = subject_rdm_categories[subject_id]
    
    # Determine figure size based on number of categories
    n_cats = len(categories)
    fig_size = max(10, n_cats * 0.3)
    
    # Set font size for category labels (adaptive) - much larger for readability
    if n_cats <= 50:
        label_fontsize = 12
        tick_fontsize = 20
    elif n_cats <= 100:
        label_fontsize = 10
        tick_fontsize = 18
    else:
        label_fontsize = 8
        tick_fontsize = 16
    
    # Create figure with viridis colormap
    fig, ax = plt.subplots(figsize=(fig_size, fig_size))
    im = ax.imshow(rdm, cmap='viridis', aspect='auto', vmin=vmin, vmax=vmax)
    
    # Add visual separators between category groups
    group_boundaries = subject_group_boundaries.get(subject_id, [])
    for boundary in group_boundaries:
        # Draw vertical line
        if boundary['start'] > 0:  # Don't draw line at the very start
            ax.axvline(x=boundary['start'] - 0.5, color='white', linewidth=2, linestyle='--', alpha=0.7)
        # Draw horizontal line
        if boundary['start'] > 0:  # Don't draw line at the very start
            ax.axhline(y=boundary['start'] - 0.5, color='white', linewidth=2, linestyle='--', alpha=0.7)
    
    # Set category names as axis labels
    ax.set_xticks(range(len(categories)))
    ax.set_yticks(range(len(categories)))
    ax.set_xticklabels(categories, rotation=90, ha='right', fontsize=tick_fontsize)
    ax.set_yticklabels(categories, fontsize=tick_fontsize)
    
    ax.set_xlabel('Category', fontsize=label_fontsize)
    ax.set_ylabel('Category', fontsize=label_fontsize)
    ax.set_title(f'RDM: {subject_id}\n({n_cats} categories)', fontsize=14, pad=10)
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('Distance (1 - Cosine Similarity)', fontsize=12)
    
    plt.tight_layout()
    
    # Save as PNG only
    output_path_png = individual_rdm_dir / f"rdm_{subject_id}.png"
    plt.savefig(output_path_png, dpi=200, bbox_inches='tight')
    
    plt.close()
    
    # Create and save individual dendrogram for this subject
    if len(categories) > 1:
        # Get embeddings for this subject's categories
        subject_embeddings = {cat: subject_embeddings_normalized[subject_id][cat] 
                             for cat in categories if cat in subject_embeddings_normalized[subject_id]}
        
        if len(subject_embeddings) > 1:
            # Build embedding matrix
            embedding_matrix = np.array([subject_embeddings[cat].flatten() for cat in categories])
            
            # Normalize embeddings
            normalized_embeddings = (embedding_matrix - embedding_matrix.mean(axis=0)) / (embedding_matrix.std(axis=0) + 1e-10)
            
            # Compute distance matrix
            similarity_matrix = cosine_similarity(normalized_embeddings)
            distance_matrix = 1 - similarity_matrix
            np.fill_diagonal(distance_matrix, 0)
            
            # Convert to condensed form for linkage
            condensed_distances = squareform(distance_matrix)
            
            # Perform hierarchical clustering
            linkage_matrix = linkage(condensed_distances, method='ward')
            
            # Get optimal leaf ordering
            try:
                linkage_matrix = optimal_leaf_ordering(linkage_matrix, condensed_distances)
            except:
                pass
            
            # Create dendrogram
            plt.figure(figsize=(max(16, len(categories) * 0.5), 10))
            dendrogram(linkage_matrix, 
                      labels=categories,
                      leaf_rotation=90,
                      leaf_font_size=max(8, min(14, 200 // len(categories))))
            plt.title(f'Individual Subject Dendrogram: {subject_id}\n({len(categories)} categories)',
                     fontsize=16, pad=20)
            plt.xlabel('Category', fontsize=14)
            plt.ylabel('Distance', fontsize=14)
            plt.tight_layout()
            
            # Save dendrogram
            dendrogram_dir = individual_rdm_dir / "dendrograms"
            dendrogram_dir.mkdir(exist_ok=True, parents=True)
            dendrogram_path = dendrogram_dir / f"dendrogram_{subject_id}.png"
            plt.savefig(dendrogram_path, dpi=300, bbox_inches='tight', pad_inches=0.2)
            plt.close()

print(f"\nSaved {len(subject_rdms)} individual RDM plots to {individual_rdm_dir}")
print(f"  Each subject has 1 RDM file: rdm_{subject_id}.png")
print(f"  Individual dendrograms saved to {individual_rdm_dir / 'dendrograms'}")

Plotting 32 individual subject RDMs...
Color scale range: [0.2053, 1.4904]

Saved all individual RDM visualization to individual_subject_rdms/all_individual_rdms.png
Saved all individual RDM visualization (coolwarm) to individual_subject_rdms/all_individual_rdms_coolwarm.png

Saving individual RDM plots...


Saving individual RDMs:  69%|██████▉   | 22/32 [03:13<01:21,  8.19s/it]

In [None]:
# Save RDMs
print("Saving individual subject RDMs...")

for subject_id, rdm in tqdm(subject_rdms.items(), desc="Saving RDMs"):
    categories = subject_rdm_categories[subject_id]
    
    # Save as numpy array
    np.save(output_dir / f"rdm_{subject_id}.npy", rdm)
    
    # Save as CSV with category labels
    rdm_df = pd.DataFrame(rdm, index=categories, columns=categories)
    rdm_df.to_csv(output_dir / f"rdm_{subject_id}.csv")
    
    # Save metadata
    metadata = {
        'subject_id': subject_id,
        'n_categories': len(categories),
        'categories': categories,
        'mean_distance': float(rdm.mean()),
        'std_distance': float(rdm.std())
    }
    
    metadata_df = pd.DataFrame([metadata])
    metadata_df.to_csv(output_dir / f"metadata_{subject_id}.csv", index=False)

print(f"\nSaved RDMs to {output_dir}")


Saving individual subject RDMs...


Saving RDMs: 100%|██████████| 32/32 [00:01<00:00, 27.39it/s]


Saved RDMs to individual_subject_rdms





## Create Summary Statistics


In [None]:
# Create summary dataframe
summary_data = []

for subject_id, rdm in subject_rdms.items():
    categories = subject_rdm_categories[subject_id]
    
    summary_data.append({
        'subject_id': subject_id,
        'n_categories': len(categories),
        'mean_distance': float(rdm.mean()),
        'std_distance': float(rdm.std()),
        'min_distance': float(rdm[rdm > 0].min()) if (rdm > 0).any() else np.nan,
        'max_distance': float(rdm.max())
    })

summary_df = pd.DataFrame(summary_data)
summary_df = summary_df.sort_values('n_categories', ascending=False)
summary_df.to_csv(output_dir / "summary_statistics.csv", index=False)

print("Summary statistics:")
print(summary_df.describe())
print(f"\nSaved summary to {output_dir / 'summary_statistics.csv'}")


Summary statistics:
       n_categories  mean_distance  std_distance  min_distance  max_distance
count     32.000000      32.000000     32.000000     32.000000     32.000000
mean     145.343750       0.965700      0.267005      0.038118      1.607634
std       23.350913       0.018806      0.015772      0.017234      0.058626
min       55.000000       0.914233      0.239251      0.006570      1.408912
25%      140.750000       0.958276      0.249782      0.026638      1.575900
50%      154.500000       0.969105      0.270802      0.039205      1.623180
75%      160.000000       0.980300      0.279752      0.048356      1.636637
max      162.000000       0.990634      0.290838      0.092241      1.689233

Saved summary to individual_subject_rdms/summary_statistics.csv


## Visualize Sample RDMs


In [None]:
# Visualize a few sample RDMs
n_samples = min(6, len(subject_rdms))
sample_subjects = list(subject_rdms.keys())[:n_samples]

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for idx, subject_id in enumerate(sample_subjects):
    rdm = subject_rdms[subject_id]
    categories = subject_rdm_categories[subject_id]
    
    ax = axes[idx]
    im = ax.imshow(rdm, cmap='viridis', aspect='auto')
    ax.set_title(f"{subject_id}\n({len(categories)} categories)", fontsize=10)
    ax.set_xlabel('Category')
    ax.set_ylabel('Category')
    plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.savefig(output_dir / "sample_rdms.png", dpi=150, bbox_inches='tight')
print(f"Saved sample RDM visualization to {output_dir / 'sample_rdms.png'}")
plt.close()


Saved sample RDM visualization to individual_subject_rdms/sample_rdms.png


## Data Density Analysis


In [None]:
# Analyze data density across subjects
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Category count distribution
axes[0].hist([len(cats) for cats in subject_rdm_categories.values()], bins=20, edgecolor='black')
axes[0].set_xlabel('Number of Categories per Subject')
axes[0].set_ylabel('Number of Subjects')
axes[0].set_title('Data Density: Categories per Subject')
axes[0].axvline(min_categories_per_subject, color='red', linestyle='--', label=f'Min threshold ({min_categories_per_subject})')
axes[0].legend()

# Mean distance vs category count
mean_distances = [subject_rdms[sid].mean() for sid in subject_rdms.keys()]
n_categories = [len(subject_rdm_categories[sid]) for sid in subject_rdms.keys()]

axes[1].scatter(n_categories, mean_distances, alpha=0.6)
axes[1].set_xlabel('Number of Categories')
axes[1].set_ylabel('Mean RDM Distance')
axes[1].set_title('RDM Distance vs Data Density')

plt.tight_layout()
plt.savefig(output_dir / "data_density_analysis.png", dpi=150, bbox_inches='tight')
print(f"Saved data density analysis to {output_dir / 'data_density_analysis.png'}")
plt.close()


Saved data density analysis to individual_subject_rdms/data_density_analysis.png
