# Developmental Trajectory RDM Analysis

This notebook creates multiple Representational Dissimilarity Matrices (RDMs) for each individual subject, binned by age in months (age_mo).
This allows tracking how object representations change developmentally within each subject.

## Overview

This analysis:
1. Loads grouped embeddings (averaged by category, subject, and age_mo)
2. Bins embeddings by age_mo for each subject
3. Computes RDM for each subject at each age_mo bin
4. Handles data density differences (some subjects/ages have more data)
5. Visualizes developmental trajectories
6. Compares RDMs across age bins within subjects

## Key Features

- **Age binning**: Groups embeddings by age_mo to track developmental changes
- **Data density handling**: Minimum category threshold per age bin
- **Trajectory analysis**: Compare RDMs across age bins to see developmental changes
- **Missing data handling**: Only includes age bins with sufficient data


## Setup and Imports


In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import spearmanr, pearsonr
from collections import defaultdict
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set matplotlib backend
import matplotlib
matplotlib.use('Agg')

print("All imports successful!")


## Configuration


In [None]:
# Paths
embeddings_dir = Path("/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/clip_embeddings_grouped_by_age-mo")
output_dir = Path("developmental_trajectory_rdms")
output_dir.mkdir(exist_ok=True, parents=True)

# Categories file (optional - to filter to specific categories)
categories_file = Path("../../data/things_bv_overlap_categories_exclude_zero_precisions.txt")

# Minimum categories required per age_mo bin to compute RDM
min_categories_per_age_bin = 8

# Minimum number of age bins required per subject to include in analysis
min_age_bins_per_subject = 2

# Age binning strategy: 'exact' (use exact age_mo) or 'binned' (group into bins)
age_binning_strategy = 'exact'  # or 'binned'
age_bin_size = 3  # if using 'binned', group ages into bins of this size (e.g., 3 months)

print(f"Embeddings directory: {embeddings_dir}")
print(f"Output directory: {output_dir}")
print(f"Min categories per age bin: {min_categories_per_age_bin}")
print(f"Min age bins per subject: {min_age_bins_per_subject}")
print(f"Age binning strategy: {age_binning_strategy}")


## Load Category List (Optional)


In [None]:
# Load allowed categories if file exists
allowed_categories = None
if categories_file.exists():
    print(f"Loading categories from {categories_file}...")
    with open(categories_file, 'r') as f:
        allowed_categories = set(line.strip() for line in f if line.strip())
    print(f"Loaded {len(allowed_categories)} categories")
else:
    print(f"Categories file not found, using all categories")


## Load Embeddings by Age


