# Individual Subject RDM Analysis

This notebook creates Representational Dissimilarity Matrices (RDMs) for each individual subject.
Each subject's RDM shows the similarity structure of object categories based on their averaged embeddings.

## Overview

This analysis:
1. Loads grouped embeddings (averaged by category, subject, and age_mo)
2. Aggregates embeddings per subject across all age_mo (weighted average if multiple age bins)
3. Computes RDM for each subject using cosine distance
4. Handles data density differences between subjects
5. Visualizes and saves individual subject RDMs

## Key Features

- **Data density handling**: Subjects with more data get more reliable RDMs
- **Missing category handling**: Only includes categories present for each subject
- **Weighted averaging**: If a subject has multiple age_mo bins, embeddings are weighted by sample count


## Setup and Imports


In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity, cosine_distances
from collections import defaultdict
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set matplotlib backend
import matplotlib
matplotlib.use('Agg')

print("All imports successful!")


## Configuration


In [None]:
# Paths
embeddings_dir = Path("/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/clip_embeddings_grouped_by_age-mo")
output_dir = Path("individual_subject_rdms")
output_dir.mkdir(exist_ok=True, parents=True)

# Categories file (optional - to filter to specific categories)
categories_file = Path("../../data/things_bv_overlap_categories_exclude_zero_precisions.txt")

# Minimum categories required per subject to compute RDM
min_categories_per_subject = 10

# Whether to weight by sample count when aggregating across age_mo
weight_by_sample_count = True

print(f"Embeddings directory: {embeddings_dir}")
print(f"Output directory: {output_dir}")
print(f"Min categories per subject: {min_categories_per_subject}")


## Load Category List (Optional)


In [None]:
# Load allowed categories if file exists
allowed_categories = None
if categories_file.exists():
    print(f"Loading categories from {categories_file}...")
    with open(categories_file, 'r') as f:
        allowed_categories = set(line.strip() for line in f if line.strip())
    print(f"Loaded {len(allowed_categories)} categories")
else:
    print(f"Categories file not found, using all categories")


## Load Embeddings


In [None]:
def load_subject_embeddings(embeddings_dir, allowed_categories=None):
    """
    Load embeddings organized by subject and category.
    
    Returns:
        subject_embeddings: dict[subject_id][category] = {
            'embedding': np.array,  # averaged embedding
            'age_mo': int,  # age in months
            'sample_count': int  # number of embeddings averaged
        }
    """
    subject_embeddings = defaultdict(lambda: defaultdict(dict))
    
    # Get all category folders
    category_folders = [f for f in embeddings_dir.iterdir() if f.is_dir()]
    
    if allowed_categories:
        category_folders = [f for f in category_folders if f.name in allowed_categories]
    
    print(f"Loading embeddings from {len(category_folders)} categories...")
    
    for category_folder in tqdm(category_folders, desc="Loading categories"):
        category = category_folder.name
        
        # Get all embedding files in this category
        embedding_files = list(category_folder.glob("*.npy"))
        
        for emb_file in embedding_files:
            # Parse filename: {subject_id}_{age_mo}_month_level_avg.npy
            filename = emb_file.stem  # without .npy
            parts = filename.split('_')
            
            if len(parts) < 2:
                continue
            
            # Extract subject_id and age_mo
            # Format: S00560001_16_month_level_avg
            subject_id = parts[0]
            age_mo = int(parts[1]) if parts[1].isdigit() else None
            
            if age_mo is None:
                continue
            
            try:
                embedding = np.load(emb_file)
                
                # Store embedding with metadata
                # If subject already has this category at this age_mo, we'll average them later
                key = (category, age_mo)
                if key not in subject_embeddings[subject_id][category]:
                    subject_embeddings[subject_id][category][key] = {
                        'embedding': embedding,
                        'age_mo': age_mo,
                        'sample_count': 1  # We don't have this info, assume 1
                    }
                else:
                    # Average if duplicate
                    existing = subject_embeddings[subject_id][category][key]
                    n_existing = existing['sample_count']
                    n_new = 1
                    existing['embedding'] = (existing['embedding'] * n_existing + embedding) / (n_existing + n_new)
                    existing['sample_count'] += n_new
            except Exception as e:
                print(f"Error loading {emb_file}: {e}")
                continue
    
    return subject_embeddings

# Load embeddings
subject_embeddings_raw = load_subject_embeddings(embeddings_dir, allowed_categories)
print(f"\nLoaded embeddings for {len(subject_embeddings_raw)} subjects")


## Aggregate Embeddings Per Subject


