# Developmental Trajectory RDM Analysis

This notebook creates two Representational Dissimilarity Matrices (RDMs) for each individual subject, split by a median age threshold computed across all participants.
This allows tracking how object representations change developmentally within each subject.

## Overview

This analysis:
1. Loads grouped embeddings (averaged by category, subject, and age_mo)
2. Calculates the overall median age across all participants
3. For each subject, splits data into "younger" (age_mo <= median) and "older" (age_mo > median) bins
4. Computes RDM for each subject for each age bin (2 RDMs per subject)
5. Handles data density differences (some subjects/ages have more data)
6. Visualizes developmental trajectories
7. Compares RDMs between younger and older periods within subjects

## Key Features

- **Median split**: Uses overall median age across all participants to split each subject's data
- **Two RDMs per subject**: One for "younger" period, one for "older" period
- **Data density handling**: Minimum category threshold per age bin
- **Trajectory analysis**: Compare RDMs between younger and older periods to see developmental changes
- **Missing data handling**: Only includes subjects with sufficient data in both bins


## Setup and Imports


In [25]:
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import spearmanr, pearsonr
from scipy.cluster.hierarchy import linkage, dendrogram, optimal_leaf_ordering
from scipy.spatial.distance import squareform
from collections import defaultdict
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set matplotlib backend
import matplotlib
matplotlib.use('Agg')

print("All imports successful!")


All imports successful!


## Configuration


In [26]:
# Paths
# Path to normalized embeddings from notebook 05 (age-month level normalized embeddings)
# These are saved in category folders: {normalized_embeddings_dir}/{category}/{subject_id}_{age_mo}_month_level_avg.npy
# normalized_embeddings_dir = Path("/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/facebook_dinov3-vitb16-pretrain-lvd1689m_grouped_by_age-mo_normalized")
normalized_embeddings_dir = Path("/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/clip_embeddings_grouped_by_age-mo_normalized")

# Detect embedding type from path
normalized_embeddings_dir_str = str(normalized_embeddings_dir).lower()
if "dinov3" in normalized_embeddings_dir_str or "dinov" in normalized_embeddings_dir_str:
    embedding_type = "dinov3"
elif "clip" in normalized_embeddings_dir_str:
    embedding_type = "clip"
else:
    embedding_type = "unknown"

# Create output directory with embedding type in name
output_dir = Path(f"developmental_trajectory_rdms_{embedding_type}")
output_dir.mkdir(exist_ok=True, parents=True)

# Subject to exclude from analyses (should match notebook 06)
excluded_subject = "00270001"

# Categories file (optional - to filter to specific categories)
categories_file = Path("../../data/things_bv_overlap_categories_exclude_zero_precisions.txt")

# CDI words CSV file (required for category type organization)
cdi_path = Path("../../data/cdi_words.csv")

# Hierarchical clustering options
use_clustering = True  # Enable hierarchical clustering within category groups
save_dendrograms = True  # Save dendrogram plots for each category group

# Predefined category list for consistent RDM ordering (optional)
# Set to None to use automatic organization, or provide path to category order file
# This allows comparing RDMs across subjects with the same category ordering
USE_PREDEFINED_CATEGORY_LIST = True  # If True, load category order from PREDEFINED_CATEGORY_LIST_PATH
PREDEFINED_CATEGORY_LIST_PATH = "../vss-2026/bv_things_comp_12252025/bv_clip_filtered_zscored_hierarchical_163cats/category_order_reorganized.txt"  # Path to text file with category order (one category per line), or None
# Example: PREDEFINED_CATEGORY_LIST_PATH = "../vss-2026/bv_things_comp_12252025/bv_clip_filtered_zscored_hierarchical_163cats/category_order_reorganized.txt"

# Minimum categories required per age bin to compute RDM
min_categories_per_age_bin = 8

print(f"Normalized embeddings directory: {normalized_embeddings_dir}")
print(f"Detected embedding type: {embedding_type}")
print(f"Output directory: {output_dir}")
print(f"Excluded subject: {excluded_subject}")
print(f"CDI path: {cdi_path}")
print(f"Use clustering: {use_clustering}")
print(f"Use predefined category list: {USE_PREDEFINED_CATEGORY_LIST}")
if USE_PREDEFINED_CATEGORY_LIST and PREDEFINED_CATEGORY_LIST_PATH:
    print(f"Predefined category list path: {PREDEFINED_CATEGORY_LIST_PATH}")


Normalized embeddings directory: /data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/clip_embeddings_grouped_by_age-mo_normalized
Detected embedding type: clip
Output directory: developmental_trajectory_rdms_clip
Excluded subject: 00270001
CDI path: ../../data/cdi_words.csv
Use clustering: True
Use predefined category list: True
Predefined category list path: ../vss-2026/bv_things_comp_12252025/bv_clip_filtered_zscored_hierarchical_163cats/category_order_reorganized.txt


## Helper Functions


In [27]:
def load_category_types(cdi_path):
    """Load category type information from CDI words CSV"""
    print(f"Loading category types from {cdi_path}...")
    cdi_df = pd.read_csv(cdi_path)
    
    category_types = {}
    for _, row in cdi_df.iterrows():
        category_types[row['uni_lemma']] = {
            'is_animate': bool(row.get('is_animate', 0)),
            'is_bodypart': bool(row.get('is_bodypart', 0)),
            'is_small': bool(row.get('is_small', 0)),
            'is_big': bool(row.get('is_big', 0))
        }
    
    print(f"Loaded type information for {len(category_types)} categories")
    return category_types

def load_cdi_category_mapping(cdi_path):
    """Load CDI category mapping (uni_lemma -> category) for coloring labels"""
    cdi_df = pd.read_csv(cdi_path)
    category_map = {}
    for _, row in cdi_df.iterrows():
        category_map[row['uni_lemma']] = row.get('category', 'unknown')
    return category_map

def get_category_color(category_name, category_map):
    """Get color for a category based on its CDI category type"""
    # Define color scheme for CDI categories
    category_colors = {
        'animals': '#8B4513',  # Brown
        'body_parts': '#FF6B6B',  # Red
        'food_drink': '#FFA500',  # Orange
        'furniture_rooms': '#4169E1',  # Royal Blue
        'toys': '#FF69B4',  # Hot Pink
        'vehicles': '#32CD32',  # Lime Green
        'clothing': '#9370DB',  # Medium Purple
        'outside': '#228B22',  # Forest Green
        'places': '#4682B4',  # Steel Blue
        'small_things': '#FFD700',  # Gold
        'action_words': '#DC143C',  # Crimson
        'descriptive_words': '#20B2AA',  # Light Sea Green
        'sound_effects': '#FF1493',  # Deep Pink
        'games_routines': '#00CED1',  # Dark Turquoise
    }
    
    # Get the CDI category for this uni_lemma
    cdi_category = category_map.get(category_name, 'unknown')
    
    # Return color, default to gray if not found
    return category_colors.get(cdi_category, '#808080')  # Gray for unknown

def cluster_categories_within_group(group_categories, cat_to_embedding, save_dendrogram=False, output_dir=None, group_name=None):
    """
    Perform hierarchical clustering within a group of categories.
    
    Args:
        group_categories: List of category names in the group
        cat_to_embedding: Dictionary mapping category names to embeddings
        save_dendrogram: Whether to save dendrogram plot (default: False)
        output_dir: Output directory for saving dendrogram (required if save_dendrogram=True)
        group_name: Name of the group for saving dendrogram (required if save_dendrogram=True)
    
    Returns:
        List of category names reordered according to clustering dendrogram
    """
    if len(group_categories) <= 1:
        return group_categories, None
    
    # Get embeddings for this group
    group_embeddings = np.array([cat_to_embedding[cat].flatten() for cat in group_categories])
    
    # Normalize embeddings (z-score normalization per embedding)
    normalized_embeddings = (group_embeddings - group_embeddings.mean(axis=0)) / (group_embeddings.std(axis=0) + 1e-10)
    
    # Compute distance matrix (1 - cosine similarity)
    similarity_matrix = cosine_similarity(normalized_embeddings)
    distance_matrix = 1 - similarity_matrix
    np.fill_diagonal(distance_matrix, 0)
    
    # Convert to condensed form for linkage
    condensed_distances = squareform(distance_matrix)
    
    # Perform hierarchical clustering
    linkage_matrix = linkage(condensed_distances, method='ward')
    
    # Get optimal leaf ordering for better visualization
    try:
        linkage_matrix = optimal_leaf_ordering(linkage_matrix, condensed_distances)
    except:
        # If optimal leaf ordering fails, use original linkage
        pass
    
    # Extract the order from the dendrogram
    dendro_dict = dendrogram(linkage_matrix, no_plot=True)
    leaf_order = dendro_dict['leaves']
    
    # Reorder categories according to clustering
    clustered_categories = [group_categories[i] for i in leaf_order]
    
    # Save dendrogram if requested
    if save_dendrogram and output_dir is not None and group_name is not None:
        output_dir = Path(output_dir)
        output_dir.mkdir(exist_ok=True, parents=True)
        
        plt.figure(figsize=(12, 8))
        dendrogram(linkage_matrix, 
                  labels=group_categories,
                  leaf_rotation=90,
                  leaf_font_size=10)
        plt.title(f'Hierarchical Clustering Dendrogram: {group_name.upper()}\n({len(group_categories)} categories)',
                 fontsize=16, pad=20)
        plt.xlabel('Category', fontsize=12)
        plt.ylabel('Distance', fontsize=12)
        plt.tight_layout()
        
        # Save as PNG
        output_path_png = output_dir / f'dendrogram_{group_name}.png'
        plt.savefig(output_path_png, dpi=300, bbox_inches='tight', pad_inches=0.2)
        print(f"    Saved dendrogram to {output_path_png}")
        
        # Save as PDF
        output_path_pdf = output_dir / f'dendrogram_{group_name}.pdf'
        plt.savefig(output_path_pdf, bbox_inches='tight', pad_inches=0.2)
        print(f"    Saved dendrogram to {output_path_pdf}")
        
        plt.close()
    
    return clustered_categories, linkage_matrix

print("Helper functions loaded!")

Helper functions loaded!


In [28]:
# Load allowed categories if file exists
allowed_categories = None
if categories_file.exists():
    print(f"Loading categories from {categories_file}...")
    with open(categories_file, 'r') as f:
        allowed_categories = set(line.strip() for line in f if line.strip())
    print(f"Loaded {len(allowed_categories)} categories")
else:
    print(f"Categories file not found, using all categories")


Loading categories from ../../data/things_bv_overlap_categories_exclude_zero_precisions.txt...
Loaded 163 categories


## Load Embeddings by Age