In [None]:
def load_embeddings_by_age(embeddings_dir, allowed_categories=None, age_binning_strategy='exact', age_bin_size=3):
    """
    Load embeddings organized by subject, age_mo, and category.
    
    Returns:
        subject_age_embeddings: dict[subject_id][age_mo_bin][category] = embedding array
    """
    subject_age_embeddings = defaultdict(lambda: defaultdict(dict))
    
    # Get all category folders
    category_folders = [f for f in embeddings_dir.iterdir() if f.is_dir()]
    
    if allowed_categories:
        category_folders = [f for f in category_folders if f.name in allowed_categories]
    
    print(f"Loading embeddings from {len(category_folders)} categories...")
    
    for category_folder in tqdm(category_folders, desc="Loading categories"):
        category = category_folder.name
        
        # Get all embedding files in this category
        embedding_files = list(category_folder.glob("*.npy"))
        
        for emb_file in embedding_files:
            # Parse filename: {subject_id}_{age_mo}_month_level_avg.npy
            filename = emb_file.stem  # without .npy
            parts = filename.split('_')
            
            if len(parts) < 2:
                continue
            
            # Extract subject_id and age_mo
            subject_id = parts[0]
            age_mo = int(parts[1]) if parts[1].isdigit() else None
            
            if age_mo is None:
                continue
            
            # Apply age binning strategy
            if age_binning_strategy == 'binned':
                age_mo_bin = (age_mo // age_bin_size) * age_bin_size  # Round down to bin
            else:
                age_mo_bin = age_mo  # Use exact age
            
            try:
                embedding = np.load(emb_file)
                subject_age_embeddings[subject_id][age_mo_bin][category] = embedding
            except Exception as e:
                print(f"Error loading {emb_file}: {e}")
                continue
    
    return subject_age_embeddings

# Load embeddings
subject_age_embeddings = load_embeddings_by_age(
    embeddings_dir, 
    allowed_categories, 
    age_binning_strategy=age_binning_strategy,
    age_bin_size=age_bin_size
)

print(f"\nLoaded embeddings for {len(subject_age_embeddings)} subjects")

# Show age bin distribution
all_age_bins = set()
for subject_id, age_data in subject_age_embeddings.items():
    all_age_bins.update(age_data.keys())

print(f"Age bins found: {sorted(all_age_bins)}")
print(f"Age range: {min(all_age_bins)} to {max(all_age_bins)} months")


## Normalize Embeddings

Before computing RDMs, we normalize embeddings using z-score normalization (mean=0, std=1) for each age bin to ensure fair comparisons.


In [None]:
# Normalize embeddings per subject and age bin (z-score normalization)
# This is appropriate for developmental trajectory analysis where we want to see
# relative structure changes within each subject over time
print("Normalizing embeddings per subject and age bin...")
subject_age_embeddings_normalized = {}

for subject_id, age_data in tqdm(subject_age_embeddings.items(), desc="Normalizing"):
    subject_age_embeddings_normalized[subject_id] = {}
    
    for age_mo, categories in age_data.items():
        # Build embedding matrix for this age bin
        # Flatten each embedding to ensure 1D (in case they have shape (1, 512) instead of (512,))
        embedding_matrix = np.array([categories[cat].flatten() for cat in categories.keys()])
        
        # Ensure 2D shape: (n_categories, embedding_dim)
        if embedding_matrix.ndim != 2:
            raise ValueError(f"Expected 2D embedding matrix, got shape {embedding_matrix.shape}")
        
        # Per-age-bin z-score normalization: (x - mean) / std
        # This normalizes relative to the categories present at this age for this subject
        # Normalize across embedding dimensions (axis=0 means across categories)
        normalized_matrix = (embedding_matrix - embedding_matrix.mean(axis=0)) / (embedding_matrix.std(axis=0) + 1e-10)
        
        # Store normalized embeddings
        subject_age_embeddings_normalized[subject_id][age_mo] = {
            cat: normalized_matrix[i] 
            for i, cat in enumerate(categories.keys())
        }

print(f"Normalized embeddings for {len(subject_age_embeddings_normalized)} subjects")
print("  Note: Each age bin is normalized independently, focusing on relative structure within that age bin")

def compute_rdm_for_age_bin(age_embeddings_dict, categories_list):
    """
    Compute RDM for a single age bin.
    
    Args:
        age_embeddings_dict: dict[category] = embedding array (should be normalized)
        categories_list: list of categories to include (in order)
    
    Returns:
        rdm: numpy array of shape (n_categories, n_categories) or None
        available_categories: list of categories actually present
    """
    # Filter to categories that exist for this age bin
    available_categories = [cat for cat in categories_list if cat in age_embeddings_dict]
    
    if len(available_categories) < min_categories_per_age_bin:
        return None, available_categories
    
    # Build embedding matrix (already normalized)
    # Flatten each embedding to ensure 1D (in case they have shape (1, 512) instead of (512,))
    embedding_matrix = np.array([age_embeddings_dict[cat].flatten() for cat in available_categories])
    
    # Ensure 2D shape: (n_categories, embedding_dim)
    if embedding_matrix.ndim != 2:
        raise ValueError(f"Expected 2D embedding matrix, got shape {embedding_matrix.shape}")
    
    # Compute cosine similarity
    similarity_matrix = cosine_similarity(embedding_matrix)
    
    # Convert to distance (RDM)
    distance_matrix = 1 - similarity_matrix
    np.fill_diagonal(distance_matrix, 0)  # Ensure diagonal is 0
    
    # Make symmetric (in case of numerical errors)
    distance_matrix = (distance_matrix + distance_matrix.T) / 2
    
    return distance_matrix, available_categories

# Get all unique categories across all subjects and ages
all_categories = set()
for subject_id, age_data in subject_age_embeddings_normalized.items():
    for age_mo, categories in age_data.items():
        all_categories.update(categories.keys())

all_categories = sorted(list(all_categories))
print(f"Total unique categories across all subjects and ages: {len(all_categories)}")

# Compute RDMs for each subject at each age bin using normalized embeddings
subject_age_rdms = {}
subject_age_rdm_categories = {}

for subject_id, age_data in tqdm(subject_age_embeddings_normalized.items(), desc="Computing RDMs"):
    subject_age_rdms[subject_id] = {}
    subject_age_rdm_categories[subject_id] = {}
    
    for age_mo, categories in age_data.items():
        rdm, available_cats = compute_rdm_for_age_bin(categories, all_categories)
        
        if rdm is not None:
            subject_age_rdms[subject_id][age_mo] = rdm
            subject_age_rdm_categories[subject_id][age_mo] = available_cats
    
    # Filter out subjects with too few age bins
    if len(subject_age_rdms[subject_id]) < min_age_bins_per_subject:
        del subject_age_rdms[subject_id]
        del subject_age_rdm_categories[subject_id]

print(f"\nComputed RDMs for {len(subject_age_rdms)} subjects")
print(f"  (Excluded subjects with < {min_age_bins_per_subject} age bins with sufficient data)")

# Show distribution of age bins per subject
age_bin_counts = [len(age_rdms) for age_rdms in subject_age_rdms.values()]
print(f"\nAge bins per subject:")
print(f"  Min: {min(age_bin_counts) if age_bin_counts else 0}")
print(f"  Max: {max(age_bin_counts) if age_bin_counts else 0}")
print(f"  Mean: {np.mean(age_bin_counts):.1f}" if age_bin_counts else "  Mean: 0")
print(f"  Median: {np.median(age_bin_counts):.1f}" if age_bin_counts else "  Median: 0")


## Compute RDMs for Each Subject at Each Age Bin


In [None]:
# Save RDMs for each subject-age combination
print("Saving developmental trajectory RDMs...")

for subject_id, age_rdms in tqdm(subject_age_rdms.items(), desc="Saving RDMs"):
    subject_output_dir = output_dir / subject_id
    subject_output_dir.mkdir(exist_ok=True, parents=True)
    
    for age_mo, rdm in age_rdms.items():
        categories = subject_age_rdm_categories[subject_id][age_mo]
        
        # Save as numpy array
        np.save(subject_output_dir / f"rdm_age_{age_mo}.npy", rdm)
        
        # Save as CSV with category labels
        rdm_df = pd.DataFrame(rdm, index=categories, columns=categories)
        rdm_df.to_csv(subject_output_dir / f"rdm_age_{age_mo}.csv")
        
        # Save metadata
        metadata = {
            'subject_id': subject_id,
            'age_mo': age_mo,
            'n_categories': len(categories),
            'categories': categories,
            'mean_distance': float(rdm.mean()),
            'std_distance': float(rdm.std())
        }
        
        metadata_df = pd.DataFrame([metadata])
        metadata_df.to_csv(subject_output_dir / f"metadata_age_{age_mo}.csv", index=False)

print(f"\nSaved RDMs to {output_dir}")


## Analyze Developmental Trajectories


In [None]:
def compute_rdm_correlation(rdm1, rdm2, categories1, categories2):
    """
    Compute correlation between two RDMs.
    Only uses categories present in both RDMs.
    """
    # Find common categories
    common_categories = sorted(list(set(categories1) & set(categories2)))
    
    if len(common_categories) < 2:
        return np.nan, len(common_categories)
    
    # Get indices for common categories
    idx1 = [categories1.index(cat) for cat in common_categories]
    idx2 = [categories2.index(cat) for cat in common_categories]
    
    # Extract upper triangle (excluding diagonal) for both RDMs
    rdm1_subset = rdm1[np.ix_(idx1, idx1)]
    rdm2_subset = rdm2[np.ix_(idx2, idx2)]
    
    # Get upper triangle
    mask = np.triu(np.ones_like(rdm1_subset, dtype=bool), k=1)
    rdm1_flat = rdm1_subset[mask]
    rdm2_flat = rdm2_subset[mask]
    
    # Compute Spearman correlation (more robust to outliers)
    if len(rdm1_flat) > 0:
        corr, _ = spearmanr(rdm1_flat, rdm2_flat)
        return corr, len(common_categories)
    else:
        return np.nan, len(common_categories)

# Compute RDM correlations across age bins for each subject
trajectory_data = []

for subject_id, age_rdms in tqdm(subject_age_rdms.items(), desc="Analyzing trajectories"):
    ages = sorted(age_rdms.keys())
    
    if len(ages) < 2:
        continue
    
    # Compute pairwise correlations between consecutive age bins
    for i in range(len(ages) - 1):
        age1 = ages[i]
        age2 = ages[i + 1]
        
        rdm1 = age_rdms[age1]
        rdm2 = age_rdms[age2]
        cats1 = subject_age_rdm_categories[subject_id][age1]
        cats2 = subject_age_rdm_categories[subject_id][age2]
        
        corr, n_common = compute_rdm_correlation(rdm1, rdm2, cats1, cats2)
        
        trajectory_data.append({
            'subject_id': subject_id,
            'age1': age1,
            'age2': age2,
            'age_diff': age2 - age1,
            'rdm_correlation': corr,
            'n_common_categories': n_common,
            'n_categories_age1': len(cats1),
            'n_categories_age2': len(cats2)
        })

trajectory_df = pd.DataFrame(trajectory_data)
trajectory_df.to_csv(output_dir / "trajectory_correlations.csv", index=False)

print(f"\nTrajectory analysis:")
print(f"  Total age transitions analyzed: {len(trajectory_df)}")
print(f"  Mean RDM correlation: {trajectory_df['rdm_correlation'].mean():.3f}")
print(f"  Std RDM correlation: {trajectory_df['rdm_correlation'].std():.3f}")
print(f"\nSaved trajectory correlations to {output_dir / 'trajectory_correlations.csv'}")


## Visualize Developmental Trajectories


In [None]:
# Visualize RDMs for a few subjects across age bins
n_sample_subjects = min(3, len(subject_age_rdms))
sample_subjects = list(subject_age_rdms.keys())[:n_sample_subjects]

for subject_id in sample_subjects:
    age_rdms = subject_age_rdms[subject_id]
    ages = sorted(age_rdms.keys())
    
    n_ages = len(ages)
    fig, axes = plt.subplots(1, n_ages, figsize=(6*n_ages, 5))
    
    if n_ages == 1:
        axes = [axes]
    
    for idx, age_mo in enumerate(ages):
        rdm = age_rdms[age_mo]
        categories = subject_age_rdm_categories[subject_id][age_mo]
        
        ax = axes[idx]
        im = ax.imshow(rdm, cmap='viridis', aspect='auto')
        ax.set_title(f"Age {age_mo} months\n({len(categories)} categories)", fontsize=12)
        ax.set_xlabel('Category')
        ax.set_ylabel('Category')
        plt.colorbar(im, ax=ax)
    
    plt.suptitle(f"Developmental Trajectory: {subject_id}", fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig(output_dir / f"trajectory_{subject_id}.png", dpi=150, bbox_inches='tight')
    plt.close()

print(f"Saved trajectory visualizations for {n_sample_subjects} sample subjects")


## Plot RDM Stability Across Development


In [None]:
# Plot RDM correlation as a function of age difference
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# RDM correlation vs age difference
axes[0].scatter(trajectory_df['age_diff'], trajectory_df['rdm_correlation'], alpha=0.6)
axes[0].set_xlabel('Age Difference (months)')
axes[0].set_ylabel('RDM Correlation (Spearman)')
axes[0].set_title('RDM Stability vs Age Gap')
axes[0].grid(True, alpha=0.3)

# RDM correlation vs mean age
trajectory_df['mean_age'] = (trajectory_df['age1'] + trajectory_df['age2']) / 2
axes[1].scatter(trajectory_df['mean_age'], trajectory_df['rdm_correlation'], alpha=0.6)
axes[1].set_xlabel('Mean Age (months)')
axes[1].set_ylabel('RDM Correlation (Spearman)')
axes[1].set_title('RDM Stability vs Age')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_dir / "rdm_stability_analysis.png", dpi=150, bbox_inches='tight')
print(f"Saved RDM stability analysis to {output_dir / 'rdm_stability_analysis.png'}")
plt.close()


## Summary Statistics


In [None]:
# Create summary statistics
summary_data = []

for subject_id, age_rdms in subject_age_rdms.items():
    ages = sorted(age_rdms.keys())
    
    for age_mo in ages:
        rdm = age_rdms[age_mo]
        categories = subject_age_rdm_categories[subject_id][age_mo]
        
        summary_data.append({
            'subject_id': subject_id,
            'age_mo': age_mo,
            'n_categories': len(categories),
            'mean_distance': float(rdm.mean()),
            'std_distance': float(rdm.std()),
            'min_distance': float(rdm[rdm > 0].min()) if (rdm > 0).any() else np.nan,
            'max_distance': float(rdm.max())
        })

summary_df = pd.DataFrame(summary_data)
summary_df.to_csv(output_dir / "summary_statistics.csv", index=False)

print("Summary statistics:")
print(summary_df.describe())
print(f"\nSaved summary to {output_dir / 'summary_statistics.csv'}")