In [None]:
def aggregate_subject_embeddings(subject_embeddings_raw, weight_by_sample_count=True):
    """
    Aggregate embeddings per subject across age_mo bins.
    
    For each subject-category pair, if there are multiple age_mo bins,
    compute a weighted average.
    """
    subject_embeddings_agg = {}
    
    for subject_id, categories in tqdm(subject_embeddings_raw.items(), desc="Aggregating subjects"):
        subject_embeddings_agg[subject_id] = {}
        
        for category, age_data in categories.items():
            # Get all age_mo bins for this subject-category pair
            embeddings_list = []
            weights_list = []
            
            for (cat, age_mo), data in age_data.items():
                embeddings_list.append(data['embedding'])
                if weight_by_sample_count:
                    weights_list.append(data['sample_count'])
                else:
                    weights_list.append(1.0)
            
            if len(embeddings_list) == 0:
                continue
            
            # Compute weighted average
            embeddings_array = np.array(embeddings_list)
            weights_array = np.array(weights_list)
            weights_array = weights_array / weights_array.sum()  # Normalize weights
            
            # Weighted average
            aggregated = np.average(embeddings_array, axis=0, weights=weights_array)
            
            subject_embeddings_agg[subject_id][category] = aggregated
    
    return subject_embeddings_agg

# Aggregate embeddings
subject_embeddings = aggregate_subject_embeddings(
    subject_embeddings_raw, 
    weight_by_sample_count=weight_by_sample_count
)

print(f"\nAggregated embeddings for {len(subject_embeddings)} subjects")

# Show category counts per subject
category_counts = {sid: len(cats) for sid, cats in subject_embeddings.items()}
print(f"\nCategory counts per subject:")
print(f"  Min: {min(category_counts.values())}")
print(f"  Max: {max(category_counts.values())}")
print(f"  Mean: {np.mean(list(category_counts.values())):.1f}")
print(f"  Median: {np.median(list(category_counts.values())):.1f}")


## Normalize Embeddings

Before computing RDMs, we normalize embeddings using z-score normalization (mean=0, std=1) to ensure fair comparisons.


In [None]:
# Global normalization: normalize across ALL embeddings from ALL subjects
print("Computing global normalization statistics across all subjects...")

# Collect all embeddings from all subjects
all_embeddings_list = []
for subject_id, categories in subject_embeddings.items():
    for cat, embedding in categories.items():
        all_embeddings_list.append(embedding)

# Stack all embeddings
all_embeddings_matrix = np.array(all_embeddings_list)
print(f"  Collected {len(all_embeddings_list)} embeddings from {len(subject_embeddings)} subjects")

# Compute global mean and std across all embeddings
global_mean = all_embeddings_matrix.mean(axis=0)
global_std = all_embeddings_matrix.std(axis=0) + 1e-10  # Add small epsilon to avoid division by zero

print(f"  Global mean shape: {global_mean.shape}")
print(f"  Global std shape: {global_std.shape}")
print(f"  Global mean range: [{global_mean.min():.4f}, {global_mean.max():.4f}]")
print(f"  Global std range: [{global_std.min():.4f}, {global_std.max():.4f}]")

# Apply global normalization to each subject's embeddings
print("\nApplying global normalization to each subject...")
subject_embeddings_normalized = {}

for subject_id, categories in tqdm(subject_embeddings.items(), desc="Normalizing"):
    subject_embeddings_normalized[subject_id] = {}
    
    for cat, embedding in categories.items():
        # Apply global normalization: (x - global_mean) / global_std
        normalized_embedding = (embedding - global_mean) / global_std
        subject_embeddings_normalized[subject_id][cat] = normalized_embedding

print(f"Normalized embeddings for {len(subject_embeddings_normalized)} subjects using global statistics")

def compute_subject_rdm(subject_embeddings_dict, categories_list):
    """
    Compute RDM for a single subject.
    
    Args:
        subject_embeddings_dict: dict[category] = embedding array (should be normalized)
        categories_list: list of categories to include (in order)
    
    Returns:
        rdm: numpy array of shape (n_categories, n_categories)
        available_categories: list of categories actually present
    """
    # Filter to categories that exist for this subject
    available_categories = [cat for cat in categories_list if cat in subject_embeddings_dict]
    
    if len(available_categories) < 2:
        return None, available_categories
    
    # Build embedding matrix (already normalized)
    embedding_matrix = np.array([subject_embeddings_dict[cat] for cat in available_categories])
    
    # Compute cosine similarity
    similarity_matrix = cosine_similarity(embedding_matrix)
    
    # Convert to distance (RDM)
    distance_matrix = 1 - similarity_matrix
    np.fill_diagonal(distance_matrix, 0)  # Ensure diagonal is 0
    
    # Make symmetric (in case of numerical errors)
    distance_matrix = (distance_matrix + distance_matrix.T) / 2
    
    return distance_matrix, available_categories

# Get all unique categories across all subjects
all_categories = set()
for subject_id, categories in subject_embeddings_normalized.items():
    all_categories.update(categories.keys())