In [29]:
def load_embeddings_by_age(embeddings_dir, allowed_categories=None, excluded_subject=None, age_binning_strategy='exact', age_bin_size=3):
    """
    Load pre-normalized embeddings organized by subject, age_mo, and category.
    These embeddings are already normalized from notebook 05.
    
    Returns:
        subject_age_embeddings: dict[subject_id][age_mo_bin][category] = embedding array (already normalized)
    """
    subject_age_embeddings = defaultdict(lambda: defaultdict(dict))
    
    # Get all category folders
    category_folders = [f for f in embeddings_dir.iterdir() if f.is_dir()]
    
    if allowed_categories:
        category_folders = [f for f in category_folders if f.name in allowed_categories]
    
    print(f"Loading pre-normalized embeddings from {len(category_folders)} categories...")
    
    for category_folder in tqdm(category_folders, desc="Loading categories"):
        category = category_folder.name
        
        # Get all embedding files in this category
        embedding_files = list(category_folder.glob("*.npy"))
        
        for emb_file in embedding_files:
            # Parse filename: {subject_id}_{age_mo}_month_level_avg.npy
            filename = emb_file.stem  # without .npy
            parts = filename.split('_')
            
            if len(parts) < 2:
                continue
            
            # Extract subject_id and age_mo
            subject_id = parts[0]
            
            # Exclude subject if specified
            if excluded_subject and subject_id == excluded_subject:
                continue
            
            age_mo = int(parts[1]) if parts[1].isdigit() else None
            
            if age_mo is None:
                continue
            
            # Apply age binning strategy
            if age_binning_strategy == 'binned':
                age_mo_bin = (age_mo // age_bin_size) * age_bin_size  # Round down to bin
            else:
                age_mo_bin = age_mo  # Use exact age
            
            try:
                embedding = np.load(emb_file)
                subject_age_embeddings[subject_id][age_mo_bin][category] = embedding
            except Exception as e:
                print(f"Error loading {emb_file}: {e}")
                continue
    
    return subject_age_embeddings

# Load pre-normalized embeddings from notebook 05 (using exact ages - we'll do median split later)
subject_age_embeddings = load_embeddings_by_age(
    normalized_embeddings_dir,  # Use normalized embeddings from notebook 05
    allowed_categories,
    excluded_subject=excluded_subject,  # Exclude specified subject
    age_binning_strategy='exact',  # Use exact ages
    age_bin_size=1  # Not used when strategy is 'exact'
)

print(f"\nLoaded embeddings for {len(subject_age_embeddings)} subjects")

# Show age bin distribution
all_age_bins = set()
for subject_id, age_data in subject_age_embeddings.items():
    all_age_bins.update(age_data.keys())

print(f"Age bins found: {sorted(all_age_bins)}")
print(f"Age range: {min(all_age_bins)} to {max(all_age_bins)} months")


Loading pre-normalized embeddings from 163 categories...


Loading categories: 100%|██████████| 163/163 [00:01<00:00, 124.70it/s]


Loaded embeddings for 31 subjects
Age bins found: [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 37]
Age range: 6 to 37 months





## Calculate Overall Median Age

We calculate the overall median age across all participants to split each subject's data into younger and older periods. Note: Embeddings are already normalized from notebook 05, so no normalization is performed here.


In [30]:
## Calculate Overall Median Age Across All Participants

# Collect all age_mo values across all subjects to compute overall median
all_ages = []
for subject_id, age_data in subject_age_embeddings.items():
    all_ages.extend(age_data.keys())

overall_median_age = np.median(all_ages)
print(f"Overall median age across all participants: {overall_median_age:.1f} months")
print(f"Age range: {min(all_ages)} to {max(all_ages)} months")
print(f"Total age observations: {len(all_ages)}")

## Use Pre-Normalized Embeddings from Notebook 05

# Embeddings are already normalized from notebook 05, so we use them directly
# Rename for consistency with rest of code
subject_age_embeddings_normalized = subject_age_embeddings

print(f"\nUsing pre-normalized embeddings for {len(subject_age_embeddings_normalized)} subjects")
print("  Note: Embeddings were normalized in notebook 05 (within each subject across all age bins)")
print("  No additional normalization performed here")

## Aggregate Embeddings by Median Split and Compute RDMs

def aggregate_embeddings_by_bin(age_embeddings_dict, age_bin_name):
    """
    Aggregate embeddings for a bin by averaging across all ages in that bin.
    
    Args:
        age_embeddings_dict: dict[age_mo][category] = embedding array
        age_bin_name: 'younger' or 'older'
    
    Returns:
        aggregated_embeddings: dict[category] = averaged embedding array
    """
    # Collect all embeddings for each category across ages in this bin
    category_embeddings = defaultdict(list)
    
    for age_mo, categories in age_embeddings_dict.items():
        for cat, embedding in categories.items():
            category_embeddings[cat].append(embedding)
    
    # Average embeddings for each category
    aggregated = {}
    for cat, embeddings_list in category_embeddings.items():
        if len(embeddings_list) > 0:
            aggregated[cat] = np.mean(embeddings_list, axis=0)
    
    return aggregated

def compute_rdm_for_bin_with_na(bin_embeddings_dict, ordered_categories_list):
    """
    Compute RDM for a single age bin (younger or older) with NA for missing categories.
    This ensures consistent ordering across bins using the predefined category order.
    
    Args:
        bin_embeddings_dict: dict[category] = embedding array (should be normalized and aggregated)
        ordered_categories_list: list of all categories in desired order (may include categories not present for this bin)
    
    Returns:
        rdm: numpy array of shape (n_categories, n_categories) with np.nan for missing categories
        mask: boolean array of shape (n_categories, n_categories) where True indicates NA (missing category)
        available_categories: list of categories actually present for this bin
    """
    n_categories = len(ordered_categories_list)
    
    # Find available categories (categories that exist for this bin)
    available_categories = [cat for cat in ordered_categories_list if cat in bin_embeddings_dict]
    
    if len(available_categories) < min_categories_per_age_bin:
        # Return RDM full of NaN if not enough categories
        rdm = np.full((n_categories, n_categories), np.nan)
        mask = np.ones((n_categories, n_categories), dtype=bool)
        return rdm, mask, available_categories
    
    # Build embedding matrix for available categories (already normalized)
    embedding_matrix = np.array([bin_embeddings_dict[cat].flatten() for cat in available_categories])
    
    # Ensure 2D shape: (n_available_categories, embedding_dim)
    if embedding_matrix.ndim != 2:
        raise ValueError(f"Expected 2D embedding matrix, got shape {embedding_matrix.shape}")
    
    # Compute cosine similarity for available categories
    similarity_matrix_available = cosine_similarity(embedding_matrix)
    
    # Convert to distance (RDM) for available categories
    distance_matrix_available = 1 - similarity_matrix_available
    np.fill_diagonal(distance_matrix_available, 0)  # Ensure diagonal is 0
    
    # Make symmetric (in case of numerical errors)
    distance_matrix_available = (distance_matrix_available + distance_matrix_available.T) / 2
    
    # Create full RDM with NaN for missing categories
    rdm = np.full((n_categories, n_categories), np.nan)
    mask = np.ones((n_categories, n_categories), dtype=bool)
    
    # Map available categories to their indices in ordered_categories_list
    available_indices = [ordered_categories_list.index(cat) for cat in available_categories]
    
    # Fill in the RDM for available categories
    for i, idx_i in enumerate(available_indices):
        for j, idx_j in enumerate(available_indices):
            rdm[idx_i, idx_j] = distance_matrix_available[i, j]
            mask[idx_i, idx_j] = False  # False means not NA (data present)
    
    return rdm, mask, available_categories

# Get all unique categories across all subjects and ages
all_categories = set()
for subject_id, age_data in subject_age_embeddings_normalized.items():
    for age_mo, categories in age_data.items():
        all_categories.update(categories.keys())

all_categories = sorted(list(all_categories))
print(f"\nTotal unique categories across all subjects and ages: {len(all_categories)}")
print("Note: RDMs will be computed with predefined category order after organization step.")

# First, identify which subjects have sufficient data in both bins
# We'll compute actual RDMs after category organization
print(f"\nIdentifying subjects with sufficient data in both age bins...")
subject_age_rdms = {}  # Temporary storage - will be recomputed with predefined order
subject_age_rdm_categories = {}  # Temporary storage
excluded_subjects = []  # Track excluded subjects and reasons

for subject_id, age_data in tqdm(subject_age_embeddings_normalized.items(), desc="Checking subjects"):
    # Split ages into younger and older bins
    younger_ages = {age_mo: categories for age_mo, categories in age_data.items() 
                    if age_mo <= overall_median_age}
    older_ages = {age_mo: categories for age_mo, categories in age_data.items() 
                  if age_mo > overall_median_age}
    
    subject_age_rdms[subject_id] = {}
    subject_age_rdm_categories[subject_id] = {}
    
    # Process younger bin
    younger_has_rdm = False
    younger_n_cats = 0
    if len(younger_ages) > 0:
        younger_aggregated = aggregate_embeddings_by_bin(younger_ages, 'younger')
        younger_n_cats = len(younger_aggregated)
        if younger_n_cats >= min_categories_per_age_bin:
            younger_has_rdm = True
            subject_age_rdms[subject_id]['younger'] = True  # Placeholder
            subject_age_rdm_categories[subject_id]['younger'] = list(younger_aggregated.keys())
    else:
        excluded_subjects.append({
            'subject_id': subject_id,
            'reason': 'no younger ages',
            'younger_n_cats': 0,
            'older_n_cats': len(aggregate_embeddings_by_bin(older_ages, 'older')) if len(older_ages) > 0 else 0
        })
    
    # Process older bin
    older_has_rdm = False
    older_n_cats = 0
    if len(older_ages) > 0:
        older_aggregated = aggregate_embeddings_by_bin(older_ages, 'older')
        older_n_cats = len(older_aggregated)
        if older_n_cats >= min_categories_per_age_bin:
            older_has_rdm = True
            subject_age_rdms[subject_id]['older'] = True  # Placeholder
            subject_age_rdm_categories[subject_id]['older'] = list(older_aggregated.keys())
    else:
        excluded_subjects.append({
            'subject_id': subject_id,
            'reason': 'no older ages',
            'younger_n_cats': younger_n_cats,
            'older_n_cats': 0
        })
    
    # Filter out subjects without both bins
    if not younger_has_rdm or not older_has_rdm:
        if subject_id not in [s['subject_id'] for s in excluded_subjects]:
            # Determine specific reason
            if not younger_has_rdm and not older_has_rdm:
                reason = 'both bins insufficient'
            elif not younger_has_rdm:
                reason = f'younger bin insufficient ({younger_n_cats} < {min_categories_per_age_bin} cats)'
            else:
                reason = f'older bin insufficient ({older_n_cats} < {min_categories_per_age_bin} cats)'
            
            excluded_subjects.append({
                'subject_id': subject_id,
                'reason': reason,
                'younger_n_cats': younger_n_cats,
                'older_n_cats': older_n_cats
            })
        
        del subject_age_rdms[subject_id]
        del subject_age_rdm_categories[subject_id]

print(f"\nIdentified {len(subject_age_rdms)} subjects with sufficient data in both bins")
print(f"  Excluded {len(excluded_subjects)} subjects without sufficient data in both bins")

# Show excluded subjects details
if len(excluded_subjects) > 0:
    print(f"\nExcluded subjects ({len(excluded_subjects)}):")
    excluded_df = pd.DataFrame(excluded_subjects)
    excluded_df = excluded_df.sort_values('subject_id')
    for _, row in excluded_df.iterrows():
        print(f"  {row['subject_id']}: {row['reason']} (younger: {row['younger_n_cats']} cats, older: {row['older_n_cats']} cats)")


Overall median age across all participants: 16.0 months
Age range: 6 to 37 months
Total age observations: 266

Using pre-normalized embeddings for 31 subjects
  Note: Embeddings were normalized in notebook 05 (within each subject across all age bins)
  No additional normalization performed here

Total unique categories across all subjects and ages: 163
Note: RDMs will be computed with predefined category order after organization step.

Identifying subjects with sufficient data in both age bins...


Checking subjects: 100%|██████████| 31/31 [00:00<00:00, 462.22it/s]


Identified 18 subjects with sufficient data in both bins
  Excluded 13 subjects without sufficient data in both bins

Excluded subjects (13):
  00220001: no older ages (younger: 155 cats, older: 0 cats)
  00230001: no older ages (younger: 143 cats, older: 0 cats)
  00340002: no older ages (younger: 99 cats, older: 0 cats)
  00350001: no older ages (younger: 111 cats, older: 0 cats)
  00350002: no older ages (younger: 123 cats, older: 0 cats)
  00360001: no older ages (younger: 153 cats, older: 0 cats)
  00390001: no older ages (younger: 141 cats, older: 0 cats)
  00430002: no older ages (younger: 112 cats, older: 0 cats)
  00440001: no older ages (younger: 134 cats, older: 0 cats)
  00460001: no older ages (younger: 140 cats, older: 0 cats)
  00550001: no older ages (younger: 132 cats, older: 0 cats)
  00720001: no younger ages (younger: 0 cats, older: 154 cats)
  00820001: no younger ages (younger: 0 cats, older: 160 cats)





In [31]:
# Organize Categories (with Predefined List Option)
# This section organizes categories either by loading a predefined category list (for consistent ordering across subjects) or by automatic organization.

# Get all unique categories across all subjects and ages (needed for organization)
all_categories = set()
for subject_id, age_data in subject_age_embeddings_normalized.items():
    for age_mo, categories in age_data.items():
        all_categories.update(categories.keys())

all_categories = sorted(list(all_categories))
print(f"Total unique categories across all subjects and ages: {len(all_categories)}")

# Organize categories: either load predefined list or organize automatically
print("\nOrganizing categories...")

if USE_PREDEFINED_CATEGORY_LIST and PREDEFINED_CATEGORY_LIST_PATH is not None:
    # Load predefined category list
    predefined_path = Path(PREDEFINED_CATEGORY_LIST_PATH)
    if not predefined_path.exists():
        raise FileNotFoundError(f"Predefined category list file not found: {predefined_path}")
    
    print(f"  Loading predefined category order from {predefined_path}...")
    with open(predefined_path, 'r') as f:
        # Skip comment lines (lines starting with #)
        ordered_categories = [line.strip() for line in f if line.strip() and not line.strip().startswith('#')]
    
    # Verify that all categories in predefined list exist in our data
    predefined_set = set(ordered_categories)
    all_categories_set = set(all_categories)
    
    if predefined_set != all_categories_set:
        missing_in_predefined = all_categories_set - predefined_set
        extra_in_predefined = predefined_set - all_categories_set
        if missing_in_predefined:
            print(f"  Warning: {len(missing_in_predefined)} categories in data but not in predefined list: {sorted(missing_in_predefined)[:5]}...")
        if extra_in_predefined:
            print(f"  Warning: {len(extra_in_predefined)} categories in predefined list but not in data: {sorted(extra_in_predefined)[:5]}...")
        # Use intersection: only categories that exist in both
        ordered_categories = [cat for cat in ordered_categories if cat in all_categories_set]
        print(f"  Using intersection: {len(ordered_categories)} categories")
    
    print(f"  Loaded {len(ordered_categories)} categories in predefined order")
    
    # Still organize into groups for visualization boundaries (even though order is predefined)
    # Load category types for grouping
    if cdi_path.exists():
        category_types = load_category_types(cdi_path)
    else:
        print(f"Warning: CDI path {cdi_path} not found. Cannot compute group boundaries.")
        category_types = {}
    
    # Organize predefined categories into groups for visualization boundaries
    organized = {
        'animals': [],
        'bodyparts': [],
        'big_objects': [],
        'small_objects': [],
        'others': []
    }
    
    for cat in ordered_categories:
        if cat not in category_types:
            organized['others'].append(cat)
            continue
        
        types = category_types[cat]
        if types['is_animate']:
            organized['animals'].append(cat)
        elif types['is_bodypart']:
            organized['bodyparts'].append(cat)
        elif types['is_big']:
            organized['big_objects'].append(cat)
        elif types['is_small']:
            organized['small_objects'].append(cat)
        else:
            organized['others'].append(cat)
    
else:
    # Automatic organization by type (similar to notebook 02)
    print(f"  Organizing categories by type...")
    
    # Load category types for organization
    if cdi_path.exists():
        category_types = load_category_types(cdi_path)
    else:
        print(f"Warning: CDI path {cdi_path} not found. Skipping category organization.")
        category_types = {}
    
    # Get a representative set of embeddings for clustering (average across all subjects and ages)
    representative_embeddings = {}
    for cat in all_categories:
        cat_embeddings = []
        for subject_id, age_data in subject_age_embeddings_normalized.items():
            for age_mo, categories in age_data.items():
                if cat in categories:
                    cat_embeddings.append(categories[cat])
        if len(cat_embeddings) > 0:
            # Average across all subjects and ages for this category
            representative_embeddings[cat] = np.mean(cat_embeddings, axis=0)
    
    # Organize by type
    organized = {
        'animals': [],
        'bodyparts': [],
        'big_objects': [],
        'small_objects': [],
        'others': []
    }
    
    for cat in all_categories:
        if cat not in category_types:
            organized['others'].append(cat)
            continue
        
        types = category_types[cat]
        if types['is_animate']:
            organized['animals'].append(cat)
        elif types['is_bodypart']:
            organized['bodyparts'].append(cat)
        elif types['is_big']:
            organized['big_objects'].append(cat)
        elif types['is_small']:
            organized['small_objects'].append(cat)
        else:
            organized['others'].append(cat)
    
    print(f"  Organized into: {len(organized['animals'])} animals, {len(organized['bodyparts'])} bodyparts, "
          f"{len(organized['big_objects'])} big objects, {len(organized['small_objects'])} small objects, "
          f"{len(organized['others'])} others")
    
    # Apply hierarchical clustering if enabled
    if use_clustering:
        print(f"  Applying hierarchical clustering within groups...")
        for group_name in ['animals', 'bodyparts', 'big_objects', 'small_objects', 'others']:
            if len(organized[group_name]) > 1:
                # Filter to categories that have representative embeddings
                group_cats = [cat for cat in organized[group_name] if cat in representative_embeddings]
                if len(group_cats) > 1:
                    print(f"    Clustering {group_name} ({len(group_cats)} categories)...")
                    organized[group_name], _ = cluster_categories_within_group(
                        group_cats,
                        representative_embeddings,
                        save_dendrogram=save_dendrograms,
                        output_dir=output_dir,
                        group_name=group_name
                    )
                else:
                    organized[group_name] = group_cats
            else:
                organized[group_name] = [cat for cat in organized[group_name] if cat in representative_embeddings]
    else:
        for group_name in organized:
            organized[group_name] = sorted([cat for cat in organized[group_name] if cat in representative_embeddings])
    
    # Create ordered list
    ordered_categories = (
        organized['animals'] +
        organized['bodyparts'] +
        organized['big_objects'] +
        organized['small_objects'] +
        organized['others']
    )

print(f"\nFinal ordered category list: {len(ordered_categories)} categories")

Total unique categories across all subjects and ages: 163

Organizing categories...
  Loading predefined category order from ../vss-2026/bv_things_comp_12252025/bv_clip_filtered_zscored_hierarchical_163cats/category_order_reorganized.txt...
  Loaded 163 categories in predefined order
Loading category types from ../../data/cdi_words.csv...
Loaded type information for 295 categories

Final ordered category list: 163 categories


In [32]:
# No age binning needed - we're using median split (younger/older)

## Grouped Developmental Trajectory Visualization

Create a grouped visualization combining multiple subjects' developmental trajectory plots together in one figure.

In [33]:
# Create grouped visualization combining N subjects' developmental trajectory plots
print("Creating grouped developmental trajectory visualization...")

# Initialize subject_age_group_boundaries if it doesn't exist
if 'subject_age_group_boundaries' not in globals():
    subject_age_group_boundaries = {}
    # Initialize empty dicts for all subjects
    for sid in subject_age_rdms.keys():
        subject_age_group_boundaries[sid] = {}
        for bname in ['younger', 'older']:
            if bname in subject_age_rdms[sid]:
                subject_age_group_boundaries[sid][bname] = []

# Load CDI category mapping for label coloring
cdi_category_map = load_cdi_category_mapping(cdi_path)

# Get all subjects with valid data
valid_subjects = [sid for sid in subject_age_rdms.keys() 
                  if 'younger' in subject_age_rdms[sid] and 'older' in subject_age_rdms[sid]]

if len(valid_subjects) == 0:
    print("No subjects with valid data for grouped visualization")
else:
    # Number of subjects to plot (can be adjusted)
    # Set to None to plot all subjects, or specify a number
    n_subjects_to_plot = None  # Change to a number like 6, 9, 12, etc. to limit
    
    subjects_to_plot = valid_subjects[:n_subjects_to_plot] if n_subjects_to_plot else valid_subjects
    n_subjects = len(subjects_to_plot)
    
    print(f"Plotting {n_subjects} subjects in grouped visualization...")
    
    # Calculate global min/max for consistent color scale across all subjects
    all_rdm_values = []
    for subject_id in subjects_to_plot:
        bin_rdms = subject_age_rdms[subject_id]
        for bin_name in ['younger', 'older']:
            rdm = bin_rdms[bin_name]
            # Ensure rdm is a numpy array
            if not isinstance(rdm, np.ndarray):
                rdm = np.array(rdm)
            # Only process if numeric array (not boolean)
            if isinstance(rdm, np.ndarray) and rdm.size > 0 and rdm.dtype != bool:
                valid_values = rdm[~np.isnan(rdm)]
                if len(valid_values) > 0:
                    all_rdm_values.extend(valid_values.tolist())
    
    vmin = np.percentile(all_rdm_values, 1) if len(all_rdm_values) > 0 else 0
    vmax = np.percentile(all_rdm_values, 99) if len(all_rdm_values) > 0 else 2
    
    # Create figure with n_subjects rows and 2 columns (younger, older)
    # Adjust figure size based on number of subjects
    fig_height = max(4, n_subjects * 2.5)  # At least 4 inches, 2.5 inches per subject
    fig_width = 16  # Keep consistent width
    
    fig, axes = plt.subplots(n_subjects, 2, figsize=(fig_width, fig_height))
    
    # Handle case where there's only one subject (axes would be 1D)
    if n_subjects == 1:
        axes = axes.reshape(1, -1)
    else:
        axes = axes.reshape(n_subjects, 2)
    
    # Determine font sizes based on number of categories
    n_cats_total = len(ordered_categories)
    if n_cats_total <= 50:
        tick_fontsize = 8
    elif n_cats_total <= 100:
        tick_fontsize = 6
    else:
        tick_fontsize = 4
    
    # Plot each subject
    for row_idx, subject_id in enumerate(subjects_to_plot):
        bin_rdms = subject_age_rdms[subject_id]
        # Check if masks exist, if not create empty masks
        if "subject_age_rdm_masks" not in globals() or subject_age_rdm_masks is None:
            # Create empty masks if not available
            subject_age_rdm_masks = {}
            for sid in subject_age_rdms.keys():
                subject_age_rdm_masks[sid] = {}
                for bname in ["younger", "older"]:
                    if bname in subject_age_rdms[sid]:
                        # Create mask with all False (no masking)
                        rdm_shape = subject_age_rdms[sid][bname].shape if isinstance(subject_age_rdms[sid][bname], np.ndarray) else (0, 0)
                        subject_age_rdm_masks[sid][bname] = np.zeros(rdm_shape, dtype=bool)
        
        bin_masks = subject_age_rdm_masks[subject_id]
        
        for col_idx, bin_name in enumerate(['younger', 'older']):
            rdm = bin_rdms[bin_name]
            mask = bin_masks[bin_name]
            available_cats = subject_age_rdm_categories[subject_id][bin_name]
            group_boundaries = subject_age_group_boundaries[subject_id][bin_name]
            
            ax = axes[row_idx, col_idx]
            
            # Ensure rdm is a numpy array with correct shape
            rdm = np.array(rdm)
            # Ensure mask is a numpy array with correct shape
            mask = np.array(mask)
            # Ensure shapes match
            if rdm.shape != mask.shape:
                print(f"Warning: Shape mismatch for {subject_id} {bin_name}: rdm {rdm.shape} vs mask {mask.shape}")
                # Try to reshape or skip if incompatible
                if rdm.size == mask.size:
                    rdm = rdm.reshape(mask.shape)
                else:
                    print(f"  Skipping {subject_id} {bin_name} due to incompatible shapes")
                    continue
            
            # Create masked array for visualization
            rdm_masked = np.ma.masked_where(mask, rdm)
            cmap = plt.cm.get_cmap('viridis').copy()
            cmap.set_bad(color='white', alpha=1.0)
            im = ax.imshow(rdm_masked, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
            
            # Set category names as axis labels (only show every Nth label to avoid crowding)
            n_cats = len(ordered_categories)
            if n_cats <= 50:
                tick_step = 1
            elif n_cats <= 100:
                tick_step = 2
            else:
                tick_step = max(1, n_cats // 50)
            
            ax.set_xticks(range(0, n_cats, tick_step))
            ax.set_yticks(range(0, n_cats, tick_step))
            ax.set_xticklabels([ordered_categories[i] for i in range(0, n_cats, tick_step)],
                               rotation=90, ha="right", fontsize=max(8, tick_fontsize))
            ax.set_yticklabels([ordered_categories[i] for i in range(0, n_cats, tick_step)], 
                               fontsize=tick_fontsize)
            
            # Color code labels based on category
            for i, (xlabel, ylabel) in enumerate(zip(ax.get_xticklabels(), ax.get_yticklabels())):
                tick_idx = i * tick_step
                if tick_idx < len(ordered_categories):
                    cat_name = ordered_categories[tick_idx]
                    color = get_category_color(cat_name, cdi_category_map)
                    xlabel.set_color(color)
                    ylabel.set_color(color)
            
            # Create title with subject ID and age info
            n_cats_available = len(available_cats)
            if bin_name == "younger":
                title = f"{subject_id} - Younger (≤{overall_median_age:.0f}mo)\n({n_cats_available}/{n_cats_total} cats)"
            else:
                title = f"{subject_id} - Older (>{overall_median_age:.0f}mo)\n({n_cats_available}/{n_cats_total} cats)"
            
            ax.set_title(title, fontsize=10, pad=5)
            
            # Add colorbar only to the rightmost plots
            if col_idx == 1:  # Right column
                plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    # Add overall title
    plt.suptitle(f'Grouped Developmental Trajectories: {n_subjects} Subjects\n(Median split at {overall_median_age:.1f} months)', 
                 fontsize=16, y=0.998, fontweight='bold')
    
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    
    # Save the grouped visualization
    output_filename = f"grouped_trajectory_{n_subjects}_subjects.png"
    plt.savefig(output_dir / output_filename, dpi=200, bbox_inches='tight')
    print(f"Saved grouped visualization to {output_dir / output_filename}")
    plt.close()
    
    # Also create a version with fewer subjects if there are many (for better readability)
    if n_subjects > 12:
        print(f"\nCreating additional grouped visualization with first 12 subjects for better readability...")
        subjects_to_plot_12 = valid_subjects[:12]
        
        fig_height_12 = 12 * 2.5
        fig, axes = plt.subplots(12, 2, figsize=(fig_width, fig_height_12))
        
        for row_idx, subject_id in enumerate(subjects_to_plot_12):
            bin_rdms = subject_age_rdms[subject_id]
            # Check if masks exist, if not create empty masks
            if "subject_age_rdm_masks" not in globals() or subject_age_rdm_masks is None:
                # Create empty masks if not available
                subject_age_rdm_masks = {}
                for sid in subject_age_rdms.keys():
                    subject_age_rdm_masks[sid] = {}
                    for bname in ["younger", "older"]:
                        if bname in subject_age_rdms[sid]:
                            # Create mask with all False (no masking)
                            rdm_shape = subject_age_rdms[sid][bname].shape if isinstance(subject_age_rdms[sid][bname], np.ndarray) else (0, 0)
                            subject_age_rdm_masks[sid][bname] = np.zeros(rdm_shape, dtype=bool)
            
            bin_masks = subject_age_rdm_masks[subject_id]
            
            for col_idx, bin_name in enumerate(['younger', 'older']):
                rdm = bin_rdms[bin_name]
                mask = bin_masks[bin_name]
                available_cats = subject_age_rdm_categories[subject_id][bin_name]
                group_boundaries = subject_age_group_boundaries[subject_id][bin_name]
                
                ax = axes[row_idx, col_idx]
                
                # Ensure rdm is a numpy array with correct shape
                rdm = np.array(rdm)
                
                # Ensure mask is a numpy array with correct shape
                mask = np.array(mask)
                
                # Ensure shapes match
                if rdm.shape != mask.shape:
                    print(f"Warning: Shape mismatch for {subject_id} {bin_name}: rdm {rdm.shape} vs mask {mask.shape}")
                    # Try to reshape or skip if incompatible
                    if rdm.size == mask.size:
                        rdm = rdm.reshape(mask.shape)
                    else:
                        print(f"  Skipping {subject_id} {bin_name} due to incompatible shapes")
                        continue
                
                # Create masked array for visualization
                rdm_masked = np.ma.masked_where(mask, rdm)
                cmap = plt.cm.get_cmap('viridis').copy()
                cmap.set_bad(color='white', alpha=1.0)
                im = ax.imshow(rdm_masked, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
                
                # Set category names as axis labels
                n_cats = len(ordered_categories)
                if n_cats <= 50:
                    tick_step = 1
                elif n_cats <= 100:
                    tick_step = 2
                else:
                    tick_step = max(1, n_cats // 50)
                
                ax.set_xticks(range(0, n_cats, tick_step))
                ax.set_yticks(range(0, n_cats, tick_step))
                ax.set_xticklabels([ordered_categories[i] for i in range(0, n_cats, tick_step)],
                                   rotation=90, ha="right", fontsize=max(8, tick_fontsize))
                ax.set_yticklabels([ordered_categories[i] for i in range(0, n_cats, tick_step)], 
                                   fontsize=tick_fontsize)
                
                # Color code labels based on category
                for i, (xlabel, ylabel) in enumerate(zip(ax.get_xticklabels(), ax.get_yticklabels())):
                    tick_idx = i * tick_step
                    if tick_idx < len(ordered_categories):
                        cat_name = ordered_categories[tick_idx]
                        color = get_category_color(cat_name, cdi_category_map)
                        xlabel.set_color(color)
                        ylabel.set_color(color)
                
                # Create title with subject ID and age info
                n_cats_available = len(available_cats)
                if bin_name == "younger":
                    title = f"{subject_id} - Younger (≤{overall_median_age:.0f}mo)\n({n_cats_available}/{n_cats_total} cats)"
                else:
                    title = f"{subject_id} - Older (>{overall_median_age:.0f}mo)\n({n_cats_available}/{n_cats_total} cats)"
                
                ax.set_title(title, fontsize=10, pad=5)
                
                # Add colorbar only to the rightmost plots
                if col_idx == 1:  # Right column
                    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        
        # Add overall title (outside the loops)
        plt.suptitle(f'Grouped Developmental Trajectories: First 12 Subjects\n(Median split at {overall_median_age:.1f} months)', 
                     fontsize=16, y=0.998, fontweight='bold')
        
        plt.tight_layout(rect=[0, 0, 1, 0.98])
        
        output_filename_12 = f"grouped_trajectory_12_subjects.png"
        plt.savefig(output_dir / output_filename_12, dpi=200, bbox_inches='tight')
        print(f"Saved 12-subject grouped visualization to {output_dir / output_filename_12}")
        plt.close()

print("\nGrouped visualization complete!")

Creating grouped developmental trajectory visualization...
Plotting 18 subjects in grouped visualization...
  Skipping 00320001 younger due to incompatible shapes
  Skipping 00320001 older due to incompatible shapes
  Skipping 00680001 younger due to incompatible shapes
  Skipping 00680001 older due to incompatible shapes
  Skipping 00320002 younger due to incompatible shapes
  Skipping 00320002 older due to incompatible shapes
  Skipping 00500001 younger due to incompatible shapes
  Skipping 00500001 older due to incompatible shapes
  Skipping 00400001 younger due to incompatible shapes
  Skipping 00400001 older due to incompatible shapes
  Skipping 00430001 younger due to incompatible shapes
  Skipping 00430001 older due to incompatible shapes
  Skipping 00560001 younger due to incompatible shapes
  Skipping 00560001 older due to incompatible shapes
  Skipping 00370001 younger due to incompatible shapes
  Skipping 00370001 older due to incompatible shapes
  Skipping 00400003 younger 

In [34]:
# Recompute RDMs using predefined category order with NaN for missing categories
print("\nRecomputing RDMs with predefined category order (including NaN for missing categories)...")
subject_age_rdms_reorganized = {}
subject_age_rdm_masks = {}  # Store masks indicating NA cells
subject_age_rdm_categories_reorganized = {}
subject_age_group_boundaries = {}  # Store group boundaries for visual separators

for subject_id in tqdm(subject_age_rdms.keys(), desc="Recomputing RDMs"):
    subject_age_rdms_reorganized[subject_id] = {}
    subject_age_rdm_masks[subject_id] = {}
    subject_age_rdm_categories_reorganized[subject_id] = {}
    subject_age_group_boundaries[subject_id] = {}
    
    # Get original data for this subject
    original_rdms = subject_age_rdms[subject_id]
    original_categories = subject_age_rdm_categories[subject_id]
    
    # Recompute each bin's RDM using predefined order
    for bin_name in ['younger', 'older']:
        if bin_name not in original_rdms:
            continue
        
        # Get aggregated embeddings for this bin
        if bin_name == 'younger':
            relevant_ages = {age_mo: cats for age_mo, cats in subject_age_embeddings_normalized[subject_id].items() 
                           if age_mo <= overall_median_age}
        else:  # older
            relevant_ages = {age_mo: cats for age_mo, cats in subject_age_embeddings_normalized[subject_id].items() 
                            if age_mo > overall_median_age}
        
        # Aggregate embeddings for this bin
        bin_embeddings = aggregate_embeddings_by_bin(relevant_ages, bin_name)
        
        # Compute RDM with NaN for missing categories using predefined order
        rdm, mask, available_cats = compute_rdm_for_bin_with_na(bin_embeddings, ordered_categories)
        
        if rdm is not None:
            subject_age_rdms_reorganized[subject_id][bin_name] = rdm
            subject_age_rdm_masks[subject_id][bin_name] = mask
            subject_age_rdm_categories_reorganized[subject_id][bin_name] = available_cats
            
            # Compute group boundaries based on full ordered_categories (for visualization)
            group_boundaries = []
            current_idx = 0
            for group_name in ['animals', 'bodyparts', 'big_objects', 'small_objects', 'others']:
                group_cats = [cat for cat in organized[group_name] if cat in ordered_categories]
                if len(group_cats) > 0:
                    group_start = current_idx
                    group_end = current_idx + len(group_cats)
                    group_boundaries.append({
                        'name': group_name,
                        'start': group_start,
                        'end': group_end,
                        'categories': group_cats
                    })
                    current_idx = group_end
            
            subject_age_group_boundaries[subject_id][bin_name] = group_boundaries

# Update the main dictionaries
subject_age_rdms = subject_age_rdms_reorganized
subject_age_rdm_categories = subject_age_rdm_categories_reorganized

print(f"Recomputed RDMs for {len(subject_age_rdms)} subjects using predefined category order")
print(f"  All RDMs now use the same {len(ordered_categories)}-category order with NaN for missing categories")


Recomputing RDMs with predefined category order (including NaN for missing categories)...


Recomputing RDMs: 100%|██████████| 18/18 [00:00<00:00, 64.49it/s]

Recomputed RDMs for 18 subjects using predefined category order
  All RDMs now use the same 163-category order with NaN for missing categories





## Save RDMs for Each Subject (Younger and Older Bins)


In [35]:
# Save RDMs for each subject (younger and older bins)
print("Saving developmental trajectory RDMs...")

for subject_id, bin_rdms in tqdm(subject_age_rdms.items(), desc="Saving RDMs"):
    subject_output_dir = output_dir / subject_id
    subject_output_dir.mkdir(exist_ok=True, parents=True)
    
    for bin_name, rdm in bin_rdms.items():
        available_cats = subject_age_rdm_categories[subject_id][bin_name]
        
        # Save as numpy array (includes NaN for missing categories)
        np.save(subject_output_dir / f"rdm_{bin_name}.npy", rdm)
        
        # Save as CSV with category labels (use ordered_categories for full order)
        rdm_df = pd.DataFrame(rdm, index=ordered_categories, columns=ordered_categories)
        rdm_df.to_csv(subject_output_dir / f"rdm_{bin_name}.csv")
        
        # Save metadata
        # Compute statistics only on valid (non-NaN) values
        valid_rdm = rdm[~np.isnan(rdm)]
        valid_rdm_positive = valid_rdm[valid_rdm > 0]
        
        metadata = {
            'subject_id': subject_id,
            'age_bin': bin_name,
            'median_age_threshold': overall_median_age,
            'n_categories_total': len(ordered_categories),
            'n_categories_available': len(available_cats),
            'n_categories_missing': len(ordered_categories) - len(available_cats),
            'categories_available': available_cats,
            'mean_distance': float(np.nanmean(rdm)),
            'std_distance': float(np.nanstd(rdm)),
            'min_distance': float(valid_rdm_positive.min()) if len(valid_rdm_positive) > 0 else np.nan,
            'max_distance': float(np.nanmax(rdm))
        }
        
        metadata_df = pd.DataFrame([metadata])
        metadata_df.to_csv(subject_output_dir / f"metadata_{bin_name}.csv", index=False)

        # Create and save individual dendrogram for this bin (using available categories only)
        if len(available_cats) > 1:
            # Get aggregated embeddings for this bin's available categories
            # We need to reconstruct from the original normalized embeddings
            bin_embeddings = {}
            for cat in available_cats:
                cat_embeddings = []
                # Get all ages in this bin for this subject
                if bin_name == 'younger':
                    relevant_ages = {age_mo: cats for age_mo, cats in subject_age_embeddings_normalized[subject_id].items() 
                                   if age_mo <= overall_median_age}
                else:  # older
                    relevant_ages = {age_mo: cats for age_mo, cats in subject_age_embeddings_normalized[subject_id].items() 
                                    if age_mo > overall_median_age}
                
                for age_mo, age_cats in relevant_ages.items():
                    if cat in age_cats:
                        cat_embeddings.append(age_cats[cat])
                
                if len(cat_embeddings) > 0:
                    bin_embeddings[cat] = np.mean(cat_embeddings, axis=0)
            
            if len(bin_embeddings) > 1:
                # Build embedding matrix
                embedding_matrix = np.array([bin_embeddings[cat].flatten() for cat in available_cats])
                
                # Normalize embeddings
                normalized_embeddings = (embedding_matrix - embedding_matrix.mean(axis=0)) / (embedding_matrix.std(axis=0) + 1e-10)
                
                # Compute distance matrix
                similarity_matrix = cosine_similarity(normalized_embeddings)
                distance_matrix = 1 - similarity_matrix
                np.fill_diagonal(distance_matrix, 0)
                
                # Convert to condensed form for linkage
                condensed_distances = squareform(distance_matrix)
                
                # Perform hierarchical clustering
                linkage_matrix = linkage(condensed_distances, method='ward')
                
                # Get optimal leaf ordering
                try:
                    linkage_matrix = optimal_leaf_ordering(linkage_matrix, condensed_distances)
                except:
                    pass
                
                # Create dendrogram
                plt.figure(figsize=(max(16, len(available_cats) * 0.5), 10))
                dendrogram(linkage_matrix, 
                          labels=available_cats,
                          leaf_rotation=90,
                          leaf_font_size=max(8, min(14, 200 // len(available_cats))))
                plt.title(f'Dendrogram: {subject_id} {bin_name.capitalize()} (≤{overall_median_age:.0f}mo vs >{overall_median_age:.0f}mo)\n({len(available_cats)}/{len(ordered_categories)} categories)',
                         fontsize=16, pad=20)
                plt.xlabel('Category', fontsize=14)
                plt.ylabel('Distance', fontsize=14)
                plt.tight_layout()
                
                # Save dendrogram
                dendrogram_dir = subject_output_dir / "dendrograms"
                dendrogram_dir.mkdir(exist_ok=True, parents=True)
                dendrogram_path = dendrogram_dir / f"dendrogram_{bin_name}.png"
                plt.savefig(dendrogram_path, dpi=300, bbox_inches='tight', pad_inches=0.2)
                plt.close()

print(f"\nSaved RDMs to {output_dir}")


Saving developmental trajectory RDMs...


Saving RDMs: 100%|██████████| 18/18 [01:29<00:00,  4.96s/it]


Saved RDMs to developmental_trajectory_rdms_clip





## Analyze Developmental Trajectories


## Detailed Explanation: RDM Correlation Logic

### Overview
This section explains in detail how we compute correlations between younger and older RDMs for each subject, including how we handle missing categories (NaN values) and whether correlations are comparable across subjects.

### RDM Structure
Each subject has two RDMs (younger and older), both with shape (163, 163) corresponding to the full predefined category order:
- **Diagonal elements**: Always 0 (distance from category to itself)
- **Off-diagonal elements**: Distance values (0-2 range) for category pairs that exist in that age bin
- **Missing categories**: Represented as NaN (white cells in visualization)

### Step-by-Step Correlation Process

#### Step 1: Identify Common Categories
- **Input**: Two lists of available categories (`available_cats1` for younger, `available_cats2` for older) and the full `ordered_categories_list` (predefined order)
- **Process**: Find categories that are in BOTH available lists, preserving the predefined order (NOT alphabetical)
- **Output**: `common_categories` - categories present in BOTH age bins, in predefined order
- **Example**: If younger has 150 categories and older has 155 categories, they might share 140 categories
- **Key Point**: Order matters! We use the predefined order to ensure submatrices are aligned correctly

#### Step 2: Map to Full Category Order
- **Input**: `common_categories` and the full `ordered_categories` list (163 categories)
- **Process**: Find the indices of common categories in the full ordered list
- **Output**: `common_indices` - positions in the 163x163 RDM matrices
- **Purpose**: This ensures we extract the correct submatrices from the full RDMs

#### Step 3: Extract Submatrices
- **Input**: Full 163x163 RDMs and `common_indices`
- **Process**: Extract square submatrices using `rdm[np.ix_(common_indices, common_indices)]`
- **Output**: Two smaller square matrices (e.g., 140x140 if 140 common categories)
- **Key Point**: These submatrices contain ONLY the common categories, but may still have NaN if there are any data issues

#### Step 4: Extract Upper Triangle
- **Process**: Use a triangular mask to extract only the upper triangle (excluding diagonal)
- **Why**: RDMs are symmetric, so we only need half the values to avoid double-counting
- **Output**: Two flattened arrays of pairwise distances
- **Size**: If n common categories, we get n×(n-1)/2 distance values

#### Step 5: Filter NaN Values
- **Process**: Create a boolean mask identifying valid (non-NaN) values in BOTH arrays
- **Filter**: Keep only pairs where BOTH RDMs have valid values
- **Output**: Two arrays of the same length with only valid distance pairs
- **Safety Check**: Even though we only use common categories, this ensures no NaN values slip through

#### Step 6: Compute Spearman Correlation
- **Method**: Spearman rank correlation (non-parametric, robust to outliers)
- **Input**: Two arrays of valid distance values (same length, same category pairs)
- **Output**: Correlation coefficient (-1 to 1)
- **Interpretation**: 
  - High correlation (>0.7): Similar representational structure across age bins
  - Low correlation (<0.5): Representational structure changed with development
  - Near 0: No relationship between structures

### Handling NaN Values

**Where NaN values come from:**
1. Categories not present in a particular age bin (expected)
2. Categories present but with insufficient data (rare, but possible)

**How we handle them:**
1. **Pre-filtering**: We only use categories present in BOTH bins (common categories)
2. **Submatrix extraction**: We extract only the common category submatrices
3. **Post-filtering**: We filter out any remaining NaN values before correlation
4. **Result**: The correlation is computed only on valid distance pairs

**Why this works:**
- By using only common categories, we ensure we're comparing the same category pairs
- The correlation reflects how similarly those common categories are organized in younger vs older periods
- Missing categories don't affect the correlation (they're simply excluded)

### Comparability Across Subjects

**Are correlations comparable across subjects?**

**YES, with important caveats:**

1. **Same correlation metric**: All subjects use Spearman correlation on the same type of data (distance matrices)

2. **Different category sets**: Each subject may have different numbers of common categories:
   - Subject A: 140 common categories → 9,730 distance pairs
   - Subject B: 150 common categories → 11,175 distance pairs
   - Subject C: 130 common categories → 8,385 distance pairs

3. **Interpretation considerations**:
   - **Absolute correlation values ARE comparable**: A correlation of 0.8 means the same thing for all subjects (strong similarity between age bins)
   - **Statistical power varies**: Subjects with more common categories have more data points, so their correlations may be more reliable
   - **Missing categories don't bias**: As long as we use only common categories, missing categories don't affect the correlation value

4. **What makes correlations comparable**:
   - Same age split (median = 16 months for all)
   - Same normalization (within-subject z-score normalization)
   - Same distance metric (cosine distance)
   - Same correlation method (Spearman)
   - Only common categories used (fair comparison)

5. **What to consider when comparing**:
   - **Number of common categories**: Tracked in `n_common_categories` - more categories = more reliable
   - **Category composition**: Different subjects may have different sets of common categories
   - **Data density**: Subjects with more data in both bins may have more stable RDMs

### Example Walkthrough

**Subject 00240001:**
- Younger bin: 155 categories available
- Older bin: 160 categories available
- Common categories: 150 categories
- Extracted submatrices: 150×150 (22,500 cells)
- Upper triangle: 11,175 distance pairs
- After NaN filtering: 11,175 valid pairs (assuming all common categories have data)
- Spearman correlation: 0.756
- **Interpretation**: Strong similarity (0.756) between younger and older representational structures, based on 150 common categories

**Subject 00320001:**
- Younger bin: 140 categories available
- Older bin: 145 categories available
- Common categories: 135 categories
- Extracted submatrices: 135×135 (18,225 cells)
- Upper triangle: 9,045 distance pairs
- After NaN filtering: 9,045 valid pairs
- Spearman correlation: 0.682
- **Interpretation**: Moderate similarity (0.682) between age bins, based on 135 common categories

**Comparison**: Subject 00240001 has a higher correlation (0.756 vs 0.682), suggesting more stable representational structure across development. However, we should also consider that Subject 00240001 has more common categories (150 vs 135), which provides more data for the correlation.

### Summary

The correlation logic:
1. ✅ Uses only categories present in BOTH age bins (fair comparison)
2. ✅ Filters out all NaN values before correlation
3. ✅ Uses Spearman correlation (robust, non-parametric)
4. ✅ Produces comparable values across subjects
5. ⚠️ But correlations should be interpreted with awareness of the number of common categories

**Key insight**: The correlation tells us how similarly categories are organized in younger vs older periods, but only for the categories that exist in both periods. This is appropriate for developmental trajectory analysis because we want to know: "For the categories this child experienced at both ages, how stable was their representational structure?"

## Null Model: Age-Shuffled Baseline

To interpret whether the within-kid correlations reflect true developmental stability or just stable individual differences, we implement a null model by randomly shuffling age labels per subject. 

### What the Null Model Tests

**Real Split (Median Age)**: Uses actual chronological ages → captures both:
- **Developmental changes** (how representations change with age)
- **Stable individual differences** (consistent patterns unique to each child)

**Null Model (Age-Shuffled)**: Randomly shuffles age labels within each subject → destroys temporal structure but preserves:
- **Stable individual differences** (child-specific patterns remain)
- **No developmental signal** (age information is scrambled)

### How to Interpret Results

**Scenario 1: Real >> Null (Large Positive Difference)**
- **Interpretation**: Strong developmental stability signal
- **Meaning**: The median age split captures genuine developmental changes. Representations are more similar within age periods than would be expected by chance.
- **Example**: Real = 0.75, Null = 0.50 → Difference = 0.25 indicates strong developmental structure

**Scenario 2: Real ≈ Null (Small Difference)**
- **Interpretation**: Weak or no developmental signal
- **Meaning**: High correlations are primarily due to stable individual differences, not developmental changes. The median split doesn't capture meaningful developmental structure.
- **Example**: Real = 0.75, Null = 0.72 → Difference = 0.03 suggests mostly stable individual patterns

**Scenario 3: Real < Null (Negative Difference)**
- **Interpretation**: Unusual pattern (rare)
- **Meaning**: Shuffled splits are more correlated than real splits, suggesting the median split might be creating artificial structure or there's something unexpected in the data.

### Key Metrics to Check

1. **Mean Difference (Real - Null)**: 
   - Large positive (>0.1): Strong developmental signal
   - Small positive (0.01-0.1): Weak developmental signal
   - Near zero (<0.01): No developmental signal, just stable individual differences

2. **Effect Size (Cohen's d)**:
   - d > 0.8: Large effect (strong developmental signal)
   - 0.5 < d < 0.8: Medium effect
   - d < 0.5: Small effect

3. **Proportion of Subjects with Real > Null**:
   - >90%: Consistent developmental signal across subjects
   - 50-90%: Mixed pattern
   - <50%: Concerning (suggests median split may not be appropriate)

4. **Statistical Significance (p-value)**:
   - p < 0.05: Statistically significant difference
   - p > 0.05: No significant difference (null model cannot be rejected)

In [54]:
# Null Model: Randomly shuffle age labels per subject
print("Computing null model with age-shuffled splits...")
print("This will help interpret whether correlations reflect developmental stability or just stable individual differences.\n")

n_permutations = 100  # Number of random shuffles per subject
null_model_data = []

# Set random seed for reproducibility
np.random.seed(42)

for subject_id, age_data in tqdm(subject_age_rdms.items(), desc="Null model"):
    if 'younger' not in subject_age_rdms[subject_id] or 'older' not in subject_age_rdms[subject_id]:
        continue
    
    # Get original age data for this subject
    original_age_data = subject_age_embeddings_normalized[subject_id]
    original_ages = list(original_age_data.keys())
    
    # Store correlations for this subject across permutations
    subject_null_corrs = []
    
    for perm_idx in range(n_permutations):
        # Randomly shuffle age labels for this subject
        shuffled_ages = original_ages.copy()
        np.random.shuffle(shuffled_ages)
        
        # Create shuffled age mapping: original_age -> shuffled_age
        age_mapping = {orig_age: shuffled_age for orig_age, shuffled_age in zip(original_ages, shuffled_ages)}
        
        # Create shuffled age data structure
        shuffled_age_data = {}
        for orig_age, categories in original_age_data.items():
            shuffled_age = age_mapping[orig_age]
            shuffled_age_data[shuffled_age] = categories
        
        # Split into younger and older using the same median threshold
        shuffled_younger_ages = {age_mo: cats for age_mo, cats in shuffled_age_data.items() 
                                if age_mo <= overall_median_age}
        shuffled_older_ages = {age_mo: cats for age_mo, cats in shuffled_age_data.items() 
                              if age_mo > overall_median_age}
        
        # Aggregate embeddings for shuffled bins
        shuffled_younger_aggregated = aggregate_embeddings_by_bin(shuffled_younger_ages, 'younger')
        shuffled_older_aggregated = aggregate_embeddings_by_bin(shuffled_older_ages, 'older')
        
        # Check if we have enough categories in both bins
        if len(shuffled_younger_aggregated) < min_categories_per_age_bin or len(shuffled_older_aggregated) < min_categories_per_age_bin:
            continue
        
        # Compute RDMs for shuffled bins
        rdm_younger_shuffled, mask_younger, cats_younger_shuffled = compute_rdm_for_bin_with_na(
            shuffled_younger_aggregated, ordered_categories
        )
        rdm_older_shuffled, mask_older, cats_older_shuffled = compute_rdm_for_bin_with_na(
            shuffled_older_aggregated, ordered_categories
        )
        
        if rdm_younger_shuffled is None or rdm_older_shuffled is None:
            continue
        
        # Compute correlation between shuffled bins
        corr_shuffled, n_common_shuffled = compute_rdm_correlation(
            rdm_younger_shuffled, rdm_older_shuffled,
            ordered_categories,
            cats_younger_shuffled, cats_older_shuffled
        )
        
        if not np.isnan(corr_shuffled):
            subject_null_corrs.append(corr_shuffled)
            null_model_data.append({
                'subject_id': subject_id,
                'permutation': perm_idx,
                'correlation': corr_shuffled,
                'n_common_categories': n_common_shuffled
            })
    
    # Store summary statistics for this subject
    if len(subject_null_corrs) > 0:
        null_model_data.append({
            'subject_id': subject_id,
            'permutation': 'mean',
            'correlation': np.mean(subject_null_corrs),
            'n_common_categories': np.nan  # Not applicable for mean
        })

null_model_df = pd.DataFrame(null_model_data)
null_model_df.to_csv(output_dir / "null_model_age_shuffled_correlations.csv", index=False)

# Compute summary statistics
print(f"\nNull model analysis:")
print(f"  Total permutations: {len(null_model_df[null_model_df['permutation'] != 'mean'])}")
print(f"  Subjects analyzed: {len(null_model_df[null_model_df['permutation'] == 'mean'])}")

# Compare real vs null correlations
real_corrs = trajectory_df['rdm_correlation'].dropna().values
null_corrs = null_model_df[null_model_df['permutation'] != 'mean']['correlation'].dropna().values

if len(null_corrs) > 0:
    print(f"\nReal vs Null correlations:")
    print(f"  Real correlation mean: {np.mean(real_corrs):.3f} ± {np.std(real_corrs):.3f}")
    print(f"  Null correlation mean: {np.mean(null_corrs):.3f} ± {np.std(null_corrs):.3f}")
    print(f"  Difference: {np.mean(real_corrs) - np.mean(null_corrs):.3f}")
    
    # Compute effect size (Cohen's d)
    pooled_std = np.sqrt((np.var(real_corrs) + np.var(null_corrs)) / 2)
    if pooled_std > 0:
        cohens_d = (np.mean(real_corrs) - np.mean(null_corrs)) / pooled_std
        print(f"  Cohen's d (effect size): {cohens_d:.3f}")
    
    # Statistical test
    from scipy import stats
    t_stat, p_value = stats.ttest_ind(real_corrs, null_corrs)
    print(f"  t-test: t={t_stat:.3f}, p={p_value:.2e}")
    
    # Per-subject comparison
    print(f"\nPer-subject comparison (Real vs Null mean):")
    subject_null_means = null_model_df[null_model_df['permutation'] == 'mean'].set_index('subject_id')['correlation']
    comparison_data = []
    for subject_id in trajectory_df['subject_id']:
        real_corr = trajectory_df[trajectory_df['subject_id'] == subject_id]['rdm_correlation'].values[0]
        if subject_id in subject_null_means.index:
            null_mean = subject_null_means[subject_id]
            diff = real_corr - null_mean
            comparison_data.append({
                'subject_id': subject_id,
                'real_correlation': real_corr,
                'null_mean_correlation': null_mean,
                'difference': diff
            })
    
    comparison_df = pd.DataFrame(comparison_data)
    comparison_df.to_csv(output_dir / "real_vs_null_comparison.csv", index=False)
    
    print(f"  Mean difference per subject: {comparison_df['difference'].mean():.3f} ± {comparison_df['difference'].std():.3f}")
    print(f"  Subjects with real > null: {len(comparison_df[comparison_df['difference'] > 0])} / {len(comparison_df)}")
    
    print(f"\nSaved null model results to {output_dir / 'null_model_age_shuffled_correlations.csv'}")
    print(f"Saved comparison to {output_dir / 'real_vs_null_comparison.csv'}")
else:
    print("  Warning: No valid null correlations computed")

Computing null model with age-shuffled splits...
This will help interpret whether correlations reflect developmental stability or just stable individual differences.



Null model: 100%|██████████| 18/18 [00:30<00:00,  1.69s/it]


Null model analysis:
  Total permutations: 1800
  Subjects analyzed: 18

Real vs Null correlations:
  Real correlation mean: 0.756 ± 0.071
  Null correlation mean: 0.782 ± 0.066
  Difference: -0.026
  Cohen's d (effect size): -0.383
  t-test: t=-1.682, p=9.28e-02

Per-subject comparison (Real vs Null mean):
  Mean difference per subject: -0.026 ± 0.023
  Subjects with real > null: 3 / 18

Saved null model results to developmental_trajectory_rdms_clip/null_model_age_shuffled_correlations.csv
Saved comparison to developmental_trajectory_rdms_clip/real_vs_null_comparison.csv





In [52]:
# Visualize real vs null model correlations
print("Creating visualization for real vs null model correlations...")

if len(null_corrs) > 0 and len(real_corrs) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # 1. Distribution comparison
    ax1 = axes[0, 0]
    ax1.hist(real_corrs, bins=20, alpha=0.7, label='Real (median split)', color='#4ECDC4', edgecolor='black')
    ax1.hist(null_corrs, bins=20, alpha=0.7, label='Null (age-shuffled)', color='#FF6B6B', edgecolor='black')
    ax1.axvline(np.mean(real_corrs), color='#4ECDC4', linestyle='--', linewidth=2, label=f'Real mean: {np.mean(real_corrs):.3f}')
    ax1.axvline(np.mean(null_corrs), color='#FF6B6B', linestyle='--', linewidth=2, label=f'Null mean: {np.mean(null_corrs):.3f}')
    ax1.set_xlabel('RDM Correlation', fontsize=12)
    ax1.set_ylabel('Frequency', fontsize=12)
    ax1.set_title('Distribution: Real vs Null Correlations', fontsize=13, pad=10)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # 2. Box plot comparison
    ax2 = axes[0, 1]
    box_data = [real_corrs, null_corrs]
    bp = ax2.boxplot(box_data, labels=['Real', 'Null'], patch_artist=True)
    colors_box = ['#4ECDC4', '#FF6B6B']
    for patch, color in zip(bp['boxes'], colors_box):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)
    ax2.set_ylabel('RDM Correlation', fontsize=12)
    ax2.set_title('Box Plot: Real vs Null', fontsize=13, pad=10)
    ax2.grid(True, alpha=0.3, axis='y')
    ax2.set_ylim([0, 1])
    
    # 3. Per-subject comparison
    if len(comparison_df) > 0:
        ax3 = axes[1, 0]
        x_pos = np.arange(len(comparison_df))
        width = 0.35
        ax3.bar(x_pos - width/2, comparison_df['real_correlation'], width, 
               label='Real', alpha=0.7, color='#4ECDC4', edgecolor='black')
        ax3.bar(x_pos + width/2, comparison_df['null_mean_correlation'], width,
               label='Null (mean)', alpha=0.7, color='#FF6B6B', edgecolor='black')
        ax3.set_xlabel('Subject', fontsize=12)
        ax3.set_ylabel('RDM Correlation', fontsize=12)
        ax3.set_title('Per-Subject: Real vs Null Mean', fontsize=13, pad=10)
        ax3.set_xticks(x_pos)
        ax3.set_xticklabels(comparison_df['subject_id'], rotation=45, ha='right', fontsize=8)
        ax3.legend()
        ax3.grid(True, alpha=0.3, axis='y')
        ax3.set_ylim([0, 1])
        
        # 4. Difference distribution
        ax4 = axes[1, 1]
        differences = comparison_df['difference']
        ax4.hist(differences, bins=20, alpha=0.7, color='#45B7D1', edgecolor='black')
        ax4.axvline(0, color='red', linestyle='--', linewidth=2, label='No difference')
        ax4.axvline(np.mean(differences), color='green', linestyle='--', linewidth=2, 
                   label=f'Mean: {np.mean(differences):.3f}')
        ax4.set_xlabel('Difference (Real - Null)', fontsize=12)
        ax4.set_ylabel('Frequency', fontsize=12)
        ax4.set_title('Distribution of Differences', fontsize=13, pad=10)
        ax4.legend()
        ax4.grid(True, alpha=0.3)
    
    plt.suptitle('Null Model: Real vs Age-Shuffled Correlations\n(Testing Developmental Stability vs Stable Individual Differences)', 
                 fontsize=14, y=0.995, fontweight='bold')
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    plt.savefig(output_dir / "null_model_age_shuffled_comparison.png", dpi=200, bbox_inches='tight')
    print(f"Saved visualization to {output_dir / 'null_model_age_shuffled_comparison.png'}")
    plt.close()
    
    print("\nVisualization complete!")
else:
    print("  Warning: Insufficient data for visualization")

Creating visualization for real vs null model correlations...
Saved visualization to developmental_trajectory_rdms_clip/null_model_age_shuffled_comparison.png

Visualization complete!


## Interpreting Null Model Results

**Your Results:**
- Real correlation: 0.756 ± 0.071
- Null correlation: 0.782 ± 0.066
- Difference: -0.026 (negative!)
- Only 3/18 subjects show real > null

### What This Means

**The negative difference indicates that the median age split is NOT capturing developmental structure.** Instead:

1. **High correlations (0.756) are primarily due to stable individual differences**, not developmental changes
2. **Random age shuffling produces HIGHER correlations** (0.782), suggesting the age-based split may be breaking up naturally similar timepoints
3. **The median split may be introducing noise** by separating timepoints that are actually more similar than the split suggests

### Possible Explanations

1. **Individual differences dominate**: Each child has a stable representational structure that doesn't change much with age in this dataset
2. **Median split may not align with developmental transitions**: The 16-month median might not correspond to a meaningful developmental boundary
3. **Age range may be too narrow**: If most data is clustered around similar ages, the split may be arbitrary
4. **Sampling effects**: The way data was collected might not capture developmental changes

### What This Means for Your Analysis

- **Within-kid correlations (0.756) are meaningful** but reflect **stable individual differences**, not developmental stability
- **Cross-kid correlations** (comparing different children) are still valid and interesting
- **The median age split** may not be the right approach for detecting developmental changes in this dataset

### Potential Next Steps

1. **Examine age distribution**: Check if ages are evenly distributed or clustered
2. **Try alternative splits**: 
   - Tertiles or quartiles instead of median
   - Age-based thresholds (e.g., 12mo, 18mo, 24mo)
   - Equal-sized bins rather than median-based
3. **Analyze age effects directly**: Correlate RDMs with continuous age rather than binary splits
4. **Focus on individual differences**: The high correlations suggest stable individual patterns are the main signal

In [55]:
# Diagnostic: Investigate why null > real
print("="*70)
print("DIAGNOSTIC: Investigating Age Distribution and Split Characteristics")
print("="*70)

# 1. Check age distribution
print("\n1. AGE DISTRIBUTION:")
all_ages_list = []
for subject_id, age_data in subject_age_embeddings_normalized.items():
    all_ages_list.extend(age_data.keys())

all_ages_array = np.array(all_ages_list)
print(f"   Total age observations: {len(all_ages_array)}")
print(f"   Age range: {all_ages_array.min():.1f} to {all_ages_array.max():.1f} months")
print(f"   Median age: {np.median(all_ages_array):.1f} months")
print(f"   Mean age: {np.mean(all_ages_array):.1f} months")
print(f"   Std age: {np.std(all_ages_array):.1f} months")

# Check distribution around median
below_median = all_ages_array[all_ages_array <= overall_median_age]
above_median = all_ages_array[all_ages_array > overall_median_age]
print(f"\n   Split at {overall_median_age:.1f} months:")
print(f"   Below/equal: {len(below_median)} observations (mean: {np.mean(below_median):.1f} ± {np.std(below_median):.1f})")
print(f"   Above: {len(above_median)} observations (mean: {np.mean(above_median):.1f} ± {np.std(above_median):.1f})")

# 2. Check per-subject age distributions
print("\n2. PER-SUBJECT AGE CHARACTERISTICS:")
subject_age_stats = []
for subject_id in trajectory_df['subject_id']:
    if subject_id in subject_age_embeddings_normalized:
        ages = np.array(list(subject_age_embeddings_normalized[subject_id].keys()))
        n_below = np.sum(ages <= overall_median_age)
        n_above = np.sum(ages > overall_median_age)
        age_range = ages.max() - ages.min()
        subject_age_stats.append({
            'subject_id': subject_id,
            'n_ages': len(ages),
            'age_min': ages.min(),
            'age_max': ages.max(),
            'age_range': age_range,
            'age_mean': ages.mean(),
            'n_below_median': n_below,
            'n_above_median': n_above,
            'split_balance': min(n_below, n_above) / max(n_below, n_above) if max(n_below, n_above) > 0 else 0
        })

age_stats_df = pd.DataFrame(subject_age_stats)
print(f"   Mean age range per subject: {age_stats_df['age_range'].mean():.1f} ± {age_stats_df['age_range'].std():.1f} months")
print(f"   Mean split balance (min/max): {age_stats_df['split_balance'].mean():.3f}")
print(f"   Subjects with very unbalanced splits (<0.3): {len(age_stats_df[age_stats_df['split_balance'] < 0.3])}")

# 3. Check if there's a relationship between age range and correlation difference
if len(comparison_df) > 0:
    print("\n3. RELATIONSHIP BETWEEN AGE RANGE AND CORRELATION DIFFERENCE:")
    merged = comparison_df.merge(age_stats_df, on='subject_id')
    if len(merged) > 0:
        corr_age_range = np.corrcoef(merged['age_range'], merged['difference'])[0, 1]
        corr_split_balance = np.corrcoef(merged['split_balance'], merged['difference'])[0, 1]
        print(f"   Correlation (age_range vs difference): {corr_age_range:.3f}")
        print(f"   Correlation (split_balance vs difference): {corr_split_balance:.3f}")
        
        # Show subjects with largest negative differences
        print(f"\n   Subjects with largest negative differences (real << null):")
        top_negative = merged.nsmallest(5, 'difference')[['subject_id', 'difference', 'age_range', 'split_balance']]
        for _, row in top_negative.iterrows():
            print(f"     {row['subject_id']}: diff={row['difference']:.3f}, range={row['age_range']:.1f}mo, balance={row['split_balance']:.2f}")

# 4. Summary interpretation
print("\n" + "="*70)
print("INTERPRETATION:")
print("="*70)
if np.mean(null_corrs) > np.mean(real_corrs):
    print("⚠️  NULL MODEL SHOWS HIGHER CORRELATIONS THAN REAL MODEL")
    print("\nThis suggests:")
    print("  • High correlations (0.756) are primarily due to STABLE INDIVIDUAL DIFFERENCES")
    print("  • The median age split is NOT capturing developmental structure")
    print("  • The split may be breaking up naturally similar timepoints")
    print("\nPossible reasons:")
    print("  • Individual differences dominate over age-related changes")
    print("  • Age range may be too narrow to see developmental effects")
    print("  • Median split may not align with meaningful developmental transitions")
    print("  • Data collection may not capture developmental changes")
else:
    print("✓ Real model shows higher correlations - developmental signal detected")
print("="*70)

# Save diagnostic data
age_stats_df.to_csv(output_dir / "age_distribution_diagnostics.csv", index=False)
if len(comparison_df) > 0 and len(age_stats_df) > 0:
    merged.to_csv(output_dir / "correlation_age_relationship.csv", index=False)

DIAGNOSTIC: Investigating Age Distribution and Split Characteristics

1. AGE DISTRIBUTION:
   Total age observations: 266
   Age range: 6.0 to 37.0 months
   Median age: 16.0 months
   Mean age: 16.6 months
   Std age: 5.8 months

   Split at 16.0 months:
   Below/equal: 146 observations (mean: 12.5 ± 2.6)
   Above: 120 observations (mean: 21.6 ± 4.5)

2. PER-SUBJECT AGE CHARACTERISTICS:
   Mean age range per subject: 10.7 ± 4.5 months
   Mean split balance (min/max): 0.508
   Subjects with very unbalanced splits (<0.3): 4

3. RELATIONSHIP BETWEEN AGE RANGE AND CORRELATION DIFFERENCE:
   Correlation (age_range vs difference): 0.342
   Correlation (split_balance vs difference): -0.003

   Subjects with largest negative differences (real << null):
     00680001: diff=-0.080, range=4.0mo, balance=0.25
     00510001: diff=-0.062, range=7.0mo, balance=0.60
     00590001: diff=-0.053, range=4.0mo, balance=0.67
     00370001: diff=-0.047, range=10.0mo, balance=0.57
     00400001: diff=-0.037,

## First-Order Similarity: Category-Wise Embedding Correlations

In addition to second-order similarity (RDM correlations that compare the geometric structure), we also compute first-order similarity by directly comparing individual category embeddings.

**Second-order (RDM correlation)**: Compares the structure/geometry of representations
- "Are the relationships between categories similar?"
- Example: If "dog" and "cat" are similar in younger period, are they also similar in older period?

**First-order (category embedding correlation)**: Compares individual category embeddings directly
- "Are the individual category representations similar?"
- Example: Is the embedding for "dog" in younger period similar to "dog" in older period?

### Within-Kid First-Order Similarity
For each subject, for each category present in both younger and older bins, compute the correlation between the category embeddings. Then average across categories to get a within-kid first-order similarity score.

### Between-Kid First-Order Similarity
For each pair of subjects, for each category present in both subjects, compute the correlation between the category embeddings. Can compare:
- Younger-Younger: subject 1 younger vs subject 2 younger
- Older-Older: subject 1 older vs subject 2 older
- Younger-Older: subject 1 younger vs subject 2 older
- Older-Younger: subject 1 older vs subject 2 younger

In [56]:
# First-Order Similarity: Category-wise embedding correlations
print("Computing first-order similarity (category-wise embedding correlations)...")
print("This compares individual category embeddings directly, rather than the geometric structure.\n")

from scipy.stats import spearmanr, pearsonr

def compute_first_order_correlation(embedding1, embedding2):
    """
    Compute correlation between two category embeddings.
    
    Args:
        embedding1: numpy array of embedding for category in first condition
        embedding2: numpy array of embedding for category in second condition
    
    Returns:
        spearman_corr: Spearman correlation coefficient
        pearson_corr: Pearson correlation coefficient
    """
    # Flatten embeddings if needed
    emb1_flat = embedding1.flatten()
    emb2_flat = embedding2.flatten()
    
    # Ensure same length
    if len(emb1_flat) != len(emb2_flat):
        return np.nan, np.nan
    
    # Compute correlations
    spearman_corr, _ = spearmanr(emb1_flat, emb2_flat)
    pearson_corr, _ = pearsonr(emb1_flat, emb2_flat)
    
    return spearman_corr, pearson_corr

# ============================================================================
# WITHIN-KID FIRST-ORDER SIMILARITY (Category-wise)
# ============================================================================
print("1. Computing within-kid first-order similarity (younger vs older for same subject)...")
print("   Computing correlations for EACH category separately, then aggregating across subjects.\n")

within_kid_category_wise_data = []
within_kid_summary_data = []

for subject_id in tqdm(subject_age_rdms.keys(), desc="Within-kid first-order"):
    if 'younger' not in subject_age_rdms[subject_id] or 'older' not in subject_age_rdms[subject_id]:
        continue
    
    # Get aggregated embeddings for younger and older bins
    younger_ages = {age_mo: cats for age_mo, cats in subject_age_embeddings_normalized[subject_id].items() 
                   if age_mo <= overall_median_age}
    older_ages = {age_mo: cats for age_mo, cats in subject_age_embeddings_normalized[subject_id].items() 
                  if age_mo > overall_median_age}
    
    younger_aggregated = aggregate_embeddings_by_bin(younger_ages, 'younger')
    older_aggregated = aggregate_embeddings_by_bin(older_ages, 'older')
    
    # Find common categories
    common_categories = [cat for cat in younger_aggregated.keys() if cat in older_aggregated.keys()]
    
    if len(common_categories) < 2:
        continue
    
        # Compute correlation for each category separately
    category_correlations = []
    category_pearson_correlations = []
    for cat in common_categories:
        emb_younger = younger_aggregated[cat]
        emb_older = older_aggregated[cat]
        
        spearman_corr, pearson_corr = compute_first_order_correlation(emb_younger, emb_older)
        
        if not np.isnan(spearman_corr):
            within_kid_category_wise_data.append({
                'subject_id': subject_id,
                'category': cat,
                'spearman_correlation': spearman_corr,
                'pearson_correlation': pearson_corr
            })
            category_correlations.append(spearman_corr)
    
    if len(category_correlations) > 0:
        # Store summary per subject
        within_kid_summary_data.append({
            'subject_id': subject_id,
            'n_common_categories': len(category_correlations),
            'mean_spearman_correlation': np.mean(category_correlations),
            'mean_pearson_correlation': np.mean([c['pearson_correlation'] for c in within_kid_category_wise_data 
                                                 if c['subject_id'] == subject_id]),
            'std_spearman_correlation': np.std(category_correlations)
        })

# Save category-wise data
within_kid_category_wise_df = pd.DataFrame(within_kid_category_wise_data)
within_kid_category_wise_df.to_csv(output_dir / "within_kid_first_order_category_wise.csv", index=False)

# Aggregate by category across all subjects
category_aggregated = within_kid_category_wise_df.groupby('category').agg({
    'spearman_correlation': ['mean', 'std', 'count'],
    'pearson_correlation': ['mean', 'std']
}).reset_index()
category_aggregated.columns = ['category', 'mean_spearman', 'std_spearman', 'n_subjects', 
                               'mean_pearson', 'std_pearson']
category_aggregated = category_aggregated.sort_values('mean_spearman', ascending=False)
category_aggregated.to_csv(output_dir / "within_kid_first_order_by_category.csv", index=False)

within_kid_summary_df = pd.DataFrame(within_kid_summary_data)
within_kid_summary_df.to_csv(output_dir / "within_kid_first_order_similarity.csv", index=False)

print(f"\nWithin-kid first-order similarity:")
print(f"  Subjects analyzed: {len(within_kid_summary_df)}")
print(f"  Total category-subject pairs: {len(within_kid_category_wise_df)}")
print(f"  Mean Spearman correlation (across all): {within_kid_category_wise_df['spearman_correlation'].mean():.3f} ± {within_kid_category_wise_df['spearman_correlation'].std():.3f}")
print(f"\nTop 10 most correlated categories (across subjects):")
for _, row in category_aggregated.head(10).iterrows():
    print(f"  {row['category']:20s}: {row['mean_spearman']:.3f} ± {row['std_spearman']:.3f} (n={int(row['n_subjects'])})")
print(f"\nBottom 10 least correlated categories:")
for _, row in category_aggregated.tail(10).iterrows():
    print(f"  {row['category']:20s}: {row['mean_spearman']:.3f} ± {row['std_spearman']:.3f} (n={int(row['n_subjects'])})")
print(f"\nSaved category-wise data to {output_dir / 'within_kid_first_order_category_wise.csv'}")
print(f"Saved aggregated by category to {output_dir / 'within_kid_first_order_by_category.csv'}")

# ============================================================================
# BETWEEN-KID FIRST-ORDER SIMILARITY (Category-wise)
# ============================================================================
print("\n2. Computing between-kid first-order similarity...")
print("   Computing correlations for EACH category separately, then aggregating across subject pairs.\n")

valid_subjects = [sid for sid in subject_age_rdms.keys() 
                  if 'younger' in subject_age_rdms[sid] and 'older' in subject_age_rdms[sid]]

between_kid_category_wise_data = []
between_kid_summary_data = []

for i, subject_id_1 in enumerate(tqdm(valid_subjects, desc="Between-kid first-order")):
    for subject_id_2 in valid_subjects[i+1:]:  # Only upper triangle
        # Get aggregated embeddings for both subjects
        younger_ages_1 = {age_mo: cats for age_mo, cats in subject_age_embeddings_normalized[subject_id_1].items() 
                        if age_mo <= overall_median_age}
        older_ages_1 = {age_mo: cats for age_mo, cats in subject_age_embeddings_normalized[subject_id_1].items() 
                        if age_mo > overall_median_age}
        younger_ages_2 = {age_mo: cats for age_mo, cats in subject_age_embeddings_normalized[subject_id_2].items() 
                         if age_mo <= overall_median_age}
        older_ages_2 = {age_mo: cats for age_mo, cats in subject_age_embeddings_normalized[subject_id_2].items() 
                       if age_mo > overall_median_age}
        
        younger_agg_1 = aggregate_embeddings_by_bin(younger_ages_1, 'younger')
        older_agg_1 = aggregate_embeddings_by_bin(older_ages_1, 'older')
        younger_agg_2 = aggregate_embeddings_by_bin(younger_ages_2, 'younger')
        older_agg_2 = aggregate_embeddings_by_bin(older_ages_2, 'older')
        
        # Compute correlations for each comparison type
        comparison_types = [
            ('younger_younger', younger_agg_1, younger_agg_2),
            ('older_older', older_agg_1, older_agg_2),
            ('younger_older', younger_agg_1, older_agg_2),
            ('older_younger', older_agg_1, younger_agg_2)
        ]
        
        for comp_type, emb_dict_1, emb_dict_2 in comparison_types:
            # Find common categories
            common_cats = [cat for cat in emb_dict_1.keys() if cat in emb_dict_2.keys()]
            
            if len(common_cats) < 2:
                continue
            
            # Compute correlation for each category separately
            category_corrs = []
            for cat in common_cats:
                emb_1 = emb_dict_1[cat]
                emb_2 = emb_dict_2[cat]
                
                spearman_corr, pearson_corr = compute_first_order_correlation(emb_1, emb_2)
                
                if not np.isnan(spearman_corr):
                    between_kid_category_wise_data.append({
                        'subject_id_1': subject_id_1,
                        'subject_id_2': subject_id_2,
                        'comparison_type': comp_type,
                        'category': cat,
                        'spearman_correlation': spearman_corr,
                        'pearson_correlation': pearson_corr
                    })
                    category_corrs.append(spearman_corr)
            
            if len(category_corrs) > 0:
                between_kid_summary_data.append({
                    'subject_id_1': subject_id_1,
                    'subject_id_2': subject_id_2,
                    'comparison_type': comp_type,
                    'n_common_categories': len(category_corrs),
                    'mean_spearman_correlation': np.mean(category_corrs),
                    'std_spearman_correlation': np.std(category_corrs)
                })

# Save category-wise data
between_kid_category_wise_df = pd.DataFrame(between_kid_category_wise_data)
between_kid_category_wise_df.to_csv(output_dir / "between_kid_first_order_category_wise.csv", index=False)

# Aggregate by category and comparison type
category_comparison_agg = between_kid_category_wise_df.groupby(['category', 'comparison_type']).agg({
    'spearman_correlation': ['mean', 'std', 'count']
}).reset_index()
category_comparison_agg.columns = ['category', 'comparison_type', 'mean_spearman', 'std_spearman', 'n_pairs']

# Also aggregate across all comparison types
category_agg_all = between_kid_category_wise_df.groupby('category').agg({
    'spearman_correlation': ['mean', 'std', 'count']
}).reset_index()
category_agg_all.columns = ['category', 'mean_spearman', 'std_spearman', 'n_pairs']
category_agg_all = category_agg_all.sort_values('mean_spearman', ascending=False)
category_agg_all.to_csv(output_dir / "between_kid_first_order_by_category.csv", index=False)

between_kid_summary_df = pd.DataFrame(between_kid_summary_data)
between_kid_summary_df.to_csv(output_dir / "between_kid_first_order_similarity.csv", index=False)

print(f"\nBetween-kid first-order similarity:")
print(f"  Total comparisons: {len(between_kid_summary_df)}")
print(f"  Total category-pair comparisons: {len(between_kid_category_wise_df)}")
print(f"  Subject pairs: {len(valid_subjects) * (len(valid_subjects) - 1) // 2}")

for comp_type in ['younger_younger', 'older_older', 'younger_older', 'older_younger']:
    type_data = between_kid_summary_df[between_kid_summary_df['comparison_type'] == comp_type]
    if len(type_data) > 0:
        print(f"\n  {comp_type}:")
        print(f"    Mean correlation: {type_data['mean_spearman_correlation'].mean():.3f} ± {type_data['mean_spearman_correlation'].std():.3f}")
        print(f"    N comparisons: {len(type_data)}")

print(f"\nTop 10 most correlated categories (across all subject pairs):")
for _, row in category_agg_all.head(10).iterrows():
    print(f"  {row['category']:20s}: {row['mean_spearman']:.3f} ± {row['std_spearman']:.3f} (n={int(row['n_pairs'])})")

print(f"\nSaved category-wise data to {output_dir / 'between_kid_first_order_category_wise.csv'}")
print(f"Saved aggregated by category to {output_dir / 'between_kid_first_order_by_category.csv'}")

Computing first-order similarity (category-wise embedding correlations)...
This compares individual category embeddings directly, rather than the geometric structure.

1. Computing within-kid first-order similarity (younger vs older for same subject)...


Within-kid first-order: 100%|██████████| 18/18 [00:01<00:00, 16.28it/s]



Within-kid first-order similarity:
  Subjects analyzed: 18
  Mean Spearman correlation: 0.718 ± 0.057
  Mean Pearson correlation: 0.726 ± 0.057
  Mean categories per subject: 144.1

Saved to developmental_trajectory_rdms_clip/within_kid_first_order_similarity.csv

2. Computing between-kid first-order similarity...


Between-kid first-order: 100%|██████████| 18/18 [00:35<00:00,  1.95s/it]


Between-kid first-order similarity:
  Total comparisons: 612
  Subject pairs: 153

  younger_younger:
    Mean correlation: 0.631 ± 0.078
    N comparisons: 153

  older_older:
    Mean correlation: 0.600 ± 0.052
    N comparisons: 153

  younger_older:
    Mean correlation: 0.609 ± 0.077
    N comparisons: 153

  older_younger:
    Mean correlation: 0.620 ± 0.055
    N comparisons: 153

Saved to developmental_trajectory_rdms_clip/between_kid_first_order_similarity.csv





In [61]:
# Visualize first-order similarity with category-wise breakdown
print("Creating visualizations for first-order similarity...")

# Check if the data exists (from previous cell)
if 'within_kid_summary_df' not in globals() or len(within_kid_summary_df) == 0:
    print("Warning: within_kid_summary_df not found. Please run the previous cell first.")
else:
    # Merge first-order and second-order results for comparison (using summary data)
    comparison_data = []
    for subject_id in within_kid_summary_df['subject_id']:
        first_order = within_kid_summary_df[within_kid_summary_df['subject_id'] == subject_id]
        second_order = trajectory_df[trajectory_df['subject_id'] == subject_id]
        
        if len(first_order) > 0 and len(second_order) > 0:
            comparison_data.append({
                'subject_id': subject_id,
                'first_order_spearman': first_order['mean_spearman_correlation'].values[0],
                'second_order_spearman': second_order['rdm_correlation'].values[0],
                'difference': first_order['mean_spearman_correlation'].values[0] - second_order['rdm_correlation'].values[0]
            })
    
    comparison_similarity_df = pd.DataFrame(comparison_data)
    
    if len(comparison_similarity_df) > 0:
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    # 1. Scatter plot: First-order vs Second-order
    ax1 = axes[0, 0]
    ax1.scatter(comparison_similarity_df['second_order_spearman'], 
               comparison_similarity_df['first_order_spearman'],
               alpha=0.7, s=100, edgecolors='black', linewidth=1.5)
    
    # Add diagonal line
    min_val = min(comparison_similarity_df['second_order_spearman'].min(), 
                  comparison_similarity_df['first_order_spearman'].min())
    max_val = max(comparison_similarity_df['second_order_spearman'].max(), 
                  comparison_similarity_df['first_order_spearman'].max())
    ax1.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.5, label='y=x')
    
    ax1.set_xlabel('Second-Order Similarity (RDM Correlation)', fontsize=12)
    ax1.set_ylabel('First-Order Similarity (Category Embedding Correlation)', fontsize=12)
    ax1.set_title('First-Order vs Second-Order Similarity', fontsize=13, pad=10)
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Add correlation coefficient
    corr_coef = np.corrcoef(comparison_similarity_df['second_order_spearman'], 
                           comparison_similarity_df['first_order_spearman'])[0, 1]
    ax1.text(0.05, 0.95, f'r = {corr_coef:.3f}', transform=ax1.transAxes,
            fontsize=12, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # 2. Distribution comparison
    ax2 = axes[0, 1]
    ax2.hist(comparison_similarity_df['first_order_spearman'], bins=15, alpha=0.7, 
            label='First-Order', color='#4ECDC4', edgecolor='black')
    ax2.hist(comparison_similarity_df['second_order_spearman'], bins=15, alpha=0.7, 
            label='Second-Order', color='#FF6B6B', edgecolor='black')
    ax2.axvline(comparison_similarity_df['first_order_spearman'].mean(), 
               color='#4ECDC4', linestyle='--', linewidth=2, 
               label=f'First-Order mean: {comparison_similarity_df["first_order_spearman"].mean():.3f}')
    ax2.axvline(comparison_similarity_df['second_order_spearman'].mean(), 
               color='#FF6B6B', linestyle='--', linewidth=2,
               label=f'Second-Order mean: {comparison_similarity_df["second_order_spearman"].mean():.3f}')
    ax2.set_xlabel('Correlation', fontsize=12)
    ax2.set_ylabel('Frequency', fontsize=12)
    ax2.set_title('Distribution Comparison', fontsize=13, pad=10)
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # 3. Box plot comparison
    ax3 = axes[1, 0]
    box_data = [comparison_similarity_df['first_order_spearman'], 
                comparison_similarity_df['second_order_spearman']]
    bp = ax3.boxplot(box_data, labels=['First-Order', 'Second-Order'], patch_artist=True)
    colors_box = ['#4ECDC4', '#FF6B6B']
    for patch, color in zip(bp['boxes'], colors_box):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)
    ax3.set_ylabel('Correlation', fontsize=12)
    ax3.set_title('Box Plot Comparison', fontsize=13, pad=10)
    ax3.grid(True, alpha=0.3, axis='y')
    ax3.set_ylim([0, 1])
    
    # 4. Difference distribution
    ax4 = axes[1, 1]
    differences = comparison_similarity_df['difference']
    ax4.hist(differences, bins=15, alpha=0.7, color='#45B7D1', edgecolor='black')
    ax4.axvline(0, color='red', linestyle='--', linewidth=2, label='No difference')
    ax4.axvline(differences.mean(), color='green', linestyle='--', linewidth=2, 
               label=f'Mean: {differences.mean():.3f}')
    ax4.set_xlabel('Difference (First-Order - Second-Order)', fontsize=12)
    ax4.set_ylabel('Frequency', fontsize=12)
    ax4.set_title('Difference Distribution', fontsize=13, pad=10)
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    plt.suptitle('First-Order vs Second-Order Similarity Comparison\n(Within-Kid: Younger vs Older)', 
                 fontsize=14, y=0.995, fontweight='bold')
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    plt.savefig(output_dir / "first_order_vs_second_order_comparison.png", dpi=200, bbox_inches='tight')
    print(f"Saved comparison to {output_dir / 'first_order_vs_second_order_comparison.png'}")
    plt.close()
    
    # Print summary statistics
    print(f"\nSummary comparison:")
    print(f"  First-Order mean: {comparison_similarity_df['first_order_spearman'].mean():.3f} ± {comparison_similarity_df['first_order_spearman'].std():.3f}")
    print(f"  Second-Order mean: {comparison_similarity_df['second_order_spearman'].mean():.3f} ± {comparison_similarity_df['second_order_spearman'].std():.3f}")
    print(f"  Mean difference: {differences.mean():.3f} ± {differences.std():.3f}")
    print(f"  Correlation between measures: {corr_coef:.3f}")
    print(f"  Subjects with first-order > second-order: {len(comparison_similarity_df[differences > 0])} / {len(comparison_similarity_df)}")
    
    comparison_similarity_df.to_csv(output_dir / "first_vs_second_order_comparison.csv", index=False)
    print(f"Saved comparison data to {output_dir / 'first_vs_second_order_comparison.csv'}")
    
    # Visualize between-kid first-order similarity
    if 'between_kid_summary_df' in globals() and len(between_kid_summary_df) > 0:
        fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
        # 1. Distribution by comparison type
        ax1 = axes[0, 0]
        comparison_types = ['younger_younger', 'older_older', 'younger_older', 'older_younger']
        colors_comp = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A']
        
        for comp_type, color in zip(comparison_types, colors_comp):
            type_data = between_kid_summary_df[between_kid_summary_df['comparison_type'] == comp_type]
            if len(type_data) > 0:
                ax1.hist(type_data['mean_spearman_correlation'], bins=20, alpha=0.6, 
                        label=comp_type.replace('_', ' ').title(), color=color, edgecolor='black')
        
        ax1.set_xlabel('First-Order Correlation', fontsize=12)
        ax1.set_ylabel('Frequency', fontsize=12)
        ax1.set_title('Between-Kid First-Order Similarity by Type', fontsize=13, pad=10)
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        
        # 2. Box plot by comparison type
        ax2 = axes[0, 1]
        box_data = []
        labels = []
        for comp_type in comparison_types:
            type_data = between_kid_summary_df[between_kid_summary_df['comparison_type'] == comp_type]
            if len(type_data) > 0:
                box_data.append(type_data['mean_spearman_correlation'].values)
                labels.append(comp_type.replace('_', '\n').title())
        
        if len(box_data) > 0:
            bp = ax2.boxplot(box_data, labels=labels, patch_artist=True)
            for patch, color in zip(bp['boxes'], colors_comp[:len(bp['boxes'])]):
                patch.set_facecolor(color)
                patch.set_alpha(0.7)
            ax2.set_ylabel('First-Order Correlation', fontsize=12)
            ax2.set_title('Box Plot by Comparison Type', fontsize=13, pad=10)
            ax2.grid(True, alpha=0.3, axis='y')
            ax2.set_ylim([0, 1])
            plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45, ha='right')
        
        # 3. Mean correlations by type (bar plot)
        ax3 = axes[1, 0]
        means = []
        stds = []
        type_labels = []
        for comp_type in comparison_types:
            type_data = between_kid_summary_df[between_kid_summary_df['comparison_type'] == comp_type]
            if len(type_data) > 0:
                means.append(type_data['mean_spearman_correlation'].mean())
                stds.append(type_data['mean_spearman_correlation'].std())
                type_labels.append(comp_type.replace('_', ' ').title())
        
        if len(means) > 0:
            bars = ax3.bar(range(len(type_labels)), means, yerr=stds, 
                          color=colors_comp[:len(type_labels)], alpha=0.7, 
                          capsize=5, edgecolor='black')
            ax3.set_xticks(range(len(type_labels)))
            ax3.set_xticklabels(type_labels, rotation=45, ha='right')
            ax3.set_ylabel('Mean First-Order Correlation', fontsize=12)
            ax3.set_title('Mean Correlations by Comparison Type', fontsize=13, pad=10)
            ax3.grid(True, alpha=0.3, axis='y')
            ax3.set_ylim([0, 1])
        
        # 4. Heatmap: subject pairs x comparison type
        ax4 = axes[1, 1]
        pivot_data = between_kid_summary_df.pivot_table(
            index=['subject_id_1', 'subject_id_2'], 
            columns='comparison_type', 
            values='mean_spearman_correlation'
        )
        
        if len(pivot_data) > 0:
            # Create a simplified heatmap (sample if too many pairs)
            if len(pivot_data) > 50:
                # Sample for visualization
                pivot_data_viz = pivot_data.sample(50, random_state=42).sort_index()
            else:
                pivot_data_viz = pivot_data.sort_index()
            
            im = ax4.imshow(pivot_data_viz.values, aspect='auto', cmap='RdYlBu_r', vmin=0, vmax=1)
            ax4.set_xticks(range(len(pivot_data_viz.columns)))
            ax4.set_xticklabels([col.replace('_', '\n').title() for col in pivot_data_viz.columns], 
                               fontsize=9)
            ax4.set_ylabel('Subject Pairs (sample)', fontsize=12)
            ax4.set_title('First-Order Correlations: Subject Pairs × Type', fontsize=13, pad=10)
            plt.colorbar(im, ax=ax4, label='Correlation')
        
        plt.suptitle('Between-Kid First-Order Similarity', 
                     fontsize=14, y=0.995, fontweight='bold')
        plt.tight_layout(rect=[0, 0, 1, 0.98])
        plt.savefig(output_dir / "between_kid_first_order_visualization.png", dpi=200, bbox_inches='tight')
        print(f"Saved between-kid visualization to {output_dir / 'between_kid_first_order_visualization.png'}")
        plt.close()
    
    print("\nFirst-order similarity analysis complete!")

: 

In [60]:
# Visualize category-wise first-order correlations
print("Creating category-wise visualizations...")

# ============================================================================
# WITHIN-KID: Category-wise correlations
# ============================================================================
if len(category_aggregated) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # 1. Bar plot: Top and bottom categories
    ax1 = axes[0, 0]
    top_n = 20
    top_cats = category_aggregated.head(top_n)
    bottom_cats = category_aggregated.tail(top_n)
    
    # Combine top and bottom
    plot_cats = pd.concat([top_cats, bottom_cats]).sort_values('mean_spearman', ascending=True)
    colors_cats = ['#4ECDC4' if cat in top_cats['category'].values else '#FF6B6B' 
                   for cat in plot_cats['category']]
    
    y_pos = np.arange(len(plot_cats))
    bars = ax1.barh(y_pos, plot_cats['mean_spearman'], xerr=plot_cats['std_spearman'],
                   color=colors_cats, alpha=0.7, edgecolor='black', capsize=3)
    ax1.set_yticks(y_pos)
    ax1.set_yticklabels(plot_cats['category'], fontsize=9)
    ax1.set_xlabel('Mean Spearman Correlation (across subjects)', fontsize=12)
    ax1.set_title(f'Top {top_n} and Bottom {top_n} Categories\n(Within-Kid: Younger vs Older)', fontsize=13, pad=10)
    ax1.axvline(0, color='black', linestyle='-', linewidth=0.5)
    ax1.grid(True, alpha=0.3, axis='x')
    ax1.set_xlim([0, 1])
    
    # 2. Distribution of correlations per category
    ax2 = axes[0, 1]
    # Sample categories for violin plot (too many to show all)
    if len(category_aggregated) > 30:
        sample_cats = pd.concat([category_aggregated.head(15), category_aggregated.tail(15)])
    else:
        sample_cats = category_aggregated
    
    violin_data = []
    violin_labels = []
    for _, row in sample_cats.iterrows():
        cat_data = within_kid_category_wise_df[within_kid_category_wise_df['category'] == row['category']]['spearman_correlation'].values
        if len(cat_data) > 0:
            violin_data.append(cat_data)
            violin_labels.append(row['category'])
    
    if len(violin_data) > 0:
        parts = ax2.violinplot(violin_data, positions=range(len(violin_labels)), 
                              showmeans=True, showmedians=True, widths=0.8)
        for pc in parts['bodies']:
            pc.set_facecolor('#45B7D1')
            pc.set_alpha(0.7)
        ax2.set_xticks(range(len(violin_labels)))
        ax2.set_xticklabels(violin_labels, rotation=90, ha='right', fontsize=8)
        ax2.set_ylabel('Spearman Correlation', fontsize=12)
        ax2.set_title('Distribution of Correlations by Category\n(Within-Kid)', fontsize=13, pad=10)
        ax2.grid(True, alpha=0.3, axis='y')
        ax2.set_ylim([0, 1])
    
    # 3. Heatmap: Categories x Subjects
    ax3 = axes[1, 0]
    # Get top categories for heatmap
    top_cats_for_heatmap = category_aggregated.head(30)['category'].values
    heatmap_data = []
    for cat in top_cats_for_heatmap:
        cat_subj_data = []
        for subject_id in within_kid_summary_df['subject_id']:
            cat_subj_corr = within_kid_category_wise_df[
                (within_kid_category_wise_df['category'] == cat) & 
                (within_kid_category_wise_df['subject_id'] == subject_id)
            ]['spearman_correlation'].values
            if len(cat_subj_corr) > 0:
                cat_subj_data.append(cat_subj_corr[0])
            else:
                cat_subj_data.append(np.nan)
        heatmap_data.append(cat_subj_data)
    
    heatmap_array = np.array(heatmap_data)
    im = ax3.imshow(heatmap_array, aspect='auto', cmap='RdYlBu_r', vmin=0, vmax=1)
    ax3.set_yticks(range(len(top_cats_for_heatmap)))
    ax3.set_yticklabels(top_cats_for_heatmap, fontsize=8)
    ax3.set_xticks(range(len(within_kid_summary_df)))
    ax3.set_xticklabels(within_kid_summary_df['subject_id'], rotation=90, ha='right', fontsize=8)
    ax3.set_xlabel('Subject ID', fontsize=12)
    ax3.set_ylabel('Category', fontsize=12)
    ax3.set_title('Top 30 Categories: Correlation by Subject\n(Within-Kid)', fontsize=13, pad=10)
    plt.colorbar(im, ax=ax3, label='Correlation')
    
    # 4. Scatter: Mean correlation vs number of subjects
    ax4 = axes[1, 1]
    ax4.scatter(category_aggregated['n_subjects'], category_aggregated['mean_spearman'],
               alpha=0.6, s=100, edgecolors='black', linewidth=1)
    ax4.set_xlabel('Number of Subjects with Category', fontsize=12)
    ax4.set_ylabel('Mean Spearman Correlation', fontsize=12)
    ax4.set_title('Correlation vs Category Prevalence\n(Within-Kid)', fontsize=13, pad=10)
    ax4.grid(True, alpha=0.3)
    
    # Add correlation coefficient
    corr_coef = np.corrcoef(category_aggregated['n_subjects'], 
                           category_aggregated['mean_spearman'])[0, 1]
    ax4.text(0.05, 0.95, f'r = {corr_coef:.3f}', transform=ax4.transAxes,
            fontsize=12, verticalalignment='top', 
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.suptitle('Within-Kid First-Order Similarity: Category-Wise Analysis\n(Younger vs Older for Same Subject)', 
                 fontsize=14, y=0.995, fontweight='bold')
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    plt.savefig(output_dir / "within_kid_first_order_category_wise_visualization.png", dpi=200, bbox_inches='tight')
    print(f"Saved within-kid category-wise visualization to {output_dir / 'within_kid_first_order_category_wise_visualization.png'}")
    plt.close()

# ============================================================================
# BETWEEN-KID: Category-wise correlations
# ============================================================================
if len(category_agg_all) > 0:
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    
    # 1. Bar plot: Top and bottom categories (across all comparison types)
    ax1 = axes[0, 0]
    top_n = 20
    top_cats = category_agg_all.head(top_n)
    bottom_cats = category_agg_all.tail(top_n)
    
    plot_cats = pd.concat([top_cats, bottom_cats]).sort_values('mean_spearman', ascending=True)
    colors_cats = ['#4ECDC4' if cat in top_cats['category'].values else '#FF6B6B' 
                   for cat in plot_cats['category']]
    
    y_pos = np.arange(len(plot_cats))
    bars = ax1.barh(y_pos, plot_cats['mean_spearman'], xerr=plot_cats['std_spearman'],
                   color=colors_cats, alpha=0.7, edgecolor='black', capsize=3)
    ax1.set_yticks(y_pos)
    ax1.set_yticklabels(plot_cats['category'], fontsize=9)
    ax1.set_xlabel('Mean Spearman Correlation (across subject pairs)', fontsize=12)
    ax1.set_title(f'Top {top_n} and Bottom {top_n} Categories\n(Between-Kid: All Comparison Types)', fontsize=13, pad=10)
    ax1.axvline(0, color='black', linestyle='-', linewidth=0.5)
    ax1.grid(True, alpha=0.3, axis='x')
    ax1.set_xlim([0, 1])
    
    # 2. Comparison by type for top categories
    ax2 = axes[0, 1]
    top_cats_list = category_agg_all.head(15)['category'].values
    comparison_types = ['younger_younger', 'older_older', 'younger_older', 'older_younger']
    x_pos = np.arange(len(top_cats_list))
    width = 0.2
    
    for i, comp_type in enumerate(comparison_types):
        type_means = []
        for cat in top_cats_list:
            cat_type_data = category_comparison_agg[
                (category_comparison_agg['category'] == cat) & 
                (category_comparison_agg['comparison_type'] == comp_type)
            ]
            if len(cat_type_data) > 0:
                type_means.append(cat_type_data['mean_spearman'].values[0])
            else:
                type_means.append(np.nan)
        
        ax2.bar(x_pos + i*width, type_means, width, 
               label=comp_type.replace('_', ' ').title(), alpha=0.7)
    
    ax2.set_xticks(x_pos + width * 1.5)
    ax2.set_xticklabels(top_cats_list, rotation=90, ha='right', fontsize=8)
    ax2.set_ylabel('Mean Correlation', fontsize=12)
    ax2.set_title('Top 15 Categories by Comparison Type\n(Between-Kid)', fontsize=13, pad=10)
    ax2.legend(fontsize=9, ncol=2)
    ax2.grid(True, alpha=0.3, axis='y')
    ax2.set_ylim([0, 1])
    
    # 3. Heatmap: Categories x Comparison Types
    ax3 = axes[1, 0]
    top_cats_heatmap = category_agg_all.head(25)['category'].values
    heatmap_data = []
    for cat in top_cats_heatmap:
        cat_type_means = []
        for comp_type in comparison_types:
            cat_type_data = category_comparison_agg[
                (category_comparison_agg['category'] == cat) & 
                (category_comparison_agg['comparison_type'] == comp_type)
            ]
            if len(cat_type_data) > 0:
                cat_type_means.append(cat_type_data['mean_spearman'].values[0])
            else:
                cat_type_means.append(np.nan)
        heatmap_data.append(cat_type_means)
    
    heatmap_array = np.array(heatmap_data)
    im = ax3.imshow(heatmap_array, aspect='auto', cmap='RdYlBu_r', vmin=0, vmax=1)
    ax3.set_yticks(range(len(top_cats_heatmap)))
    ax3.set_yticklabels(top_cats_heatmap, fontsize=8)
    ax3.set_xticks(range(len(comparison_types)))
    ax3.set_xticklabels([ct.replace('_', '\n').title() for ct in comparison_types], fontsize=10)
    ax3.set_xlabel('Comparison Type', fontsize=12)
    ax3.set_ylabel('Category', fontsize=12)
    ax3.set_title('Top 25 Categories: Correlation by Comparison Type\n(Between-Kid)', fontsize=13, pad=10)
    plt.colorbar(im, ax=ax3, label='Correlation')
    
    # 4. Mean correlations by comparison type (bar plot)
    ax4 = axes[1, 1]
    type_means_all = []
    type_stds_all = []
    type_labels = []
    for comp_type in comparison_types:
        type_data = category_comparison_agg[category_comparison_agg['comparison_type'] == comp_type]
        if len(type_data) > 0:
            type_means_all.append(type_data['mean_spearman'].mean())
            type_stds_all.append(type_data['mean_spearman'].std())
            type_labels.append(comp_type.replace('_', ' ').title())
    
    if len(type_means_all) > 0:
        bars = ax4.bar(range(len(type_labels)), type_means_all, yerr=type_stds_all,
                      color=['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A'], 
                      alpha=0.7, capsize=5, edgecolor='black')
        ax4.set_xticks(range(len(type_labels)))
        ax4.set_xticklabels(type_labels, rotation=45, ha='right')
        ax4.set_ylabel('Mean Correlation (across categories)', fontsize=12)
        ax4.set_title('Mean Correlations by Comparison Type\n(Between-Kid)', fontsize=13, pad=10)
        ax4.grid(True, alpha=0.3, axis='y')
        ax4.set_ylim([0, 1])
    
    plt.suptitle('Between-Kid First-Order Similarity: Category-Wise Analysis', 
                 fontsize=14, y=0.995, fontweight='bold')
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    plt.savefig(output_dir / "between_kid_first_order_category_wise_visualization.png", dpi=200, bbox_inches='tight')
    print(f"Saved between-kid category-wise visualization to {output_dir / 'between_kid_first_order_category_wise_visualization.png'}")
    plt.close()

print("\nCategory-wise visualization complete!")

Creating category-wise visualizations...


: 

In [58]:
# Demonstration: How RDM Correlation Works with NaN Values
# This cell demonstrates the correlation logic with a simple example

print("="*70)
print("DEMONSTRATION: RDM Correlation Logic with NaN Values")
print("="*70)

# Example: Simple 5-category case
print("\n1. SETUP: Full category order (5 categories)")
ordered_cats = ['cat1', 'cat2', 'cat3', 'cat4', 'cat5']
print(f"   Ordered categories: {ordered_cats}")

print("\n2. EXAMPLE SUBJECT:")
print("   Younger bin has: cat1, cat2, cat3, cat4 (4 categories)")
print("   Older bin has:   cat2, cat3, cat4, cat5 (4 categories)")
available_younger = ['cat1', 'cat2', 'cat3', 'cat4']
available_older = ['cat2', 'cat3', 'cat4', 'cat5']

print("\n3. COMMON CATEGORIES:")
# Preserve predefined order, not alphabetical
common = [cat for cat in ordered_cats if cat in available_younger and cat in available_older]
print(f"   Common categories: {common} ({len(common)} categories)")
print("   Note: cat1 only in younger, cat5 only in older - these are excluded")
print("   IMPORTANT: Categories are in predefined order, not alphabetical!")

print("\n4. RDM STRUCTURE:")
print("   Full RDMs are 5×5 (one row/column per category in ordered_cats)")
print("   Missing categories have NaN in their rows/columns")
print("\n   Younger RDM structure:")
print("   " + " ".join([f"{c:>6}" for c in ordered_cats]))
for i, cat in enumerate(ordered_cats):
    if cat in available_younger:
        status = "  data"
    else:
        status = "   NaN"
    print(f"   {cat:>6}: {status}")

print("\n5. SUBMATRIX EXTRACTION:")
common_indices = [ordered_cats.index(c) for c in common]
print(f"   Common category indices in full RDM: {common_indices}")
print(f"   Extract 3×3 submatrix using these indices")
print(f"   This gives us only the common categories: {common}")

print("\n6. UPPER TRIANGLE:")
n_common = len(common)
n_pairs = n_common * (n_common - 1) // 2
print(f"   For {n_common} categories, we get {n_pairs} unique pairs")
print(f"   (excluding diagonal: {n_common} self-pairs)")
print(f"   Example pairs: (cat2-cat3), (cat2-cat4), (cat3-cat4)")

print("\n7. CORRELATION:")
print("   - Extract distance values for these pairs from both RDMs")
print("   - Filter out any NaN values (shouldn't be any for common categories)")
print("   - Compute Spearman correlation on the paired distance values")
print("   - Result: Single correlation coefficient (-1 to 1)")

print("\n8. WHY THIS WORKS:")
print("   ✓ Only uses categories present in BOTH bins (fair comparison)")
print("   ✓ Same category pairs compared in both RDMs")
print("   ✓ NaN values are excluded (don't affect correlation)")
print("   ✓ Correlation reflects structural similarity, not data availability")

print("\n" + "="*70)
print("For actual subjects, this process uses 163 categories")
print("Common categories typically range from 130-160 per subject")
print("="*70)

DEMONSTRATION: RDM Correlation Logic with NaN Values

1. SETUP: Full category order (5 categories)
   Ordered categories: ['cat1', 'cat2', 'cat3', 'cat4', 'cat5']

2. EXAMPLE SUBJECT:
   Younger bin has: cat1, cat2, cat3, cat4 (4 categories)
   Older bin has:   cat2, cat3, cat4, cat5 (4 categories)

3. COMMON CATEGORIES:
   Common categories: ['cat2', 'cat3', 'cat4'] (3 categories)
   Note: cat1 only in younger, cat5 only in older - these are excluded
   IMPORTANT: Categories are in predefined order, not alphabetical!

4. RDM STRUCTURE:
   Full RDMs are 5×5 (one row/column per category in ordered_cats)
   Missing categories have NaN in their rows/columns

   Younger RDM structure:
     cat1   cat2   cat3   cat4   cat5
     cat1:   data
     cat2:   data
     cat3:   data
     cat4:   data
     cat5:    NaN

5. SUBMATRIX EXTRACTION:
   Common category indices in full RDM: [1, 2, 3]
   Extract 3×3 submatrix using these indices
   This gives us only the common categories: ['cat2', 'cat3',

In [37]:
def compute_rdm_correlation(rdm1, rdm2, ordered_categories_list, available_cats1, available_cats2):
    """
    Compute correlation between two RDMs that use the full ordered_categories list with NaN for missing categories.
    Only uses categories present in both RDMs (non-NaN in both).
    
    Args:
        rdm1: numpy array of shape (n_categories, n_categories) with NaN for missing categories
        rdm2: numpy array of shape (n_categories, n_categories) with NaN for missing categories
        ordered_categories_list: full list of categories in order (used for indexing)
        available_cats1: list of categories actually present in rdm1
        available_cats2: list of categories actually present in rdm2
    
    Returns:
        corr: correlation coefficient (or np.nan if insufficient data)
        n_common: number of common categories
    """
    # Find common categories (categories present in both RDMs)
    # IMPORTANT: Preserve predefined order from ordered_categories_list, NOT alphabetical order
    # This ensures submatrices are extracted in the same order for both RDMs
    # and maintains consistency with visualizations which use the predefined order
    common_categories = [cat for cat in ordered_categories_list 
                        if cat in available_cats1 and cat in available_cats2]
    
    if len(common_categories) < 2:
        return np.nan, len(common_categories)
    
    # Get indices for common categories in the ordered_categories_list
    common_indices = [ordered_categories_list.index(cat) for cat in common_categories]
    
    # Extract submatrices for common categories
    rdm1_subset = rdm1[np.ix_(common_indices, common_indices)]
    rdm2_subset = rdm2[np.ix_(common_indices, common_indices)]
    
    # Get upper triangle (excluding diagonal) for both RDMs
    mask = np.triu(np.ones_like(rdm1_subset, dtype=bool), k=1)
    rdm1_flat = rdm1_subset[mask]
    rdm2_flat = rdm2_subset[mask]
    
    # Filter out NaN values (shouldn't be any if categories are truly common, but check anyway)
    valid_mask = ~(np.isnan(rdm1_flat) | np.isnan(rdm2_flat))
    rdm1_valid = rdm1_flat[valid_mask]
    rdm2_valid = rdm2_flat[valid_mask]
    
    # Compute Spearman correlation (more robust to outliers)
    if len(rdm1_valid) > 0:
        corr, _ = spearmanr(rdm1_valid, rdm2_valid)
        return corr, len(common_categories)
    else:
        return np.nan, len(common_categories)

# Compute RDM correlations between younger and older bins for each subject
trajectory_data = []

for subject_id, bin_rdms in tqdm(subject_age_rdms.items(), desc="Analyzing trajectories"):
    if 'younger' not in bin_rdms or 'older' not in bin_rdms:
        continue
    
    rdm_younger = bin_rdms['younger']
    rdm_older = bin_rdms['older']
    cats_younger = subject_age_rdm_categories[subject_id]['younger']
    cats_older = subject_age_rdm_categories[subject_id]['older']
    
    # Use ordered_categories as reference and available categories for filtering
    corr, n_common = compute_rdm_correlation(
        rdm_younger, rdm_older, 
        ordered_categories,  # Full ordered list for indexing
        cats_younger,  # Available categories in younger bin
        cats_older     # Available categories in older bin
    )
    
    trajectory_data.append({
        'subject_id': subject_id,
        'age_bin_1': 'younger',
        'age_bin_2': 'older',
        'median_age_threshold': overall_median_age,
        'rdm_correlation': corr,
        'n_common_categories': n_common,
        'n_categories_younger': len(cats_younger),
        'n_categories_older': len(cats_older)
    })

trajectory_df = pd.DataFrame(trajectory_data)
trajectory_df.to_csv(output_dir / "trajectory_correlations.csv", index=False)

print(f"\nTrajectory analysis:")
print(f"  Total subjects analyzed: {len(trajectory_df)}")
print(f"  Mean RDM correlation (younger vs older): {trajectory_df['rdm_correlation'].mean():.3f}")
print(f"  Std RDM correlation: {trajectory_df['rdm_correlation'].std():.3f}")
print(f"  Median age threshold: {overall_median_age:.1f} months")
print(f"\nSaved trajectory correlations to {output_dir / 'trajectory_correlations.csv'}")


Analyzing trajectories: 100%|██████████| 18/18 [00:00<00:00, 406.19it/s]


Trajectory analysis:
  Total subjects analyzed: 18
  Mean RDM correlation (younger vs older): 0.756
  Std RDM correlation: 0.073
  Median age threshold: 16.0 months

Saved trajectory correlations to developmental_trajectory_rdms_clip/trajectory_correlations.csv





## Category-Based Correlations

Compute correlations between younger and older RDMs separately for each broad semantic category group (animals, bodyparts, big_objects, small_objects, others). This allows us to examine whether developmental stability varies across different semantic domains.

In [38]:
# Compute category-based correlations for each semantic group
category_correlation_data = []

# Get category groups from organized structure
category_groups = {
    'animals': organized['animals'],
    'bodyparts': organized['bodyparts'],
    'big_objects': organized['big_objects'],
    'small_objects': organized['small_objects'],
    'others': organized['others']
}

print("Computing category-based correlations...")
print(f"Category group sizes: {[(name, len(cats)) for name, cats in category_groups.items()]}")

for subject_id, bin_rdms in tqdm(subject_age_rdms.items(), desc="Category correlations"):
    if 'younger' not in bin_rdms or 'older' not in bin_rdms:
        continue
    
    rdm_younger = bin_rdms['younger']
    rdm_older = bin_rdms['older']
    cats_younger = subject_age_rdm_categories[subject_id]['younger']
    cats_older = subject_age_rdm_categories[subject_id]['older']
    
    # Compute correlation for each category group
    for group_name, group_categories in category_groups.items():
        # Find common categories in this group that are present in both age bins
        common_in_group = [cat for cat in group_categories 
                          if cat in cats_younger and cat in ordered_categories and cat in cats_older]
        
        if len(common_in_group) < 2:
            # Not enough categories in this group for correlation
            category_correlation_data.append({
                'subject_id': subject_id,
                'category_group': group_name,
                'n_common_categories': len(common_in_group),
                'correlation': np.nan,
                'n_categories_younger': len([c for c in group_categories if c in cats_younger]),
                'n_categories_older': len([c for c in group_categories if c in cats_older])
            })
            continue
        
        # Get indices for common categories in this group
        common_indices = [ordered_categories.index(cat) for cat in common_in_group]
        
        # Extract submatrices for this group
        rdm_younger_group = rdm_younger[np.ix_(common_indices, common_indices)]
        rdm_older_group = rdm_older[np.ix_(common_indices, common_indices)]
        
        # Get upper triangle (excluding diagonal)
        mask = np.triu(np.ones_like(rdm_younger_group, dtype=bool), k=1)
        rdm_younger_flat = rdm_younger_group[mask]
        rdm_older_flat = rdm_older_group[mask]
        
        # Filter out NaN values
        valid_mask = ~(np.isnan(rdm_younger_flat) | np.isnan(rdm_older_flat))
        rdm_younger_valid = rdm_younger_flat[valid_mask]
        rdm_older_valid = rdm_older_flat[valid_mask]
        
        # Compute Spearman correlation
        if len(rdm_younger_valid) > 0:
            corr, _ = spearmanr(rdm_younger_valid, rdm_older_valid)
        else:
            corr = np.nan
        
        category_correlation_data.append({
            'subject_id': subject_id,
            'category_group': group_name,
            'n_common_categories': len(common_in_group),
            'correlation': corr,
            'n_categories_younger': len([c for c in group_categories if c in cats_younger]),
            'n_categories_older': len([c for c in group_categories if c in cats_older])
        })

category_corr_df = pd.DataFrame(category_correlation_data)
category_corr_df.to_csv(output_dir / "category_group_correlations.csv", index=False)

print(f"\nCategory-based correlation analysis:")
print(f"  Total subject-group combinations: {len(category_corr_df)}")
print(f"\nMean correlations by category group:")
for group_name in category_groups.keys():
    group_data = category_corr_df[category_corr_df['category_group'] == group_name]
    valid_corrs = group_data['correlation'].dropna()
    if len(valid_corrs) > 0:
        print(f"  {group_name:15s}: {valid_corrs.mean():.3f} (n={len(valid_corrs)} valid, {len(group_data)} total)")
    else:
        print(f"  {group_name:15s}: No valid correlations (n={len(group_data)} total)")

print(f"\nSaved category group correlations to {output_dir / 'category_group_correlations.csv'}")

Computing category-based correlations...
Category group sizes: [('animals', 19), ('bodyparts', 14), ('big_objects', 32), ('small_objects', 96), ('others', 2)]


Category correlations: 100%|██████████| 18/18 [00:00<00:00, 445.22it/s]


Category-based correlation analysis:
  Total subject-group combinations: 90

Mean correlations by category group:
  animals        : 0.502 (n=18 valid, 18 total)
  bodyparts      : 0.789 (n=18 valid, 18 total)
  big_objects    : 0.736 (n=18 valid, 18 total)
  small_objects  : 0.773 (n=18 valid, 18 total)
  others         : No valid correlations (n=18 total)

Saved category group correlations to developmental_trajectory_rdms_clip/category_group_correlations.csv





## Visualize Category-Based Correlations

Create visualizations to examine how developmental stability (correlation between younger and older RDMs) varies across different semantic category groups.

In [40]:
# Visualize category-based correlations
print("Creating visualizations for category-based correlations...")

# Filter out NaN correlations for plotting
valid_category_corr_df = category_corr_df[category_corr_df['correlation'].notna()].copy()

# Create figure with multiple subplots
fig = plt.figure(figsize=(18, 12))

# 1. Box plot comparing correlations across category groups
ax1 = plt.subplot(2, 3, 1)
category_order = ['animals', 'bodyparts', 'big_objects', 'small_objects', 'others']
box_data = [valid_category_corr_df[valid_category_corr_df['category_group'] == group]['correlation'].values 
            for group in category_order if group in valid_category_corr_df['category_group'].values]

# Filter out empty groups
box_data_filtered = []
labels_filtered = []
for i, group in enumerate(category_order):
    group_data = valid_category_corr_df[valid_category_corr_df['category_group'] == group]['correlation'].values
    box_data_filtered.append(group_data)
    labels_filtered.append(group.replace('_', ' ').title())

bp = ax1.boxplot(box_data_filtered, labels=labels_filtered, patch_artist=True)
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#98D8C8']
for patch, color in zip(bp['boxes'], colors[:len(bp['boxes'])]):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

ax1.set_ylabel('RDM Correlation (Spearman)', fontsize=12)
ax1.set_title('Distribution of Correlations by Category Group', fontsize=13, pad=10)
ax1.grid(True, alpha=0.3, axis='y')
ax1.set_ylim([0, 1])
plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45, ha='right')

# 2. Bar plot of mean correlations by group
ax2 = plt.subplot(2, 3, 2)
mean_corrs = []
std_corrs = []
group_labels = []
for group in category_order:
    group_data = valid_category_corr_df[valid_category_corr_df['category_group'] == group]['correlation']
    mean_corrs.append(group_data.mean())
    std_corrs.append(group_data.std())
    group_labels.append(group.replace('_', ' ').title())

bars = ax2.bar(range(len(group_labels)), mean_corrs, yerr=std_corrs, 
               color=colors[:len(group_labels)], alpha=0.7, capsize=5, edgecolor='black')
ax2.set_xticks(range(len(group_labels)))
ax2.set_xticklabels(group_labels, rotation=45, ha='right')
ax2.set_ylabel('Mean RDM Correlation', fontsize=12)
ax2.set_title('Mean Correlations by Category Group', fontsize=13, pad=10)
ax2.grid(True, alpha=0.3, axis='y')
ax2.set_ylim([0, 1])
ax2.axhline(y=valid_category_corr_df['correlation'].mean(), color='red', 
           linestyle='--', linewidth=2, label=f'Overall mean: {valid_category_corr_df["correlation"].mean():.3f}')
ax2.legend()

# 3. Heatmap: subjects x category groups
ax3 = plt.subplot(2, 3, 3)
pivot_data = valid_category_corr_df.pivot(index='subject_id', columns='category_group', values='correlation')
# Reorder columns
pivot_data = pivot_data[[col for col in category_order if col in pivot_data.columns]]
# Sort subjects by overall correlation (average across groups)
pivot_data['mean_corr'] = pivot_data.mean(axis=1)
pivot_data = pivot_data.sort_values('mean_corr', ascending=False)
pivot_data = pivot_data.drop('mean_corr', axis=1)

im = ax3.imshow(pivot_data.values, aspect='auto', cmap='RdYlBu_r', vmin=0, vmax=1)
ax3.set_xticks(range(len(pivot_data.columns)))
ax3.set_xticklabels([col.replace('_', ' ').title() for col in pivot_data.columns], 
                    rotation=45, ha='right')
ax3.set_yticks(range(len(pivot_data.index)))
ax3.set_yticklabels(pivot_data.index, fontsize=8)
ax3.set_title('Correlation Heatmap: Subjects × Category Groups', fontsize=13, pad=10)
plt.colorbar(im, ax=ax3, label='RDM Correlation')

# 4. Violin plot for better distribution visualization
ax4 = plt.subplot(2, 3, 4)
violin_data = []
violin_labels = []
for group in category_order:
    group_data = valid_category_corr_df[valid_category_corr_df['category_group'] == group]['correlation'].values
    # Only add groups with non-empty data (filter out empty arrays)
    if len(group_data) > 0:
        violin_data.append(group_data)
        violin_labels.append(group.replace('_', ' ').title())


if len(violin_data) > 0:
    parts = ax4.violinplot(violin_data, positions=range(len(violin_labels)), showmeans=True, showmedians=True)
    for i, pc in enumerate(parts['bodies']):
        pc.set_facecolor(colors[i % len(colors)])
        pc.set_alpha(0.7)
    ax4.set_xticks(range(len(violin_labels)))
    ax4.set_xticklabels(violin_labels, rotation=45, ha='right')
    ax4.set_ylabel('RDM Correlation (Spearman)', fontsize=12)
    ax4.set_title('Distribution of Correlations (Violin Plot)', fontsize=13, pad=10)
    ax4.grid(True, alpha=0.3, axis='y')
    ax4.set_ylim([0, 1])
else:
    ax4.text(0.5, 0.5, 'No valid data for violin plot', ha='center', va='center', transform=ax4.transAxes)
    ax4.set_title('Distribution of Correlations (Violin Plot)', fontsize=13, pad=10)

# 5. Scatter plot: correlation vs number of common categories
ax5 = plt.subplot(2, 3, 5)
for group in category_order:
    group_data = valid_category_corr_df[valid_category_corr_df['category_group'] == group]
    ax5.scatter(group_data['n_common_categories'], group_data['correlation'], 
                   label=group.replace('_', ' ').title(), alpha=0.6, s=60)

ax5.set_xlabel('Number of Common Categories', fontsize=12)
ax5.set_ylabel('RDM Correlation', fontsize=12)
ax5.set_title('Correlation vs Category Count', fontsize=13, pad=10)
ax5.legend(loc='best', fontsize=9)
ax5.grid(True, alpha=0.3)
ax5.set_ylim([0, 1])

# 6. Individual subject trajectories (bar plot for each subject)
ax6 = plt.subplot(2, 3, 6)
# Get top 10 subjects by overall correlation for cleaner visualization
subject_means = valid_category_corr_df.groupby('subject_id')['correlation'].mean().sort_values(ascending=False)
top_subjects = subject_means.head(10).index

x_pos = np.arange(len(top_subjects))
bar_width = 0.15  # Use different variable name to avoid conflicts
for i, group in enumerate(category_order):
    if group in valid_category_corr_df['category_group'].values:
        group_corrs = []
        for subj in top_subjects:
            subj_group_data = valid_category_corr_df[
                (valid_category_corr_df['subject_id'] == subj) & 
                (valid_category_corr_df['category_group'] == group)
            ]
            if len(subj_group_data) > 0:
                group_corrs.append(subj_group_data['correlation'].values[0])
            else:
                group_corrs.append(np.nan)
        
        # Convert to numpy array and only plot if we have at least some valid data
        group_corrs = np.array(group_corrs)
        if not np.all(np.isnan(group_corrs)):
            ax6.bar(x_pos + i*bar_width, group_corrs, bar_width, 
                   label=group.replace('_', ' ').title(), alpha=0.7, color=colors[i % len(colors)])

ax6.set_xlabel('Subject ID', fontsize=12)
ax6.set_ylabel('RDM Correlation', fontsize=12)
ax6.set_title('Top 10 Subjects: Correlations by Category Group', fontsize=13, pad=10)
ax6.set_xticks(x_pos + bar_width * 2)
ax6.set_xticklabels(top_subjects, rotation=45, ha='right', fontsize=8)
ax6.legend(loc='upper left', fontsize=8, ncol=2)
ax6.grid(True, alpha=0.3, axis='y')
ax6.set_ylim([0, 1])

plt.suptitle('Category-Based RDM Correlations: Younger vs Older Age Bins', 
             fontsize=16, y=0.995, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.99])
plt.savefig(output_dir / "category_group_correlations_visualization.png", dpi=200, bbox_inches='tight')
print(f"Saved category correlation visualization to {output_dir / 'category_group_correlations_visualization.png'}")
plt.close()

# Create a separate detailed heatmap with all subjects
fig, ax = plt.subplots(figsize=(10, 14))
pivot_data_all = valid_category_corr_df.pivot(index='subject_id', columns='category_group', values='correlation')
pivot_data_all = pivot_data_all[[col for col in category_order if col in pivot_data_all.columns]]
# Sort by overall mean correlation
pivot_data_all['mean_corr'] = pivot_data_all.mean(axis=1)
pivot_data_all = pivot_data_all.sort_values('mean_corr', ascending=False)
pivot_data_all = pivot_data_all.drop('mean_corr', axis=1)

im = ax.imshow(pivot_data_all.values, aspect='auto', cmap='RdYlBu_r', vmin=0, vmax=1)
ax.set_xticks(range(len(pivot_data_all.columns)))
ax.set_xticklabels([col.replace('_', ' ').title() for col in pivot_data_all.columns],
                   rotation=45, ha='right', fontsize=11)
for i, (xlabel, ylabel) in enumerate(zip(ax.get_xticklabels(), ax.get_yticklabels())):
    if i < len(ordered_categories):
        cat_name = ordered_categories[i]
        color = get_category_color(cat_name, cdi_category_map)
        xlabel.set_color(color)
        ylabel.set_color(color)
ax.set_yticks(range(len(pivot_data_all.index)))
ax.set_yticklabels(pivot_data_all.index, fontsize=9)
ax.set_title('Category-Based RDM Correlations: All Subjects\n(Younger vs Older Age Bins)', 
             fontsize=14, pad=15, fontweight='bold')
cbar = plt.colorbar(im, ax=ax, label='RDM Correlation (Spearman)', fraction=0.046, pad=0.04)

# Add text annotations for correlation values
for i in range(len(pivot_data_all.index)):
    for j in range(len(pivot_data_all.columns)):
        val = pivot_data_all.iloc[i, j]
        ax.text(j, i, f'{val:.2f}', ha='center', va='center', 
                   fontsize=7, color='white' if val < 0.5 else 'black', fontweight='bold')

plt.tight_layout()
plt.savefig(output_dir / "category_group_correlations_heatmap.png", dpi=200, bbox_inches='tight')
print(f"Saved detailed heatmap to {output_dir / 'category_group_correlations_heatmap.png'}")
plt.close()

print("\nVisualization complete!")



Creating visualizations for category-based correlations...
Saved category correlation visualization to developmental_trajectory_rdms_clip/category_group_correlations_visualization.png
Saved detailed heatmap to developmental_trajectory_rdms_clip/category_group_correlations_heatmap.png

Visualization complete!


## Visualize Developmental Trajectories


In [41]:
# Create side-by-side RDM visualization for each subject (younger vs older)
print("Creating RDM visualizations for all subjects (younger vs older)...")

for subject_id in tqdm(subject_age_rdms.keys(), desc="Creating RDM plots"):
    bin_rdms = subject_age_rdms[subject_id]
    bin_masks = subject_age_rdm_masks[subject_id]
    
    continue
    
    # Create figure with 2 subplots side by side
    fig, axes = plt.subplots(1, 2, figsize=(16, 7))
    
    # Find global min/max for consistent color scale (excluding NaN)
    all_rdm_values = []
    for rdm in bin_rdms.values():
        valid_values = rdm[~np.isnan(rdm)]
        all_rdm_values.extend(valid_values)
    vmin = np.percentile(all_rdm_values, 1) if len(all_rdm_values) > 0 else 0
    vmax = np.percentile(all_rdm_values, 99) if len(all_rdm_values) > 0 else 2
    
    for idx, bin_name in enumerate(['younger', 'older']):
        rdm = bin_rdms[bin_name]
        mask = bin_masks[bin_name]
        available_cats = subject_age_rdm_categories[subject_id][bin_name]
        group_boundaries = subject_age_group_boundaries[subject_id][bin_name]
        
        ax = axes[idx]
        
        # Determine font sizes based on number of categories in predefined order
        n_cats_total = len(ordered_categories)
        n_cats_available = len(available_cats)
        
        if n_cats_total <= 50:
            label_fontsize = 10
            tick_fontsize = 12
        elif n_cats_total <= 100:
            label_fontsize = 8
            tick_fontsize = 10
        else:
            label_fontsize = 6
            tick_fontsize = 8
        
        # Create masked array for visualization (white cells for missing categories)
        rdm_masked = np.ma.masked_where(mask, rdm)
        cmap = plt.cm.get_cmap('viridis').copy()  # Get a copy to avoid modifying global colormap
        cmap.set_bad(color='white', alpha=1.0)  # White for NA cells
        im = ax.imshow(rdm_masked, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
        
            # Draw vertical line
            # Draw horizontal line
        
        # Set category names as axis labels (use full predefined order)
        ax.set_xticks(range(len(ordered_categories)))
        ax.set_yticks(range(len(ordered_categories)))
        # Show only every Nth label to avoid overlap
        n_cats = len(ordered_categories)
        if n_cats <= 50:
            tick_step = 1
            tick_step = 2
        else:
            tick_step = max(1, n_cats // 50)  # Show ~50 labels max
        
        ax.set_xticks(range(0, n_cats, tick_step))
        ax.set_yticks(range(0, n_cats, tick_step))
        ax.set_xticklabels([ordered_categories[j] for j in range(0, n_cats, tick_step)], 
                                     rotation=90, ha='right', fontsize=max(8, tick_fontsize) if 'tick_fontsize' in locals() else 8)
        ax.set_yticklabels([ordered_categories[j] for j in range(0, n_cats, tick_step)], 
                                     fontsize=max(8, tick_fontsize) if 'tick_fontsize' in locals() else 8)
        
        # Apply colors to visible labels based on CDI category
        for label in ax.get_xticklabels():
            cat_name = label.get_text()
            color = get_category_color(cat_name, cdi_category_map)
            label.set_color(color)
        for label in ax.get_yticklabels():
            cat_name = label.get_text()
            color = get_category_color(cat_name, cdi_category_map)
            label.set_color(color)
        # Apply colors to labels based on CDI category
        for i, (xlabel, ylabel) in enumerate(zip(ax.get_xticklabels(), ax.get_yticklabels())):
            cat_name = ordered_categories[i]
            color = get_category_color(cat_name, cdi_category_map)

        ax.set_yticklabels(ordered_categories, fontsize=tick_fontsize)
        
        
        # Create title with age range info and category count
        if bin_name == "younger":
            title = f"Younger (≤{overall_median_age:.0f} months)\n({n_cats_available}/{n_cats_total} categories)"
        else:
            title = f"Older (>{overall_median_age:.0f} months)\n({n_cats_available}/{n_cats_total} categories)"
        
        ax.set_title(title, fontsize=12, pad=10)
        
        # Add colorbar
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    plt.suptitle(f"Developmental Trajectory: {subject_id}\n(Median split at {overall_median_age:.1f} months)", 
                 fontsize=14, y=0.995)
    plt.tight_layout(rect=[0, 0, 1, 0.99])
    plt.savefig(output_dir / f"trajectory_{subject_id}.png", dpi=200, bbox_inches='tight')
    plt.close()

print(f"\nSaved RDM visualizations for {len(subject_age_rdms)} subjects")



Creating RDM visualizations for all subjects (younger vs older)...


Creating RDM plots: 100%|██████████| 18/18 [00:00<00:00, 364722.09it/s]


Saved RDM visualizations for 18 subjects





## Grouped Developmental Trajectory Visualization

Create a grouped visualization combining multiple subjects' developmental trajectory plots together in one figure.

In [42]:
# Create grouped visualization combining N subjects' developmental trajectory plots
print("Creating grouped developmental trajectory visualization...")

# Load CDI category mapping for label coloring
cdi_category_map = load_cdi_category_mapping(cdi_path)


# Get all subjects with valid data
valid_subjects = [sid for sid in subject_age_rdms.keys() 
                  if 'younger' in subject_age_rdms[sid] and 'older' in subject_age_rdms[sid]]

if len(valid_subjects) == 0:
    print("No subjects with valid data for grouped visualization")
else:
    # Number of subjects to plot (can be adjusted)
    # Set to None to plot all subjects, or specify a number
    n_subjects_to_plot = None  # Change to a number like 6, 9, 12, etc. to limit
    
    subjects_to_plot = valid_subjects[:n_subjects_to_plot] if n_subjects_to_plot else valid_subjects
    n_subjects = len(subjects_to_plot)
    
    print(f"Plotting {n_subjects} subjects in grouped visualization...")
    
    # Calculate global min/max for consistent color scale across all subjects
    all_rdm_values = []
    for subject_id in subjects_to_plot:
        bin_rdms = subject_age_rdms[subject_id]
        for bin_name in ['younger', 'older']:
            rdm = bin_rdms[bin_name]
            # Ensure rdm is a numpy array and handle NaN values
            if isinstance(rdm, np.ndarray):
                valid_values = rdm[~np.isnan(rdm)]
                if len(valid_values) > 0:
                    all_rdm_values.extend(valid_values.tolist())
    
    vmin = np.percentile(all_rdm_values, 1) if len(all_rdm_values) > 0 else 0
    vmax = np.percentile(all_rdm_values, 99) if len(all_rdm_values) > 0 else 2
    
    # Create figure with n_subjects rows and 2 columns (younger, older)
    # Adjust figure size based on number of subjects
    fig_height = max(4, n_subjects * 2.5)  # At least 4 inches, 2.5 inches per subject
    fig_width = 16  # Keep consistent width
    
    fig, axes = plt.subplots(n_subjects, 2, figsize=(fig_width, fig_height))
    
    # Handle case where there's only one subject (axes would be 1D)
    if n_subjects == 1:
        axes = axes.reshape(1, -1)
    
    # Determine font sizes based on number of categories
    n_cats_total = len(ordered_categories)
    if n_cats_total <= 50:
        tick_fontsize = 8
    elif n_cats_total <= 100:
        tick_fontsize = 6
    else:
        tick_fontsize = 4
    
    # Plot each subject
    for row_idx, subject_id in enumerate(subjects_to_plot):
        bin_rdms = subject_age_rdms[subject_id]
        bin_masks = subject_age_rdm_masks[subject_id]
        
        for col_idx, bin_name in enumerate(['younger', 'older']):
            rdm = bin_rdms[bin_name]
            mask = bin_masks[bin_name]
            available_cats = subject_age_rdm_categories[subject_id][bin_name]
            group_boundaries = subject_age_group_boundaries[subject_id][bin_name]
            
            ax = axes[row_idx, col_idx]
            
            # Ensure rdm is a numpy array with correct shape
            if not isinstance(rdm, np.ndarray):
                rdm = np.array(rdm)
            # Ensure mask is a numpy array with correct shape
            if not isinstance(mask, np.ndarray):
                mask = np.array(mask)
            # Ensure shapes match
            if rdm.shape != mask.shape:
                print(f"Warning: Shape mismatch for {subject_id} {bin_name}: rdm {rdm.shape} vs mask {mask.shape}")
                # Try to reshape or skip if incompatible
                if rdm.size == 0 or mask.size == 0:
                    continue
                if rdm.size == mask.size:
                    rdm = rdm.reshape(mask.shape)
                else:
                    print(f"  Skipping {subject_id} {bin_name} due to incompatible shapes")
                    continue
            
            # Create masked array for visualization
            rdm_masked = np.ma.masked_where(mask, rdm)
            cmap = plt.cm.get_cmap('viridis').copy()
            cmap.set_bad(color='white', alpha=1.0)
            im = ax.imshow(rdm_masked, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
            
            
            # Set category names as axis labels (only show every Nth label to avoid crowding)
            n_cats = len(ordered_categories)
            if n_cats <= 50:
                tick_step = 1
            elif n_cats <= 100:
                tick_step = 2
            else:
                tick_step = max(1, n_cats // 50)
            
            ax.set_xticks(range(0, n_cats, tick_step))
            ax.set_yticks(range(0, n_cats, tick_step))
            ax.set_xticklabels([ordered_categories[i] for i in range(0, n_cats, tick_step)], 
                    rotation=90, ha="right", fontsize=tick_fontsize)

            # Apply colors to labels based on CDI category
            for i, (xlabel, ylabel) in enumerate(zip(ax.get_xticklabels(), ax.get_yticklabels())):
                if i < len(ordered_categories):
                    cat_name = ordered_categories[i]
                    color = get_category_color(cat_name, cdi_category_map)
                    xlabel.set_color(color)
                    ylabel.set_color(color)
            ax.set_yticklabels([ordered_categories[i] for i in range(0, n_cats, tick_step)], 
                             fontsize=tick_fontsize)
            
            # Create title with subject ID and age info
            n_cats_available = len(available_cats)
            if bin_name == 'younger':
                title = f"{subject_id} - Younger (≤{overall_median_age:.0f}mo)\n({n_cats_available}/{n_cats_total} cats)"
            else:
                title = f"{subject_id} - Older (>{overall_median_age:.0f}mo)\n({n_cats_available}/{n_cats_total} cats)"
            
            ax.set_title(title, fontsize=10, pad=5)
            
            # Add colorbar only to the rightmost plots
            if col_idx == 1:  # Right column
                plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    
    # Add overall title
    plt.suptitle(f'Grouped Developmental Trajectories: {n_subjects} Subjects\n(Median split at {overall_median_age:.1f} months)', 
                 fontsize=16, y=0.998, fontweight='bold')
    
    plt.tight_layout(rect=[0, 0, 1, 0.98])
    
    # Save the grouped visualization
    output_filename = f"grouped_trajectory_{n_subjects}_subjects.png"
    plt.savefig(output_dir / output_filename, dpi=200, bbox_inches='tight')
    print(f"Saved grouped visualization to {output_dir / output_filename}")
    plt.close()
    
    # Also create a version with fewer subjects if there are many (for better readability)
    if n_subjects > 12:
        print(f"\nCreating additional grouped visualization with first 12 subjects for better readability...")
        subjects_to_plot_12 = valid_subjects[:12]
        
        fig_height_12 = 12 * 2.5
        fig, axes = plt.subplots(12, 2, figsize=(fig_width, fig_height_12))
        
        for row_idx, subject_id in enumerate(subjects_to_plot_12):
            bin_rdms = subject_age_rdms[subject_id]
            bin_masks = subject_age_rdm_masks[subject_id]
            
            for col_idx, bin_name in enumerate(['younger', 'older']):
                rdm = bin_rdms[bin_name]
                mask = bin_masks[bin_name]
                available_cats = subject_age_rdm_categories[subject_id][bin_name]
                group_boundaries = subject_age_group_boundaries[subject_id][bin_name]
                
                ax = axes[row_idx, col_idx]
                
                # Ensure rdm is a numpy array with correct shape
                if not isinstance(rdm, np.ndarray):
                    rdm = np.array(rdm)
                # Ensure mask is a numpy array with correct shape
                if not isinstance(mask, np.ndarray):
                    mask = np.array(mask)
                # Ensure shapes match
                if rdm.shape != mask.shape:
                    print(f"Warning: Shape mismatch for {subject_id} {bin_name}: rdm {rdm.shape} vs mask {mask.shape}")
                    # Try to reshape or skip if incompatible
                    if rdm.size == 0 or mask.size == 0:
                        continue
                    if rdm.size == mask.size:
                        rdm = rdm.reshape(mask.shape)
                    else:
                        print(f"  Skipping {subject_id} {bin_name} due to incompatible shapes")
                        continue
                
                rdm_masked = np.ma.masked_where(mask, rdm)
                cmap = plt.cm.get_cmap('viridis').copy()
                cmap.set_bad(color='white', alpha=1.0)
                im = ax.imshow(rdm_masked, cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)
                
                
                n_cats = len(ordered_categories)
                if n_cats <= 50:
                    tick_step = 1
                elif n_cats <= 100:
                    tick_step = 2
                else:
                    tick_step = max(1, n_cats // 50)
                
                ax.set_xticks(range(0, n_cats, tick_step))
                ax.set_yticks(range(0, n_cats, tick_step))
                ax.set_xticklabels([ordered_categories[i] for i in range(0, n_cats, tick_step)], 
                                             rotation=90, ha="right", fontsize=tick_fontsize)
                # Apply colors to labels based on CDI category
                for i, (xlabel, ylabel) in enumerate(zip(ax.get_xticklabels(), ax.get_yticklabels())):
                    if i < len(ordered_categories):
                        cat_name = ordered_categories[i]
                        color = get_category_color(cat_name, cdi_category_map)
                        xlabel.set_color(color)
                        ylabel.set_color(color)
                ax.set_yticklabels([ordered_categories[i] for i in range(0, n_cats, tick_step)], 
                                 fontsize=tick_fontsize)
                
                n_cats_available = len(available_cats)
                if bin_name == 'younger':
                    title = f"{subject_id} - Younger (≤{overall_median_age:.0f}mo)\n({n_cats_available}/{n_cats_total} cats)"
                else:
                    title = f"{subject_id} - Older (>{overall_median_age:.0f}mo)\n({n_cats_available}/{n_cats_total} cats)"
                
                ax.set_title(title, fontsize=10, pad=5)
                
                if col_idx == 1:
                    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        
        plt.suptitle(f'Grouped Developmental Trajectories: First 12 Subjects\n(Median split at {overall_median_age:.1f} months)', 
                     fontsize=16, y=0.998, fontweight='bold')
        
        plt.tight_layout(rect=[0, 0, 1, 0.98])
        
        output_filename_12 = f"grouped_trajectory_12_subjects.png"
        plt.savefig(output_dir / output_filename_12, dpi=200, bbox_inches='tight')
        print(f"Saved 12-subject grouped visualization to {output_dir / output_filename_12}")
        plt.close()

print("\nGrouped visualization complete!")



Creating grouped developmental trajectory visualization...
Plotting 18 subjects in grouped visualization...
Saved grouped visualization to developmental_trajectory_rdms_clip/grouped_trajectory_18_subjects.png

Creating additional grouped visualization with first 12 subjects for better readability...
Saved 12-subject grouped visualization to developmental_trajectory_rdms_clip/grouped_trajectory_12_subjects.png

Grouped visualization complete!


## Plot RDM Stability Across Development


In [43]:
# Plot RDM correlation distribution between younger and older bins
# Filter out NaN correlations
valid_correlations = trajectory_df['rdm_correlation'].dropna()

if len(valid_correlations) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Histogram of RDM correlations
    axes[0].hist(valid_correlations, bins=20, alpha=0.7, edgecolor='black')
    mean_corr = valid_correlations.mean()
    axes[0].axvline(mean_corr, color='red', linestyle='--', 
                    label=f'Mean: {mean_corr:.3f}')
    axes[0].set_xlabel('RDM Correlation (Spearman)')
    axes[0].set_ylabel('Number of Subjects')
    axes[0].set_title(f'Distribution of Younger vs Older RDM Correlations\n(n={len(valid_correlations)} valid)')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Box plot
    axes[1].boxplot(valid_correlations, vert=True)
    axes[1].set_ylabel('RDM Correlation (Spearman)')
    axes[1].set_title(f'RDM Correlation: Younger vs Older\n(n={len(valid_correlations)} valid)')
    axes[1].set_xticklabels(['All Subjects'])
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(output_dir / "rdm_stability_analysis.png", dpi=150, bbox_inches='tight')
    print(f"Saved RDM stability analysis to {output_dir / 'rdm_stability_analysis.png'}")
    plt.close()
else:
    print("Warning: No valid correlations to plot (all are NaN)")


Saved RDM stability analysis to developmental_trajectory_rdms_clip/rdm_stability_analysis.png


## Summary Statistics


In [44]:
# Create summary statistics
summary_data = []

for subject_id, bin_rdms in subject_age_rdms.items():
    for bin_name in ['younger', 'older']:
        if bin_name not in bin_rdms:
            continue
            
        rdm = bin_rdms[bin_name]
        categories = subject_age_rdm_categories[subject_id][bin_name]
        
        # Use nan-aware functions to handle NaN values (missing categories)
        valid_rdm = rdm[~np.isnan(rdm)]
        valid_rdm_positive = valid_rdm[valid_rdm > 0]  # Exclude diagonal zeros
        
        summary_data.append({
            'subject_id': subject_id,
            'age_bin': bin_name,
            'median_age_threshold': overall_median_age,
            'n_categories': len(categories),
            'mean_distance': float(np.nanmean(rdm)) if len(valid_rdm) > 0 else np.nan,
            'std_distance': float(np.nanstd(rdm)) if len(valid_rdm) > 0 else np.nan,
            'min_distance': float(valid_rdm_positive.min()) if len(valid_rdm_positive) > 0 else np.nan,
            'max_distance': float(np.nanmax(rdm)) if len(valid_rdm) > 0 else np.nan
        })

summary_df = pd.DataFrame(summary_data)
summary_df.to_csv(output_dir / "summary_statistics.csv", index=False)

print("Summary statistics:")
print(summary_df.describe())
print(f"\nSaved summary to {output_dir / 'summary_statistics.csv'}")


Summary statistics:
       median_age_threshold  n_categories  mean_distance  std_distance  \
count                  36.0     36.000000      36.000000     36.000000   
mean                   16.0    150.916667       0.970143      0.269097   
std                     0.0     12.748950       0.015751      0.012754   
min                    16.0     88.000000       0.907006      0.229432   
25%                    16.0    149.000000       0.962922      0.263846   
50%                    16.0    154.500000       0.973334      0.271045   
75%                    16.0    157.000000       0.980742      0.277700   
max                    16.0    161.000000       0.989174      0.287583   

       min_distance  max_distance  
count     36.000000     36.000000  
mean       0.040484      1.628654  
std        0.020743      0.046012  
min        0.012637      1.428745  
25%        0.022898      1.621835  
50%        0.037880      1.636737  
75%        0.056602      1.655754  
max        0.078758      

## Cross-Kid Correlations

In addition to within-kid correlations (comparing younger vs. older periods for the same child), we also compute cross-kid correlations to understand how similar children are to each other. This analysis includes:

1. **Younger-Younger correlations**: Compare younger timepoints across different children
   - Measures similarity in representational structure during early development
   - Helps identify whether children have similar early object representations

2. **Younger-Older correlations**: Compare younger timepoint of one child with older timepoint of another child
   - Measures whether early representations of one child are similar to later representations of another child
   - Can reveal developmental patterns and individual differences

3. **Older-Older correlations**: Compare older timepoints across different children
   - Measures similarity in representational structure during later development
   - Helps identify whether children converge to similar representations as they develop

These cross-kid correlations complement the within-kid analysis by providing insights into:
- **Individual differences**: How much do children vary in their object representations?
- **Developmental timing**: Do some children's early representations resemble other children's later representations?
- **Common developmental patterns**: Are there shared trajectories across children?

In [45]:
# Compute cross-kid correlations
print("Computing cross-kid correlations...")

# Get list of subjects with both younger and older RDMs
valid_subjects = [sid for sid in subject_age_rdms.keys() 
                  if 'younger' in subject_age_rdms[sid] and 'older' in subject_age_rdms[sid]]

print(f"Computing correlations for {len(valid_subjects)} subjects")
print(f"Total pairs: {len(valid_subjects) * (len(valid_subjects) - 1) // 2}")

# Store cross-kid correlation data
cross_kid_data = []

# Compute all pairwise correlations
for i, subject_id_1 in enumerate(tqdm(valid_subjects, desc="Cross-kid correlations")):
    for subject_id_2 in valid_subjects[i+1:]:  # Only compute upper triangle (avoid duplicates)
        
        # Get RDMs and categories for both subjects
        rdm1_younger = subject_age_rdms[subject_id_1]['younger']
        rdm1_older = subject_age_rdms[subject_id_1]['older']
        cats1_younger = subject_age_rdm_categories[subject_id_1]['younger']
        cats1_older = subject_age_rdm_categories[subject_id_1]['older']
        
        rdm2_younger = subject_age_rdms[subject_id_2]['younger']
        rdm2_older = subject_age_rdms[subject_id_2]['older']
        cats2_younger = subject_age_rdm_categories[subject_id_2]['younger']
        cats2_older = subject_age_rdm_categories[subject_id_2]['older']
        
        # 1. Younger-Younger correlation (subject 1 younger vs subject 2 younger)
        corr_yy, n_common_yy = compute_rdm_correlation(
            rdm1_younger, rdm2_younger,
            ordered_categories,
            cats1_younger, cats2_younger
        )
        
        # 2. Younger-Older correlation (subject 1 younger vs subject 2 older)
        corr_yo, n_common_yo = compute_rdm_correlation(
            rdm1_younger, rdm2_older,
            ordered_categories,
            cats1_younger, cats2_older
        )
        
        # 3. Older-Younger correlation (subject 1 older vs subject 2 younger)
        corr_oy, n_common_oy = compute_rdm_correlation(
            rdm1_older, rdm2_younger,
            ordered_categories,
            cats1_older, cats2_younger
        )
        
        # 4. Older-Older correlation (subject 1 older vs subject 2 older)
        corr_oo, n_common_oo = compute_rdm_correlation(
            rdm1_older, rdm2_older,
            ordered_categories,
            cats1_older, cats2_older
        )
        
        # Store results
        cross_kid_data.append({
            'subject_id_1': subject_id_1,
            'subject_id_2': subject_id_2,
            'correlation_type': 'younger_younger',
            'correlation': corr_yy,
            'n_common_categories': n_common_yy,
            'n_categories_subject1': len(cats1_younger),
            'n_categories_subject2': len(cats2_younger)
        })
        
        cross_kid_data.append({
            'subject_id_1': subject_id_1,
            'subject_id_2': subject_id_2,
            'correlation_type': 'younger_older',
            'correlation': corr_yo,
            'n_common_categories': n_common_yo,
            'n_categories_subject1': len(cats1_younger),
            'n_categories_subject2': len(cats2_older)
        })
        
        cross_kid_data.append({
            'subject_id_1': subject_id_1,
            'subject_id_2': subject_id_2,
            'correlation_type': 'older_younger',
            'correlation': corr_oy,
            'n_common_categories': n_common_oy,
            'n_categories_subject1': len(cats1_older),
            'n_categories_subject2': len(cats2_younger)
        })
        
        cross_kid_data.append({
            'subject_id_1': subject_id_1,
            'subject_id_2': subject_id_2,
            'correlation_type': 'older_older',
            'correlation': corr_oo,
            'n_common_categories': n_common_oo,
            'n_categories_subject1': len(cats1_older),
            'n_categories_subject2': len(cats2_older)
        })

# Create DataFrame
cross_kid_df = pd.DataFrame(cross_kid_data)
cross_kid_df.to_csv(output_dir / "cross_kid_correlations.csv", index=False)

# Print summary statistics
print(f"\nCross-kid correlation analysis:")
print(f"  Total subject pairs: {len(valid_subjects) * (len(valid_subjects) - 1) // 2}")
print(f"  Total correlation measurements: {len(cross_kid_df)}")

print(f"\nMean correlations by type:")
for corr_type in ['younger_younger', 'younger_older', 'older_younger', 'older_older']:
    type_data = cross_kid_df[cross_kid_df['correlation_type'] == corr_type]
    valid_corrs = type_data['correlation'].dropna()
    if len(valid_corrs) > 0:
        print(f"  {corr_type:20s}: {valid_corrs.mean():.3f} ± {valid_corrs.std():.3f} (n={len(valid_corrs)} valid, {len(type_data)} total)")
    else:
        print(f"  {corr_type:20s}: No valid correlations (n={len(type_data)} total)")

print(f"\nSaved cross-kid correlations to {output_dir / 'cross_kid_correlations.csv'}")

Computing cross-kid correlations...
Computing correlations for 18 subjects
Total pairs: 153


Cross-kid correlations: 100%|██████████| 18/18 [00:01<00:00, 12.83it/s]


Cross-kid correlation analysis:
  Total subject pairs: 153
  Total correlation measurements: 612

Mean correlations by type:
  younger_younger     : 0.675 ± 0.101 (n=153 valid, 153 total)
  younger_older       : 0.660 ± 0.096 (n=153 valid, 153 total)
  older_younger       : 0.672 ± 0.064 (n=153 valid, 153 total)
  older_older         : 0.659 ± 0.056 (n=153 valid, 153 total)

Saved cross-kid correlations to developmental_trajectory_rdms_clip/cross_kid_correlations.csv





In [46]:
# Visualize cross-kid correlations
print("Creating visualizations for cross-kid correlations...")

# Filter out NaN correlations for plotting
valid_cross_kid_df = cross_kid_df[cross_kid_df['correlation'].notna()].copy()

# Create figure with multiple subplots
fig = plt.figure(figsize=(16, 10))

# 1. Box plot comparing correlation types
ax1 = plt.subplot(2, 3, 1)
correlation_types = ['younger_younger', 'younger_older', 'older_younger', 'older_older']
box_data = [valid_cross_kid_df[valid_cross_kid_df['correlation_type'] == ct]['correlation'].values 
            for ct in correlation_types]
labels = ['Younger-Younger', 'Younger-Older', 'Older-Younger', 'Older-Older']

bp = ax1.boxplot(box_data, labels=labels, patch_artist=True)
colors = ['#FF6B6B', '#4ECDC4', '#95E1D3', '#FFA07A']
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

ax1.set_ylabel('RDM Correlation (Spearman)', fontsize=12)
ax1.set_title('Distribution of Cross-Kid Correlations', fontsize=13, pad=10)
ax1.grid(True, alpha=0.3, axis='y')
ax1.set_ylim([0, 1])

# 2. Violin plot
ax2 = plt.subplot(2, 3, 2)
violin_data = [valid_cross_kid_df[valid_cross_kid_df['correlation_type'] == ct]['correlation'].values 
               for ct in correlation_types]
parts = ax2.violinplot(violin_data, positions=range(len(labels)), showmeans=True, showmedians=True)
for i, pc in enumerate(parts['bodies']):
    pc.set_facecolor(colors[i])
    pc.set_alpha(0.7)
ax2.set_xticks(range(len(labels)))
ax2.set_xticklabels(labels, rotation=45, ha='right')
ax2.set_ylabel('RDM Correlation (Spearman)', fontsize=12)
ax2.set_title('Distribution of Cross-Kid Correlations (Violin)', fontsize=13, pad=10)
ax2.grid(True, alpha=0.3, axis='y')
ax2.set_ylim([0, 1])

# 3. Histogram overlay
ax3 = plt.subplot(2, 3, 3)
for i, ct in enumerate(correlation_types):
    type_data = valid_cross_kid_df[valid_cross_kid_df['correlation_type'] == ct]['correlation'].values
    ax3.hist(type_data, bins=30, alpha=0.6, label=labels[i], color=colors[i], density=True)
ax3.set_xlabel('RDM Correlation', fontsize=12)
ax3.set_ylabel('Density', fontsize=12)
ax3.set_title('Distribution of Cross-Kid Correlations (Histogram)', fontsize=13, pad=10)
ax3.legend()
ax3.grid(True, alpha=0.3)
ax3.set_xlim([0, 1])

# 4. Scatter: correlation vs number of common categories
ax4 = plt.subplot(2, 3, 4)
for i, ct in enumerate(correlation_types):
    type_data = valid_cross_kid_df[valid_cross_kid_df['correlation_type'] == ct]
    ax4.scatter(type_data['n_common_categories'], type_data['correlation'], 
               label=labels[i], alpha=0.6, s=60, color=colors[i])
ax4.set_xlabel('Number of Common Categories', fontsize=12)
ax4.set_ylabel('RDM Correlation', fontsize=12)
ax4.set_title('Correlation vs Common Categories', fontsize=13, pad=10)
ax4.legend()
ax4.grid(True, alpha=0.3)
ax4.set_ylim([0, 1])

# 5. Comparison: Cross-kid vs Within-kid correlations (all types)
ax5 = plt.subplot(2, 3, 5)
# Get within-kid correlations from trajectory_df
within_kid_corrs = trajectory_df['rdm_correlation'].dropna().values

# Get all cross-kid correlation types
correlation_types = ['younger_younger', 'younger_older', 'older_younger', 'older_older']
correlation_labels = ['Younger-Younger', 'Younger-Older', 'Older-Younger', 'Older-Older']

cross_kid_data = [valid_cross_kid_df[valid_cross_kid_df['correlation_type'] == ct]['correlation'].values 
                  for ct in correlation_types]

# Create grouped box plot: within-kid + all cross-kid types
all_data = [within_kid_corrs] + cross_kid_data
all_labels = ['Within-Kid\n(Younger-Older)'] + [f'Cross-Kid\n({label})' for label in correlation_labels]

x_pos = np.arange(len(all_labels))
bp2 = ax5.boxplot(all_data, labels=all_labels, patch_artist=True)

# Color within-kid differently from cross-kid
bp2['boxes'][0].set_facecolor('#FFD93D')
bp2['boxes'][0].set_alpha(0.7)
for i in range(1, len(bp2['boxes'])):
    bp2['boxes'][i].set_facecolor('#FF6B6B')
    bp2['boxes'][i].set_alpha(0.7)

ax5.set_ylabel('RDM Correlation (Spearman)', fontsize=12)
ax5.set_title('Within-Kid vs Cross-Kid Correlations\n(All Types)', fontsize=13, pad=10)
ax5.grid(True, alpha=0.3, axis='y')
ax5.set_ylim([0, 1])
plt.setp(ax5.xaxis.get_majorticklabels(), rotation=45, ha='right', fontsize=9)

# 6. Heatmap of cross-kid correlations (younger-younger only, for clarity)
ax6 = plt.subplot(2, 3, 6)
yy_df = valid_cross_kid_df[valid_cross_kid_df['correlation_type'] == 'younger_younger'].copy()
# Create pivot table for heatmap
pivot_data = yy_df.pivot(index='subject_id_1', columns='subject_id_2', values='correlation')
# Make symmetric (fill lower triangle)
pivot_data = pivot_data.fillna(pivot_data.T)
# Sort by mean correlation for better visualization
pivot_data['mean_corr'] = pivot_data.mean(axis=1)
pivot_data = pivot_data.sort_values('mean_corr', ascending=False)
pivot_data = pivot_data.drop('mean_corr', axis=1)
pivot_data = pivot_data.T
pivot_data['mean_corr'] = pivot_data.mean(axis=1)
pivot_data = pivot_data.sort_values('mean_corr', ascending=False)
pivot_data = pivot_data.drop('mean_corr', axis=1)

im = ax6.imshow(pivot_data.values, aspect='auto', cmap='RdYlBu_r', vmin=0, vmax=1)
ax6.set_xticks(range(len(pivot_data.columns)))
ax6.set_xticklabels(pivot_data.columns, rotation=45, ha='right', fontsize=8)
ax6.set_yticks(range(len(pivot_data.index)))
ax6.set_yticklabels(pivot_data.index, fontsize=8)
ax6.set_xlabel('Subject ID 1', fontsize=11)
ax6.set_ylabel('Subject ID 2', fontsize=11)
ax6.set_title('Cross-Kid Correlations (Younger-Younger)\nHeatmap', fontsize=13, pad=10)
cbar = plt.colorbar(im, ax=ax6, label='RDM Correlation', fraction=0.046, pad=0.04)

plt.suptitle('Cross-Kid RDM Correlations Analysis', 
             fontsize=16, y=0.995, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.99])
plt.savefig(output_dir / "cross_kid_correlations_visualization.png", dpi=200, bbox_inches='tight')
print(f"Saved cross-kid correlation visualization to {output_dir / 'cross_kid_correlations_visualization.png'}")
plt.close()

print("\nVisualization complete!")

Creating visualizations for cross-kid correlations...
Saved cross-kid correlation visualization to developmental_trajectory_rdms_clip/cross_kid_correlations_visualization.png

Visualization complete!


## Cross-Kid Category Group Correlations

Compute cross-kid correlations separately for each semantic category group (animals, bodyparts, big_objects, small_objects, others). This allows us to examine whether similarity between children varies across different semantic domains.

In [47]:
# Compute cross-kid category group correlations
print("Computing cross-kid category group correlations...")

# Get category groups from organized structure
category_groups = {
    'animals': organized['animals'],
    'bodyparts': organized['bodyparts'],
    'big_objects': organized['big_objects'],
    'small_objects': organized['small_objects'],
    'others': organized['others']
}

print(f"Category group sizes: {[(name, len(cats)) for name, cats in category_groups.items()]}")

# Get list of subjects with both younger and older RDMs
valid_subjects = [sid for sid in subject_age_rdms.keys() 
                  if 'younger' in subject_age_rdms[sid] and 'older' in subject_age_rdms[sid]]

# Store cross-kid category group correlation data
cross_kid_category_data = []

# Compute all pairwise correlations by category group
for i, subject_id_1 in enumerate(tqdm(valid_subjects, desc="Cross-kid category correlations")):
    for subject_id_2 in valid_subjects[i+1:]:  # Only compute upper triangle (avoid duplicates)
        
        # Get RDMs and categories for both subjects
        rdm1_younger = subject_age_rdms[subject_id_1]['younger']
        rdm1_older = subject_age_rdms[subject_id_1]['older']
        cats1_younger = subject_age_rdm_categories[subject_id_1]['younger']
        cats1_older = subject_age_rdm_categories[subject_id_1]['older']
        
        rdm2_younger = subject_age_rdms[subject_id_2]['younger']
        rdm2_older = subject_age_rdms[subject_id_2]['older']
        cats2_younger = subject_age_rdm_categories[subject_id_2]['younger']
        cats2_older = subject_age_rdm_categories[subject_id_2]['older']
        
        # Compute correlation for each category group and each correlation type
        for group_name, group_categories in category_groups.items():
            for corr_type, (rdm1, rdm2, cats1, cats2) in [
                ('younger_younger', (rdm1_younger, rdm2_younger, cats1_younger, cats2_younger)),
                ('younger_older', (rdm1_younger, rdm2_older, cats1_younger, cats2_older)),
                ('older_younger', (rdm1_older, rdm2_younger, cats1_older, cats2_younger)),
                ('older_older', (rdm1_older, rdm2_older, cats1_older, cats2_older))
            ]:
                # Find common categories in this group that are present in both RDMs
                common_in_group = [cat for cat in group_categories 
                                  if cat in cats1 and cat in ordered_categories and cat in cats2]
                
                if len(common_in_group) < 2:
                    # Not enough categories in this group for correlation
                    cross_kid_category_data.append({
                        'subject_id_1': subject_id_1,
                        'subject_id_2': subject_id_2,
                        'correlation_type': corr_type,
                        'category_group': group_name,
                        'correlation': np.nan,
                        'n_common_categories': len(common_in_group),
                        'n_categories_subject1': len([c for c in group_categories if c in cats1]),
                        'n_categories_subject2': len([c for c in group_categories if c in cats2])
                    })
                    continue
                
                # Get indices for common categories in this group
                common_indices = [ordered_categories.index(cat) for cat in common_in_group]
                
                # Extract submatrices for this group
                rdm1_group = rdm1[np.ix_(common_indices, common_indices)]
                rdm2_group = rdm2[np.ix_(common_indices, common_indices)]
                
                # Get upper triangle (excluding diagonal)
                mask = np.triu(np.ones_like(rdm1_group, dtype=bool), k=1)
                rdm1_flat = rdm1_group[mask]
                rdm2_flat = rdm2_group[mask]
                
                # Filter out NaN values
                valid_mask = ~(np.isnan(rdm1_flat) | np.isnan(rdm2_flat))
                rdm1_valid = rdm1_flat[valid_mask]
                rdm2_valid = rdm2_flat[valid_mask]
                
                # Compute Spearman correlation
                if len(rdm1_valid) > 0:
                    corr, _ = spearmanr(rdm1_valid, rdm2_valid)
                else:
                    corr = np.nan
                
                cross_kid_category_data.append({
                    'subject_id_1': subject_id_1,
                    'subject_id_2': subject_id_2,
                    'correlation_type': corr_type,
                    'category_group': group_name,
                    'correlation': corr,
                    'n_common_categories': len(common_in_group),
                    'n_categories_subject1': len([c for c in group_categories if c in cats1]),
                    'n_categories_subject2': len([c for c in group_categories if c in cats2])
                })

# Create DataFrame
cross_kid_category_df = pd.DataFrame(cross_kid_category_data)
cross_kid_category_df.to_csv(output_dir / "cross_kid_category_group_correlations.csv", index=False)

# Print summary statistics
print(f"\nCross-kid category group correlation analysis:")
print(f"  Total measurements: {len(cross_kid_category_df)}")

print(f"\nMean correlations by category group and correlation type:")
for group_name in category_groups.keys():
    group_data = cross_kid_category_df[cross_kid_category_df['category_group'] == group_name]
    valid_corrs = group_data['correlation'].dropna()
    if len(valid_corrs) > 0:
        print(f"\n  {group_name}:")
        for corr_type in ['younger_younger', 'younger_older', 'older_younger', 'older_older']:
            type_data = group_data[group_data['correlation_type'] == corr_type]['correlation'].dropna()
            if len(type_data) > 0:
                print(f"    {corr_type:20s}: {type_data.mean():.3f} ± {type_data.std():.3f} (n={len(type_data)})")
    else:
        print(f"  {group_name}: No valid correlations")

print(f"\nSaved cross-kid category group correlations to {output_dir / 'cross_kid_category_group_correlations.csv'}")

Computing cross-kid category group correlations...
Category group sizes: [('animals', 19), ('bodyparts', 14), ('big_objects', 32), ('small_objects', 96), ('others', 2)]


Cross-kid category correlations: 100%|██████████| 18/18 [00:01<00:00, 13.81it/s]


Cross-kid category group correlation analysis:
  Total measurements: 3060

Mean correlations by category group and correlation type:

  animals:
    younger_younger     : 0.458 ± 0.149 (n=153)
    younger_older       : 0.420 ± 0.156 (n=153)
    older_younger       : 0.417 ± 0.121 (n=153)
    older_older         : 0.403 ± 0.130 (n=153)

  bodyparts:
    younger_younger     : 0.782 ± 0.099 (n=153)
    younger_older       : 0.750 ± 0.107 (n=153)
    older_younger       : 0.743 ± 0.107 (n=153)
    older_older         : 0.715 ± 0.113 (n=153)

  big_objects:
    younger_younger     : 0.631 ± 0.100 (n=153)
    younger_older       : 0.602 ± 0.099 (n=153)
    older_younger       : 0.607 ± 0.089 (n=153)
    older_older         : 0.579 ± 0.094 (n=153)

  small_objects:
    younger_younger     : 0.689 ± 0.110 (n=153)
    younger_older       : 0.681 ± 0.109 (n=153)
    older_younger       : 0.696 ± 0.069 (n=153)
    older_older         : 0.690 ± 0.061 (n=153)
  others: No valid correlations

Saved




## Null Model Comparison

To validate that the observed cross-kid correlations are meaningful and not due to chance, we compare them to a null model. The null model permutes the rows/columns of one RDM randomly, destroying the category structure while preserving the distribution of distance values. If real correlations are significantly higher than null correlations, this confirms that children's object representations share meaningful structure.

In [48]:
# Null model: Permute RDM rows/columns to destroy structure
print("Computing null model correlations...")
print("This may take a few minutes...")

n_permutations = 100  # Number of permutations for null distribution
np.random.seed(42)  # For reproducibility

# Get list of subjects with both younger and older RDMs
valid_subjects = [sid for sid in subject_age_rdms.keys() 
                  if 'younger' in subject_age_rdms[sid] and 'older' in subject_age_rdms[sid]]

# Store null model data
null_model_data = []

# Sample a subset of subject pairs for null model (to save computation time)
# Use same pairs as in real analysis, but compute fewer permutations
n_pairs_to_sample = min(50, len(valid_subjects) * (len(valid_subjects) - 1) // 2)
sampled_pairs = []

for i, subject_id_1 in enumerate(valid_subjects):
    for subject_id_2 in valid_subjects[i+1:]:
        sampled_pairs.append((subject_id_1, subject_id_2))
        if len(sampled_pairs) >= n_pairs_to_sample:
            break
    if len(sampled_pairs) >= n_pairs_to_sample:
        break

print(f"Computing null model for {len(sampled_pairs)} subject pairs with {n_permutations} permutations each...")

for subject_id_1, subject_id_2 in tqdm(sampled_pairs, desc="Null model"):
    # Get RDMs and categories for both subjects
    rdm1_younger = subject_age_rdms[subject_id_1]['younger']
    rdm1_older = subject_age_rdms[subject_id_1]['older']
    cats1_younger = subject_age_rdm_categories[subject_id_1]['younger']
    cats1_older = subject_age_rdm_categories[subject_id_1]['older']
    
    rdm2_younger = subject_age_rdms[subject_id_2]['younger']
    rdm2_older = subject_age_rdms[subject_id_2]['older']
    cats2_younger = subject_age_rdm_categories[subject_id_2]['younger']
    cats2_older = subject_age_rdm_categories[subject_id_2]['older']
    
    # For each correlation type, compute null distribution
    for corr_type, (rdm1, rdm2, cats1, cats2) in [
        ('younger_younger', (rdm1_younger, rdm2_younger, cats1_younger, cats2_younger)),
        ('younger_older', (rdm1_younger, rdm2_older, cats1_younger, cats2_older)),
        ('older_younger', (rdm1_older, rdm2_younger, cats1_older, cats2_younger)),
        ('older_older', (rdm1_older, rdm2_older, cats1_older, cats2_older))
    ]:
        # Find common categories
        common_categories = [cat for cat in ordered_categories 
                           if cat in cats1 and cat in cats2]
        
        if len(common_categories) < 2:
            continue
        
        # Get indices for common categories
        common_indices = [ordered_categories.index(cat) for cat in common_categories]
        
        # Extract submatrices for common categories
        rdm1_subset = rdm1[np.ix_(common_indices, common_indices)].copy()
        rdm2_subset = rdm2[np.ix_(common_indices, common_indices)].copy()
        
        # Get upper triangle for real correlation
        mask = np.triu(np.ones_like(rdm1_subset, dtype=bool), k=1)
        rdm1_flat = rdm1_subset[mask]
        rdm2_flat = rdm2_subset[mask]
        valid_mask = ~(np.isnan(rdm1_flat) | np.isnan(rdm2_flat))
        rdm1_valid = rdm1_flat[valid_mask]
        rdm2_valid = rdm2_flat[valid_mask]
        
        if len(rdm1_valid) < 2:
            continue
        
        # Compute real correlation
        real_corr, _ = spearmanr(rdm1_valid, rdm2_valid)
        
        # Compute null correlations by permuting rdm2
        null_corrs = []
        n_common = len(common_indices)
        
        for perm in range(n_permutations):
            # Create random permutation of indices
            perm_indices = np.random.permutation(n_common)
            
            # Permute both rows and columns of rdm2_subset
            rdm2_permuted = rdm2_subset[np.ix_(perm_indices, perm_indices)]
            
            # Get upper triangle
            rdm2_perm_flat = rdm2_permuted[mask]
            rdm2_perm_valid = rdm2_perm_flat[valid_mask]
            
            # Compute correlation
            if len(rdm2_perm_valid) > 0:
                null_corr, _ = spearmanr(rdm1_valid, rdm2_perm_valid)
                null_corrs.append(null_corr)
        
        # Store results
        if len(null_corrs) > 0:
            null_model_data.append({
                'subject_id_1': subject_id_1,
                'subject_id_2': subject_id_2,
                'correlation_type': corr_type,
                'real_correlation': real_corr,
                'null_mean': np.mean(null_corrs),
                'null_std': np.std(null_corrs),
                'null_min': np.min(null_corrs),
                'null_max': np.max(null_corrs),
                'null_median': np.median(null_corrs),
                'n_permutations': len(null_corrs),
                'n_common_categories': len(common_categories),
                'z_score': (real_corr - np.mean(null_corrs)) / (np.std(null_corrs) + 1e-10),  # Add small epsilon to avoid division by zero
                'p_value_approx': np.mean(np.array(null_corrs) >= real_corr)  # Approximate p-value
            })

# Create DataFrame
null_model_df = pd.DataFrame(null_model_data)
null_model_df.to_csv(output_dir / "null_model_correlations.csv", index=False)

# Print summary statistics
print(f"\nNull model analysis:")
print(f"  Total comparisons: {len(null_model_df)}")
print(f"  Permutations per comparison: {n_permutations}")

print(f"\nReal vs Null correlations:")
print(f"  Real correlation mean: {null_model_df['real_correlation'].mean():.3f} ± {null_model_df['real_correlation'].std():.3f}")
print(f"  Null correlation mean: {null_model_df['null_mean'].mean():.3f} ± {null_model_df['null_std'].mean():.3f}")
print(f"  Difference: {(null_model_df['real_correlation'].mean() - null_model_df['null_mean'].mean()):.3f}")

print(f"\nZ-scores (how many std devs above null):")
print(f"  Mean z-score: {null_model_df['z_score'].mean():.3f} ± {null_model_df['z_score'].std():.3f}")
print(f"  Min z-score: {null_model_df['z_score'].min():.3f}")
print(f"  Max z-score: {null_model_df['z_score'].max():.3f}")
print(f"  Proportion with z > 2: {(null_model_df['z_score'] > 2).mean():.1%}")
print(f"  Proportion with z > 3: {(null_model_df['z_score'] > 3).mean():.1%}")

print(f"\nBy correlation type:")
for corr_type in ['younger_younger', 'younger_older', 'older_younger', 'older_older']:
    type_data = null_model_df[null_model_df['correlation_type'] == corr_type]
    if len(type_data) > 0:
        print(f"  {corr_type:20s}: Real={type_data['real_correlation'].mean():.3f}, Null={type_data['null_mean'].mean():.3f}, Z={type_data['z_score'].mean():.3f}")

print(f"\nSaved null model results to {output_dir / 'null_model_correlations.csv'}")

Computing null model correlations...
This may take a few minutes...
Computing null model for 50 subject pairs with 100 permutations each...


Null model: 100%|██████████| 50/50 [00:36<00:00,  1.37it/s]


Null model analysis:
  Total comparisons: 200
  Permutations per comparison: 100

Real vs Null correlations:
  Real correlation mean: 0.641 ± 0.115
  Null correlation mean: 0.000 ± 0.011
  Difference: 0.641

Z-scores (how many std devs above null):
  Mean z-score: 63.915 ± 20.278
  Min z-score: 18.150
  Max z-score: 102.882
  Proportion with z > 2: 100.0%
  Proportion with z > 3: 100.0%

By correlation type:
  younger_younger     : Real=0.630, Null=-0.000, Z=61.445
  younger_older       : Real=0.617, Null=0.000, Z=59.793
  older_younger       : Real=0.665, Null=0.000, Z=69.146
  older_older         : Real=0.654, Null=-0.000, Z=65.277

Saved null model results to developmental_trajectory_rdms_clip/null_model_correlations.csv





In [49]:
# Visualize null model comparison
print("Creating null model visualization...")

# Create figure
fig = plt.figure(figsize=(16, 10))

# 1. Histogram: Real vs Null distributions
ax1 = plt.subplot(2, 3, 1)
ax1.hist(null_model_df['real_correlation'], bins=30, alpha=0.7, label='Real', color='#FF6B6B', density=True)
ax1.hist(null_model_df['null_mean'], bins=30, alpha=0.7, label='Null (mean)', color='#95A5A6', density=True)
ax1.axvline(null_model_df['real_correlation'].mean(), color='#FF6B6B', linestyle='--', linewidth=2, label=f"Real mean: {null_model_df['real_correlation'].mean():.3f}")
ax1.axvline(null_model_df['null_mean'].mean(), color='#95A5A6', linestyle='--', linewidth=2, label=f"Null mean: {null_model_df['null_mean'].mean():.3f}")
ax1.set_xlabel('RDM Correlation', fontsize=12)
ax1.set_ylabel('Density', fontsize=12)
ax1.set_title('Real vs Null Correlation Distributions', fontsize=13, pad=10)
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Scatter: Real vs Null mean
ax2 = plt.subplot(2, 3, 2)
ax2.scatter(null_model_df['null_mean'], null_model_df['real_correlation'], alpha=0.6, s=60)
ax2.plot([0, 1], [0, 1], 'r--', alpha=0.5, label='y=x')
ax2.set_xlabel('Null Correlation (mean)', fontsize=12)
ax2.set_ylabel('Real Correlation', fontsize=12)
ax2.set_title('Real vs Null Correlations', fontsize=13, pad=10)
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_xlim([-0.1, 0.3])
ax2.set_ylim([0.3, 1.0])

# 3. Z-score distribution
ax3 = plt.subplot(2, 3, 3)
ax3.hist(null_model_df['z_score'], bins=30, alpha=0.7, color='#4ECDC4', edgecolor='black')
ax3.axvline(0, color='red', linestyle='--', linewidth=2, label='Null (z=0)')
ax3.axvline(2, color='orange', linestyle='--', linewidth=1, label='z=2')
ax3.axvline(3, color='darkorange', linestyle='--', linewidth=1, label='z=3')
ax3.set_xlabel('Z-Score', fontsize=12)
ax3.set_ylabel('Frequency', fontsize=12)
ax3.set_title('Z-Score Distribution\n(Real correlations vs Null)', fontsize=13, pad=10)
ax3.legend()
ax3.grid(True, alpha=0.3)

# 4. Box plot: Real vs Null by correlation type
ax4 = plt.subplot(2, 3, 4)
correlation_types = ['younger_younger', 'younger_older', 'older_younger', 'older_older']
labels = ['Younger-Younger', 'Younger-Older', 'Older-Younger', 'Older-Older']

real_data = [null_model_df[null_model_df['correlation_type'] == ct]['real_correlation'].values for ct in correlation_types]
null_data = [null_model_df[null_model_df['correlation_type'] == ct]['null_mean'].values for ct in correlation_types]

x_pos = np.arange(len(correlation_types))
width = 0.35

bp1 = ax4.boxplot(real_data, positions=x_pos - width/2, widths=width, patch_artist=True)
bp2 = ax4.boxplot(null_data, positions=x_pos + width/2, widths=width, patch_artist=True)

for patch in bp1['boxes']:
    patch.set_facecolor('#FF6B6B')
    patch.set_alpha(0.7)
for patch in bp2['boxes']:
    patch.set_facecolor('#95A5A6')
    patch.set_alpha(0.7)

ax4.set_ylabel('RDM Correlation', fontsize=12)
ax4.set_title('Real vs Null by Correlation Type', fontsize=13, pad=10)
ax4.legend([bp1['boxes'][0], bp2['boxes'][0]], ['Real', 'Null'], loc='upper left')
ax4.grid(True, alpha=0.3, axis='y')
ax4.set_xticks(x_pos)
ax4.set_xticklabels(labels, rotation=45, ha='right')
ax4.set_ylim([-0.1, 1.0])

# 5. Difference (Real - Null) distribution
ax5 = plt.subplot(2, 3, 5)
diff = null_model_df['real_correlation'] - null_model_df['null_mean']
ax5.hist(diff, bins=30, alpha=0.7, color='#FFA07A', edgecolor='black')
ax5.axvline(0, color='red', linestyle='--', linewidth=2, label='No difference')
ax5.axvline(diff.mean(), color='blue', linestyle='--', linewidth=2, label=f"Mean: {diff.mean():.3f}")
ax5.set_xlabel('Real - Null Correlation', fontsize=12)
ax5.set_ylabel('Frequency', fontsize=12)
ax5.set_title('Difference Distribution\n(Real - Null)', fontsize=13, pad=10)
ax5.legend()
ax5.grid(True, alpha=0.3)

# 6. Z-score by correlation type
ax6 = plt.subplot(2, 3, 6)
z_data = [null_model_df[null_model_df['correlation_type'] == ct]['z_score'].values for ct in correlation_types]
bp3 = ax6.boxplot(z_data, labels=labels, patch_artist=True)
for patch in bp3['boxes']:
    patch.set_facecolor('#4ECDC4')
    patch.set_alpha(0.7)
ax6.axhline(0, color='red', linestyle='--', linewidth=1, label='Null (z=0)')
ax6.axhline(2, color='orange', linestyle='--', linewidth=1, label='z=2')
ax6.axhline(3, color='darkorange', linestyle='--', linewidth=1, label='z=3')
ax6.set_ylabel('Z-Score', fontsize=12)
ax6.set_title('Z-Score by Correlation Type', fontsize=13, pad=10)
ax6.legend(loc='upper left')
ax6.grid(True, alpha=0.3, axis='y')
ax6.set_xticklabels(labels, rotation=45, ha='right')

plt.suptitle('Null Model Comparison: Real vs Permuted RDMs', 
             fontsize=16, y=0.995, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.99])
plt.savefig(output_dir / "null_model_comparison.png", dpi=200, bbox_inches='tight')
print(f"Saved null model visualization to {output_dir / 'null_model_comparison.png'}")
plt.close()

print("\nVisualization complete!")

Creating null model visualization...
Saved null model visualization to developmental_trajectory_rdms_clip/null_model_comparison.png

Visualization complete!


In [50]:
# Visualize cross-kid category group correlations
print("Creating visualizations for cross-kid category group correlations...")

# Filter out NaN correlations for plotting
valid_category_df = cross_kid_category_df[cross_kid_category_df['correlation'].notna()].copy()

# Create figure
fig = plt.figure(figsize=(18, 12))

# 1. Box plot comparing category groups (all correlation types combined)
ax1 = plt.subplot(2, 3, 1)
category_order = ['animals', 'bodyparts', 'big_objects', 'small_objects', 'others']
box_data = [valid_category_df[valid_category_df['category_group'] == group]['correlation'].values 
            for group in category_order if group in valid_category_df['category_group'].values]

labels_filtered = []
box_data_filtered = []
for group in category_order:
    group_data = valid_category_df[valid_category_df['category_group'] == group]['correlation'].values
    if len(group_data) > 0:
        box_data_filtered.append(group_data)
        labels_filtered.append(group.replace('_', ' ').title())

bp = ax1.boxplot(box_data_filtered, labels=labels_filtered, patch_artist=True)
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#FFA07A', '#98D8C8']
for patch, color in zip(bp['boxes'], colors[:len(bp['boxes'])]):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

ax1.set_ylabel('RDM Correlation (Spearman)', fontsize=12)
ax1.set_title('Cross-Kid Correlations by Category Group\n(All correlation types)', fontsize=13, pad=10)
ax1.grid(True, alpha=0.3, axis='y')
ax1.set_ylim([0, 1])
plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45, ha='right')

# 2. Heatmap: Mean correlations by category group and correlation type
ax2 = plt.subplot(2, 3, 2)
pivot_data = valid_category_df.groupby(['category_group', 'correlation_type'])['correlation'].mean().unstack(fill_value=np.nan)
pivot_data = pivot_data.reindex(category_order)
pivot_data = pivot_data[['younger_younger', 'younger_older', 'older_younger', 'older_older']]

im = ax2.imshow(pivot_data.values, aspect='auto', cmap='RdYlBu_r', vmin=0, vmax=1)
ax2.set_xticks(range(len(pivot_data.columns)))
ax2.set_xticklabels([col.replace('_', '-\n').title() for col in pivot_data.columns], fontsize=9)
ax2.set_yticks(range(len(pivot_data.index)))
ax2.set_yticklabels([idx.replace('_', ' ').title() for idx in pivot_data.index], fontsize=10)
ax2.set_xlabel('Correlation Type', fontsize=11)
ax2.set_ylabel('Category Group', fontsize=11)
ax2.set_title('Mean Cross-Kid Correlations\nby Category Group & Type', fontsize=13, pad=10)
cbar = plt.colorbar(im, ax=ax2, label='Mean Correlation', fraction=0.046, pad=0.04)

# Add text annotations
for i in range(len(pivot_data.index)):
    for j in range(len(pivot_data.columns)):
        val = pivot_data.iloc[i, j]
        if not np.isnan(val):
            ax2.text(j, i, f'{val:.2f}', ha='center', va='center', 
                   fontsize=8, color='white' if val < 0.5 else 'black', fontweight='bold')

# 3. Violin plot by category group
ax3 = plt.subplot(2, 3, 3)
violin_data = [valid_category_df[valid_category_df['category_group'] == group]['correlation'].values 
               for group in category_order if group in valid_category_df['category_group'].values]
parts = ax3.violinplot(violin_data, positions=range(len(labels_filtered)), showmeans=True, showmedians=True)
for i, pc in enumerate(parts['bodies']):
    pc.set_facecolor(colors[i % len(colors)])
    pc.set_alpha(0.7)
ax3.set_xticks(range(len(labels_filtered)))
ax3.set_xticklabels(labels_filtered, rotation=45, ha='right')
ax3.set_ylabel('RDM Correlation (Spearman)', fontsize=12)
ax3.set_title('Distribution by Category Group\n(Violin Plot)', fontsize=13, pad=10)
ax3.grid(True, alpha=0.3, axis='y')
ax3.set_ylim([0, 1])

# 4. Comparison: Cross-kid category groups vs Within-kid category groups
ax4 = plt.subplot(2, 3, 4)
# Load within-kid category correlations if available
try:
    within_category_df = pd.read_csv(output_dir / "category_group_correlations.csv")
    within_valid = within_category_df[within_category_df['correlation'].notna()]
    
    comparison_data = []
    comparison_labels = []
    for group in category_order:
        cross_data = valid_category_df[valid_category_df['category_group'] == group]['correlation'].values
        within_data = within_valid[within_valid['category_group'] == group]['correlation'].values
        if len(cross_data) > 0 and len(within_data) > 0:
            comparison_data.append(cross_data)
            comparison_data.append(within_data)
            comparison_labels.append(f"{group.replace('_', ' ').title()}\n(Cross)")
            comparison_labels.append(f"{group.replace('_', ' ').title()}\n(Within)")
    
    if len(comparison_data) > 0:
        bp2 = ax4.boxplot(comparison_data, labels=comparison_labels, patch_artist=True)
        # Color cross-kid in one color, within-kid in another
        for i, patch in enumerate(bp2['boxes']):
            if i % 2 == 0:  # Cross-kid
                patch.set_facecolor('#FF6B6B')
            else:  # Within-kid
                patch.set_facecolor('#4ECDC4')
            patch.set_alpha(0.7)
        ax4.set_ylabel('RDM Correlation', fontsize=12)
        ax4.set_title('Cross-Kid vs Within-Kid\nby Category Group', fontsize=13, pad=10)
        ax4.legend([bp2['boxes'][0], bp2['boxes'][1]], ['Cross-Kid', 'Within-Kid'], loc='upper left')
        ax4.grid(True, alpha=0.3, axis='y')
        ax4.set_ylim([0, 1])
        plt.setp(ax4.xaxis.get_majorticklabels(), rotation=45, ha='right', fontsize=8)
except:
    ax4.text(0.5, 0.5, 'Within-kid category\ndata not available', 
            ha='center', va='center', transform=ax4.transAxes, fontsize=12)
    ax4.set_title('Cross-Kid vs Within-Kid\n(Data not available)', fontsize=13, pad=10)

# 5. Scatter: Correlation vs number of common categories by group
ax5 = plt.subplot(2, 3, 5)
for i, group in enumerate(category_order):
    group_data = valid_category_df[valid_category_df['category_group'] == group]
    if len(group_data) > 0:
        ax5.scatter(group_data['n_common_categories'], group_data['correlation'], 
                   label=group.replace('_', ' ').title(), alpha=0.6, s=40, color=colors[i % len(colors)])
ax5.set_xlabel('Number of Common Categories', fontsize=12)
ax5.set_ylabel('RDM Correlation', fontsize=12)
ax5.set_title('Correlation vs Common Categories\nby Category Group', fontsize=13, pad=10)
ax5.legend(loc='best', fontsize=8)
ax5.grid(True, alpha=0.3)
ax5.set_ylim([0, 1])

# 6. Bar plot: Mean correlations by category group
ax6 = plt.subplot(2, 3, 6)
mean_corrs = []
std_corrs = []
group_labels = []
for group in category_order:
    group_data = valid_category_df[valid_category_df['category_group'] == group]['correlation']
    if len(group_data) > 0:
        mean_corrs.append(group_data.mean())
        std_corrs.append(group_data.std())
        group_labels.append(group.replace('_', ' ').title())

bars = ax6.bar(range(len(group_labels)), mean_corrs, yerr=std_corrs, 
              color=colors[:len(group_labels)], alpha=0.7, edgecolor='black')
ax6.set_xticks(range(len(group_labels)))
ax6.set_xticklabels(group_labels, rotation=45, ha='right')
ax6.set_ylabel('Mean RDM Correlation', fontsize=12)
ax6.set_title('Mean Cross-Kid Correlations\nby Category Group', fontsize=13, pad=10)
ax6.grid(True, alpha=0.3, axis='y')
ax6.set_ylim([0, 1])

# Add value labels on bars
for i, (bar, mean, std) in enumerate(zip(bars, mean_corrs, std_corrs)):
    height = bar.get_height()
    ax6.text(bar.get_x() + bar.get_width()/2., height + std + 0.02,
            f'{mean:.3f}', ha='center', va='bottom', fontsize=9, fontweight='bold')

plt.suptitle('Cross-Kid Category Group Correlations Analysis', 
             fontsize=16, y=0.995, fontweight='bold')
plt.tight_layout(rect=[0, 0, 1, 0.99])
plt.savefig(output_dir / "cross_kid_category_group_correlations_visualization.png", dpi=200, bbox_inches='tight')
print(f"Saved cross-kid category group visualization to {output_dir / 'cross_kid_category_group_correlations_visualization.png'}")
plt.close()

print("\nVisualization complete!")

Creating visualizations for cross-kid category group correlations...
Saved cross-kid category group visualization to developmental_trajectory_rdms_clip/cross_kid_category_group_correlations_visualization.png

Visualization complete!