all_categories = sorted(list(all_categories))
print(f"Total unique categories across all subjects: {len(all_categories)}")

# Compute RDMs for each subject using normalized embeddings
subject_rdms = {}
subject_rdm_categories = {}

for subject_id, categories in tqdm(subject_embeddings_normalized.items(), desc="Computing RDMs"):
    if len(categories) < min_categories_per_subject:
        continue
    
    rdm, available_cats = compute_subject_rdm(categories, all_categories)
    
    if rdm is not None:
        subject_rdms[subject_id] = rdm
        subject_rdm_categories[subject_id] = available_cats

print(f"\nComputed RDMs for {len(subject_rdms)} subjects")
print(f"  (Excluded {len(subject_embeddings_normalized) - len(subject_rdms)} subjects with < {min_categories_per_subject} categories)")


## Compute Individual Subject RDMs


## Save Individual Subject RDMs


In [None]:
# Save RDMs
print("Saving individual subject RDMs...")

for subject_id, rdm in tqdm(subject_rdms.items(), desc="Saving RDMs"):
    categories = subject_rdm_categories[subject_id]
    
    # Save as numpy array
    np.save(output_dir / f"rdm_{subject_id}.npy", rdm)
    
    # Save as CSV with category labels
    rdm_df = pd.DataFrame(rdm, index=categories, columns=categories)
    rdm_df.to_csv(output_dir / f"rdm_{subject_id}.csv")
    
    # Save metadata
    metadata = {
        'subject_id': subject_id,
        'n_categories': len(categories),
        'categories': categories,
        'mean_distance': float(rdm.mean()),
        'std_distance': float(rdm.std())
    }
    
    metadata_df = pd.DataFrame([metadata])
    metadata_df.to_csv(output_dir / f"metadata_{subject_id}.csv", index=False)

print(f"\nSaved RDMs to {output_dir}")


## Create Summary Statistics


In [None]:
# Create summary dataframe
summary_data = []

for subject_id, rdm in subject_rdms.items():
    categories = subject_rdm_categories[subject_id]
    
    summary_data.append({
        'subject_id': subject_id,
        'n_categories': len(categories),
        'mean_distance': float(rdm.mean()),
        'std_distance': float(rdm.std()),
        'min_distance': float(rdm[rdm > 0].min()) if (rdm > 0).any() else np.nan,
        'max_distance': float(rdm.max())
    })

summary_df = pd.DataFrame(summary_data)
summary_df = summary_df.sort_values('n_categories', ascending=False)
summary_df.to_csv(output_dir / "summary_statistics.csv", index=False)

print("Summary statistics:")
print(summary_df.describe())
print(f"\nSaved summary to {output_dir / 'summary_statistics.csv'}")


## Visualize Sample RDMs


In [None]:
# Visualize a few sample RDMs
n_samples = min(6, len(subject_rdms))
sample_subjects = list(subject_rdms.keys())[:n_samples]

fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.flatten()

for idx, subject_id in enumerate(sample_subjects):
    rdm = subject_rdms[subject_id]
    categories = subject_rdm_categories[subject_id]
    
    ax = axes[idx]
    im = ax.imshow(rdm, cmap='viridis', aspect='auto')
    ax.set_title(f"{subject_id}\n({len(categories)} categories)", fontsize=10)
    ax.set_xlabel('Category')
    ax.set_ylabel('Category')
    plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.savefig(output_dir / "sample_rdms.png", dpi=150, bbox_inches='tight')
print(f"Saved sample RDM visualization to {output_dir / 'sample_rdms.png'}")
plt.close()


## Data Density Analysis


In [None]:
# Analyze data density across subjects
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Category count distribution
axes[0].hist([len(cats) for cats in subject_rdm_categories.values()], bins=20, edgecolor='black')
axes[0].set_xlabel('Number of Categories per Subject')
axes[0].set_ylabel('Number of Subjects')
axes[0].set_title('Data Density: Categories per Subject')
axes[0].axvline(min_categories_per_subject, color='red', linestyle='--', label=f'Min threshold ({min_categories_per_subject})')
axes[0].legend()

# Mean distance vs category count
mean_distances = [subject_rdms[sid].mean() for sid in subject_rdms.keys()]
n_categories = [len(subject_rdm_categories[sid]) for sid in subject_rdms.keys()]

axes[1].scatter(n_categories, mean_distances, alpha=0.6)
axes[1].set_xlabel('Number of Categories')
axes[1].set_ylabel('Mean RDM Distance')
axes[1].set_title('RDM Distance vs Data Density')

plt.tight_layout()
plt.savefig(output_dir / "data_density_analysis.png", dpi=150, bbox_inches='tight')
print(f"Saved data density analysis to {output_dir / 'data_density_analysis.png'}")
plt.close()
