# Category-wise Subject Embedding Correlations

This notebook calculates and compares category-wise embedding correlations between subjects.
For each category, it computes the correlation between average category embeddings across different subjects.

## Overview

This analysis:
1. Loads normalized age-month level embeddings from notebook 05 (normalized grouped embeddings)
2. Aggregates embeddings per subject across all age_mo for each category (simple average across age bins)
3. For each category, computes pairwise correlations between all subject pairs
4. Creates correlation matrices showing how similar subjects are in their category representations
5. Visualizes results with heatmaps and summary statistics

## Key Features

- **Category-wise analysis**: Computes correlations separately for each category
- **Subject comparisons**: Compares average category embeddings between all pairs of subjects
- **Multiple correlation metrics**: Supports Pearson and Spearman correlations
- **Visualization**: Creates heatmaps showing subject-subject correlations for each category
- **Summary statistics**: Computes mean correlations per category and per subject pair

## Setup and Imports

In [1]:
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import pearsonr, spearmanr
from collections import defaultdict
from itertools import combinations
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set matplotlib backend
import matplotlib
matplotlib.use('Agg')

print("All imports successful!")

All imports successful!


## Configuration

In [2]:
# Paths
# Path to normalized embeddings from notebook 05 (age-month level normalized embeddings)
# These are saved in category folders: {normalized_embeddings_dir}/{category}/{subject_id}_{age_mo}_month_level_avg.npy
normalized_embeddings_dir = Path("/data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/clip_embeddings_grouped_by_age-mo_normalized")

# Detect embedding type from path
normalized_embeddings_dir_str = str(normalized_embeddings_dir).lower()
if "dinov3" in normalized_embeddings_dir_str or "dinov" in normalized_embeddings_dir_str:
    embedding_type = "dinov3"
elif "clip" in normalized_embeddings_dir_str:
    embedding_type = "clip"
else:
    embedding_type = "unknown"

# Create output directory with embedding type in name
output_dir = Path(f"category_wise_subject_correlations_{embedding_type}")
output_dir.mkdir(exist_ok=True, parents=True)

# Create subdirectories
csv_dir = output_dir / "csv"
plots_dir = output_dir / "plots"
csv_dir.mkdir(exist_ok=True, parents=True)
plots_dir.mkdir(exist_ok=True, parents=True)

# Subject to exclude from analyses
excluded_subject = "00270001"

# Correlation method: 'pearson' or 'spearman'
correlation_method = 'pearson'

# Minimum number of subjects required per category to compute correlations
min_subjects_per_category = 2

print(f"Normalized embeddings directory: {normalized_embeddings_dir}")
print(f"Detected embedding type: {embedding_type}")
print(f"Output directory: {output_dir}")
print(f"CSV subdirectory: {csv_dir}")
print(f"Plots subdirectory: {plots_dir}")
print(f"Excluded subject: {excluded_subject}")
print(f"Correlation method: {correlation_method}")
print(f"Min subjects per category: {min_subjects_per_category}")

Normalized embeddings directory: /data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/clip_embeddings_grouped_by_age-mo_normalized
Detected embedding type: clip
Output directory: category_wise_subject_correlations_clip
CSV subdirectory: category_wise_subject_correlations_clip/csv
Plots subdirectory: category_wise_subject_correlations_clip/plots
Excluded subject: 00270001
Correlation method: pearson
Min subjects per category: 2


## Load and Aggregate Embeddings

In [3]:
# Load normalized age-month level embeddings from notebook 05 and aggregate to subject level
print("Loading normalized age-month level embeddings from notebook 05...")
print(f"  Source directory: {normalized_embeddings_dir}")

# Get all category folders
category_folders = [f for f in normalized_embeddings_dir.iterdir() if f.is_dir()]
print(f"  Found {len(category_folders)} category folders")

# Collect all embeddings by subject and category
# Structure: {subject_id: {category: [list of age_mo embeddings]}}
subject_category_embeddings = defaultdict(lambda: defaultdict(list))
all_categories_set = set()

for category_folder in tqdm(category_folders, desc="Loading category folders"):
    category = category_folder.name
    all_categories_set.add(category)
    
    # Get all embedding files in this category
    embedding_files = list(category_folder.glob("*.npy"))
    
    for emb_file in embedding_files:
        # Parse filename: {subject_id}_{age_mo}_month_level_avg.npy
        filename = emb_file.stem  # without .npy
        parts = filename.split('_')
        
        if len(parts) < 2:
            continue
        
        # Extract subject_id and age_mo
        subject_id = parts[0]
        age_mo = int(parts[1]) if parts[1].isdigit() else None
        
        if age_mo is None:
            continue
        
        # Exclude subject if specified
        if excluded_subject and subject_id == excluded_subject:
            continue
        
        try:
            embedding = np.load(emb_file)
            subject_category_embeddings[subject_id][category].append(embedding)
        except Exception as e:
            print(f"Error loading {emb_file}: {e}")
            continue

# Aggregate embeddings per subject: average across age_mo for each category
print(f"\nAggregating embeddings per subject (averaging across age_mo)...")
subject_embeddings_normalized = {}

for subject_id in tqdm(subject_category_embeddings.keys(), desc="Aggregating subjects"):
    subject_embeddings_normalized[subject_id] = {}
    
    for category, age_mo_embeddings in subject_category_embeddings[subject_id].items():
        if len(age_mo_embeddings) > 0:
            # Average across all age_mo embeddings for this category
            # Stack embeddings and compute mean
            stacked = np.array([emb.flatten() for emb in age_mo_embeddings])
            avg_embedding = stacked.mean(axis=0)
            subject_embeddings_normalized[subject_id][category] = avg_embedding

print(f"\nLoaded and aggregated normalized embeddings for {len(subject_embeddings_normalized)} subjects")
print(f"  Total unique categories across all subjects: {len(all_categories_set)}")

# Get list of all subjects
all_subjects = sorted(list(subject_embeddings_normalized.keys()))
print(f"  Subjects: {all_subjects}")

Loading normalized age-month level embeddings from notebook 05...
  Source directory: /data2/dataset/babyview/868_hours/outputs/yoloe_cdi_embeddings/clip_embeddings_grouped_by_age-mo_normalized
  Found 163 category folders


Loading category folders: 100%|██████████| 163/163 [00:02<00:00, 64.85it/s]



Aggregating embeddings per subject (averaging across age_mo)...


Aggregating subjects: 100%|██████████| 31/31 [00:00<00:00, 484.64it/s]


Loaded and aggregated normalized embeddings for 31 subjects
  Total unique categories across all subjects: 163
  Subjects: ['00220001', '00230001', '00240001', '00320001', '00320002', '00320003', '00340002', '00350001', '00350002', '00360001', '00370001', '00370002', '00390001', '00400001', '00400002', '00400003', '00420001', '00430001', '00430002', '00440001', '00460001', '00490001', '00500001', '00510001', '00510002', '00550001', '00560001', '00590001', '00680001', '00720001', '00820001']





## Compute Category-wise Subject Correlations

In [4]:
# For each category, compute correlations between all subject pairs
print("Computing category-wise subject correlations...")

# Store results
# Structure: {category: {subject1: {subject2: correlation}}}
category_subject_correlations = defaultdict(lambda: defaultdict(dict))

# Also store as matrices for easier visualization
category_correlation_matrices = {}

# Get all unique categories that have data for at least min_subjects_per_category subjects
all_categories = sorted(list(all_categories_set))
valid_categories = []

for category in tqdm(all_categories, desc="Processing categories"):
    # Get all subjects that have this category
    subjects_with_category = [
        subj for subj in all_subjects 
        if category in subject_embeddings_normalized[subj]
    ]
    
    if len(subjects_with_category) < min_subjects_per_category:
        continue
    
    valid_categories.append(category)
    
    # Extract embeddings for this category across all subjects
    category_embeddings = {}
    for subject_id in subjects_with_category:
        category_embeddings[subject_id] = subject_embeddings_normalized[subject_id][category]
    
    # Compute pairwise correlations
    correlation_matrix = np.full((len(all_subjects), len(all_subjects)), np.nan)
    
    for i, subj1 in enumerate(all_subjects):
        for j, subj2 in enumerate(all_subjects):
            if subj1 == subj2:
                # Self-correlation is 1.0
                correlation_matrix[i, j] = 1.0
                category_subject_correlations[category][subj1][subj2] = 1.0
            elif subj1 in category_embeddings and subj2 in category_embeddings:
                # Compute correlation between the two embeddings
                emb1 = category_embeddings[subj1]
                emb2 = category_embeddings[subj2]
                
                if correlation_method == 'pearson':
                    corr, _ = pearsonr(emb1, emb2)
                elif correlation_method == 'spearman':
                    corr, _ = spearmanr(emb1, emb2)
                else:
                    # Fallback to cosine similarity
                    corr = cosine_similarity([emb1], [emb2])[0, 0]
                
                correlation_matrix[i, j] = corr
                category_subject_correlations[category][subj1][subj2] = corr
    
    category_correlation_matrices[category] = correlation_matrix

print(f"\nComputed correlations for {len(valid_categories)} categories")
print(f"  Categories with sufficient data: {len(valid_categories)}")

Computing category-wise subject correlations...


Processing categories: 100%|██████████| 163/163 [00:12<00:00, 13.10it/s]


Computed correlations for 163 categories
  Categories with sufficient data: 163





## Create Summary DataFrames

In [5]:
# Create a long-format DataFrame with all correlations
print("Creating summary DataFrames...")

correlation_records = []
for category in valid_categories:
    for subj1 in all_subjects:
        for subj2 in all_subjects:
            if subj1 in category_subject_correlations[category] and \
               subj2 in category_subject_correlations[category][subj1]:
                corr = category_subject_correlations[category][subj1][subj2]
                correlation_records.append({
                    'category': category,
                    'subject1': subj1,
                    'subject2': subj2,
                    'correlation': corr
                })

correlations_df = pd.DataFrame(correlation_records)
print(f"Created correlations DataFrame with {len(correlations_df)} records")

# Compute summary statistics
print("\nComputing summary statistics...")

# Mean correlation per category (excluding self-correlations)
category_mean_correlations = []
for category in valid_categories:
    category_corrs = correlations_df[
        (correlations_df['category'] == category) & 
        (correlations_df['subject1'] != correlations_df['subject2'])
    ]['correlation'].values
    if len(category_corrs) > 0:
        category_mean_correlations.append({
            'category': category,
            'mean_correlation': np.nanmean(category_corrs),
            'std_correlation': np.nanstd(category_corrs),
            'min_correlation': np.nanmin(category_corrs),
            'max_correlation': np.nanmax(category_corrs),
            'n_subject_pairs': len(category_corrs)
        })

category_summary_df = pd.DataFrame(category_mean_correlations)
category_summary_df = category_summary_df.sort_values('mean_correlation', ascending=False)
print(f"Created category summary with {len(category_summary_df)} categories")

# Mean correlation per subject pair (across all categories)
subject_pair_mean_correlations = []
for subj1, subj2 in combinations(all_subjects, 2):
    pair_corrs = correlations_df[
        (correlations_df['subject1'] == subj1) & 
        (correlations_df['subject2'] == subj2)
    ]['correlation'].values
    if len(pair_corrs) > 0:
        subject_pair_mean_correlations.append({
            'subject1': subj1,
            'subject2': subj2,
            'mean_correlation': np.nanmean(pair_corrs),
            'std_correlation': np.nanstd(pair_corrs),
            'min_correlation': np.nanmin(pair_corrs),
            'max_correlation': np.nanmax(pair_corrs),
            'n_categories': len(pair_corrs)
        })

subject_pair_summary_df = pd.DataFrame(subject_pair_mean_correlations)
subject_pair_summary_df = subject_pair_summary_df.sort_values('mean_correlation', ascending=False)
print(f"Created subject pair summary with {len(subject_pair_summary_df)} pairs")

# Display top categories by mean correlation
print("\nTop 10 categories by mean correlation:")
print(category_summary_df.head(10))

# Display top subject pairs by mean correlation
print("\nTop 10 subject pairs by mean correlation:")
print(subject_pair_summary_df.head(10))

Creating summary DataFrames...
Created correlations DataFrame with 133531 records

Computing summary statistics...
Created category summary with 163 categories
Created subject pair summary with 465 pairs

Top 10 categories by mean correlation:
    category  mean_correlation  std_correlation  min_correlation  \
42    crayon          0.942488         0.040843         0.756292   
74      hand          0.918849         0.048199         0.681343   
58    finger          0.918638         0.063157         0.705552   
94      nail          0.906610         0.089136         0.579158   
150      toe          0.883368         0.101027         0.510595   
19    bottle          0.846871         0.158516         0.156854   
160    watch          0.844931         0.098908         0.407526   
68   glasses          0.841866         0.091491         0.502670   
135  slipper          0.839791         0.089918         0.449792   
128    shirt          0.832728         0.075015         0.572636   

     ma

## Visualize Correlation Matrices

In [6]:
# Create heatmaps for correlation matrices
print("Creating correlation heatmaps...")

# Set up plotting style
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (12, 10)

# Create heatmaps for top categories (by mean correlation)
n_top_categories = min(20, len(category_summary_df))
top_categories = category_summary_df.head(n_top_categories)['category'].tolist()

for category in tqdm(top_categories, desc="Creating heatmaps"):
    corr_matrix = category_correlation_matrices[category]
    
    # Create figure
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Create heatmap
    mask = np.isnan(corr_matrix)
    sns.heatmap(
        corr_matrix,
        xticklabels=all_subjects,
        yticklabels=all_subjects,
        annot=False,
        cmap='coolwarm',
        center=0,
        vmin=-1,
        vmax=1,
        square=True,
        mask=mask,
        cbar_kws={'label': f'{correlation_method.capitalize()} Correlation'},
        ax=ax
    )
    
    ax.set_title(f'Subject Correlations: {category}\n(Mean: {category_summary_df[category_summary_df["category"]==category]["mean_correlation"].values[0]:.3f})',
                 fontsize=14, fontweight='bold')
    ax.set_xlabel('Subject', fontsize=12)
    ax.set_ylabel('Subject', fontsize=12)
    
    plt.tight_layout()
    
    # Save figure
    safe_category_name = category.replace('/', '_').replace(' ', '_')
    plt.savefig(plots_dir / f'correlation_heatmap_{safe_category_name}.png', dpi=150, bbox_inches='tight')
    plt.close()

print(f"Created {len(top_categories)} heatmaps for top categories")

Creating correlation heatmaps...


Creating heatmaps: 100%|██████████| 20/20 [00:07<00:00,  2.69it/s]

Created 20 heatmaps for top categories





## Create Summary Visualizations

## Comprehensive Category Visualizations

In [12]:
# Create comprehensive visualizations showing all categories
print("Creating comprehensive category visualizations...")

# 1. All categories sorted bar plot (full view)
fig, ax = plt.subplots(figsize=(14, max(20, len(category_summary_df) * 0.3)))
sorted_cats = category_summary_df.sort_values('mean_correlation', ascending=True)
colors = ['coral' if corr < category_summary_df['mean_correlation'].median() else 'steelblue' 
          for corr in sorted_cats['mean_correlation']]
ax.barh(range(len(sorted_cats)), sorted_cats['mean_correlation'], edgecolor='black', color=colors)
ax.set_yticks(range(len(sorted_cats)))
ax.set_yticklabels(sorted_cats['category'], fontsize=6)
ax.set_xlabel('Mean Correlation (across subject pairs)', fontsize=12)
ax.set_title('All Categories Sorted by Mean Correlation', fontsize=14, fontweight='bold')
ax.axvline(category_summary_df['mean_correlation'].median(), color='red', 
           linestyle='--', linewidth=2, label=f'Median: {category_summary_df["mean_correlation"].median():.3f}')
ax.axvline(category_summary_df['mean_correlation'].mean(), color='orange', 
           linestyle='--', linewidth=2, label=f'Mean: {category_summary_df["mean_correlation"].mean():.3f}')
ax.legend()
ax.invert_yaxis()
plt.tight_layout()
plt.savefig(plots_dir / 'all_categories_sorted_by_correlation.png', dpi=150, bbox_inches='tight')
plt.close()

# 2. Heatmap showing all categories with their correlation statistics
print("Creating category statistics heatmap...")
category_stats_matrix = category_summary_df[['mean_correlation', 'std_correlation', 
                                              'min_correlation', 'max_correlation']].values
category_stats_df = pd.DataFrame(
    category_stats_matrix,
    index=category_summary_df['category'],
    columns=['Mean', 'Std', 'Min', 'Max']
)
category_stats_df = category_stats_df.sort_values('Mean', ascending=False)

fig, ax = plt.subplots(figsize=(max(20, len(category_stats_df) * 0.3), 6))
sns.heatmap(
    category_stats_df.T,
    annot=False,
    cmap='RdYlBu_r',
    center=category_summary_df['mean_correlation'].mean(),
    cbar_kws={'label': 'Correlation Value'},
    ax=ax,
    yticklabels=['Mean Correlation', 'Std Dev', 'Min', 'Max'],
    xticklabels=False  # Don't show all labels initially
)
ax.set_xlabel('Category', fontsize=10)
ax.set_ylabel('Statistic', fontsize=12)
ax.set_title('All Categories: Correlation Statistics\n(Sorted by Mean Correlation)', 
             fontsize=14, fontweight='bold')

# Set x-axis ticks to show every nth category to avoid overcrowding
n_categories = len(category_stats_df)
if n_categories > 50:
    # Show every 5th category label
    step = max(1, n_categories // 50)
    tick_positions = list(range(0, n_categories, step))
    tick_labels = [category_stats_df.index[i] for i in tick_positions]
    ax.set_xticks(tick_positions)
    ax.set_xticklabels(tick_labels, rotation=90, fontsize=6, ha='right')
else:
    # Show all labels if not too many
    ax.set_xticks(range(n_categories))
    ax.set_xticklabels(category_stats_df.index, rotation=90, fontsize=6, ha='right')

plt.tight_layout()
plt.savefig(plots_dir / 'all_categories_statistics_heatmap.png', dpi=150, bbox_inches='tight')
plt.close()

# 3. Heatmap showing categories vs subject pairs (sample for visualization)
# This could be very large, so we'll create a version with top/bottom categories
print("Creating category vs subject pair heatmap (top and bottom categories)...")
n_sample_cats = 40  # Show top 20 and bottom 20
top_sample = category_summary_df.head(n_sample_cats // 2)
bottom_sample = category_summary_df.tail(n_sample_cats // 2)
sample_categories = pd.concat([top_sample, bottom_sample])

# Create a matrix: categories (rows) x subject pairs (columns)
# Sample subject pairs for visualization (top correlated pairs)
n_sample_pairs = min(30, len(subject_pair_summary_df))
sample_pairs = subject_pair_summary_df.head(n_sample_pairs)

category_pair_matrix = np.zeros((len(sample_categories), len(sample_pairs)))
for i, (_, cat_row) in enumerate(sample_categories.iterrows()):
    category = cat_row['category']
    for j, (_, pair_row) in enumerate(sample_pairs.iterrows()):
        subj1 = pair_row['subject1']
        subj2 = pair_row['subject2']
        # Get correlation for this category and subject pair
        corr_data = correlations_df[
            (correlations_df['category'] == category) &
            (correlations_df['subject1'] == subj1) &
            (correlations_df['subject2'] == subj2)
        ]
        if len(corr_data) > 0:
            category_pair_matrix[i, j] = corr_data['correlation'].values[0]
        else:
            category_pair_matrix[i, j] = np.nan

# Create heatmap
pair_labels = [f"{row['subject1']}-{row['subject2']}" for _, row in sample_pairs.iterrows()]
fig, ax = plt.subplots(figsize=(max(16, len(sample_pairs) * 0.5), max(12, len(sample_categories) * 0.3)))
mask = np.isnan(category_pair_matrix)
sns.heatmap(
    category_pair_matrix,
    xticklabels=False,  # Set to False initially, then set manually below
    yticklabels=False,  # Set to False initially, then set manually below
    annot=False,
    cmap='coolwarm',
    center=0,
    vmin=-1,
    vmax=1,
    square=False,
    mask=mask,
    cbar_kws={'label': f'{correlation_method.capitalize()} Correlation'},
    ax=ax
)
ax.set_xlabel('Subject Pairs', fontsize=10)
ax.set_ylabel('Category', fontsize=10)
ax.set_title(f'Category Correlations: Top {n_sample_cats//2} & Bottom {n_sample_cats//2} Categories\nvs Top {n_sample_pairs} Subject Pairs', 
             fontsize=12, fontweight='bold')

# Set tick positions and labels manually
ax.set_xticks(range(len(pair_labels)))
ax.set_xticklabels(pair_labels, rotation=90, fontsize=6, ha='right')
ax.set_yticks(range(len(sample_categories)))
ax.set_yticklabels(sample_categories['category'].values, fontsize=7)
plt.tight_layout()
plt.savefig(plots_dir / 'category_vs_subject_pair_heatmap.png', dpi=150, bbox_inches='tight')
plt.close()

# 4. Violin plot showing distribution of correlations per category (sample)
print("Creating correlation distribution violin plot...")
# Sample categories for violin plot (too many would be unreadable)
n_violin_cats = min(30, len(category_summary_df))
sample_cats_violin = category_summary_df.head(n_violin_cats // 2).append(
    category_summary_df.tail(n_violin_cats // 2)
)

violin_data = []
for _, cat_row in sample_cats_violin.iterrows():
    category = cat_row['category']
    cat_corrs = correlations_df[
        (correlations_df['category'] == category) &
        (correlations_df['subject1'] != correlations_df['subject2'])
    ]['correlation'].values
    for corr in cat_corrs:
        violin_data.append({'category': category, 'correlation': corr})

violin_df = pd.DataFrame(violin_data)
violin_df = violin_df.sort_values('correlation', ascending=False)

fig, ax = plt.subplots(figsize=(14, 8))
sns.violinplot(data=violin_df, x='category', y='correlation', ax=ax, inner='box')
ax.set_xlabel('Category', fontsize=10)
ax.set_ylabel('Correlation', fontsize=12)
ax.set_title(f'Distribution of Correlations: Top {n_violin_cats//2} & Bottom {n_violin_cats//2} Categories', 
             fontsize=14, fontweight='bold')
ax.set_xticklabels(ax.get_xticklabels(), rotation=90, fontsize=7, ha='right')
plt.tight_layout()
plt.savefig(plots_dir / 'correlation_distribution_violin_plot.png', dpi=150, bbox_inches='tight')
plt.close()

print("Comprehensive visualizations created!")

Creating comprehensive category visualizations...
Creating category statistics heatmap...
Creating category vs subject pair heatmap (top and bottom categories)...


Error: 

In [7]:
# Create summary visualizations
print("Creating summary visualizations...")

# 1. Distribution of mean correlations per category
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(category_summary_df['mean_correlation'], bins=30, edgecolor='black', alpha=0.7)
ax.set_xlabel('Mean Correlation (across subject pairs)', fontsize=12)
ax.set_ylabel('Number of Categories', fontsize=12)
ax.set_title('Distribution of Mean Category Correlations', fontsize=14, fontweight='bold')
ax.axvline(category_summary_df['mean_correlation'].mean(), color='red', 
           linestyle='--', linewidth=2, label=f'Mean: {category_summary_df["mean_correlation"].mean():.3f}')
ax.legend()
plt.tight_layout()
plt.savefig(plots_dir / 'distribution_mean_category_correlations.png', dpi=150, bbox_inches='tight')
plt.close()

# 2. Distribution of mean correlations per subject pair
fig, ax = plt.subplots(figsize=(10, 6))
ax.hist(subject_pair_summary_df['mean_correlation'], bins=30, edgecolor='black', alpha=0.7)
ax.set_xlabel('Mean Correlation (across categories)', fontsize=12)
ax.set_ylabel('Number of Subject Pairs', fontsize=12)
ax.set_title('Distribution of Mean Subject Pair Correlations', fontsize=14, fontweight='bold')
ax.axvline(subject_pair_summary_df['mean_correlation'].mean(), color='red', 
           linestyle='--', linewidth=2, label=f'Mean: {subject_pair_summary_df["mean_correlation"].mean():.3f}')
ax.legend()
plt.tight_layout()
plt.savefig(plots_dir / 'distribution_mean_subject_pair_correlations.png', dpi=150, bbox_inches='tight')
plt.close()

# 3. Top categories bar plot
fig, ax = plt.subplots(figsize=(12, 8))
top_n = min(30, len(category_summary_df))
top_cats = category_summary_df.head(top_n)
ax.barh(range(len(top_cats)), top_cats['mean_correlation'], edgecolor='black')
ax.set_yticks(range(len(top_cats)))
ax.set_yticklabels(top_cats['category'], fontsize=8)
ax.set_xlabel('Mean Correlation', fontsize=12)
ax.set_title(f'Top {top_n} Categories by Mean Correlation', fontsize=14, fontweight='bold')
ax.invert_yaxis()
plt.tight_layout()
plt.savefig(plots_dir / 'top_categories_by_correlation.png', dpi=150, bbox_inches='tight')
plt.close()

# 3b. Bottom categories bar plot
fig, ax = plt.subplots(figsize=(12, 8))
bottom_n = min(30, len(category_summary_df))
bottom_cats = category_summary_df.tail(bottom_n)
ax.barh(range(len(bottom_cats)), bottom_cats['mean_correlation'], edgecolor='black', color='coral')
ax.set_yticks(range(len(bottom_cats)))
ax.set_yticklabels(bottom_cats['category'], fontsize=8)
ax.set_xlabel('Mean Correlation', fontsize=12)
ax.set_title(f'Bottom {bottom_n} Categories by Mean Correlation', fontsize=14, fontweight='bold')
ax.invert_yaxis()
plt.tight_layout()
plt.savefig(plots_dir / 'bottom_categories_by_correlation.png', dpi=150, bbox_inches='tight')
plt.close()

# 4. Overall correlation matrix (average across all categories)
print("Creating overall average correlation matrix...")
overall_corr_matrix = np.zeros((len(all_subjects), len(all_subjects)))
overall_corr_matrix[:] = np.nan

for i, subj1 in enumerate(all_subjects):
    for j, subj2 in enumerate(all_subjects):
        if subj1 == subj2:
            overall_corr_matrix[i, j] = 1.0
        else:
            # Get all correlations for this pair across all categories
            pair_corrs = correlations_df[
                (correlations_df['subject1'] == subj1) & 
                (correlations_df['subject2'] == subj2)
            ]['correlation'].values
            if len(pair_corrs) > 0:
                overall_corr_matrix[i, j] = np.nanmean(pair_corrs)

fig, ax = plt.subplots(figsize=(12, 10))
mask = np.isnan(overall_corr_matrix)
sns.heatmap(
    overall_corr_matrix,
    xticklabels=all_subjects,
    yticklabels=all_subjects,
    annot=True,
    fmt='.3f',
    cmap='coolwarm',
    center=0,
    vmin=-1,
    vmax=1,
    square=True,
    mask=mask,
    cbar_kws={'label': f'Mean {correlation_method.capitalize()} Correlation (across categories)'},
    ax=ax
)
ax.set_title('Overall Subject Correlation Matrix\n(Averaged across all categories)', 
             fontsize=14, fontweight='bold')
ax.set_xlabel('Subject', fontsize=12)
ax.set_ylabel('Subject', fontsize=12)
plt.tight_layout()
plt.savefig(plots_dir / 'overall_subject_correlation_matrix.png', dpi=150, bbox_inches='tight')
plt.close()

print("Summary visualizations created!")

Creating summary visualizations...
Creating overall average correlation matrix...
Summary visualizations created!


## Save Results

In [8]:
# Save all results to CSV files
print("Saving results to CSV files...")

# Save full correlations DataFrame
correlations_df.to_csv(csv_dir / 'all_category_subject_correlations.csv', index=False)
print(f"Saved full correlations to: {csv_dir / 'all_category_subject_correlations.csv'}")

# Save category summary
category_summary_df.to_csv(csv_dir / 'category_summary_statistics.csv', index=False)
print(f"Saved category summary to: {csv_dir / 'category_summary_statistics.csv'}")

# Save subject pair summary
subject_pair_summary_df.to_csv(csv_dir / 'subject_pair_summary_statistics.csv', index=False)
print(f"Saved subject pair summary to: {csv_dir / 'subject_pair_summary_statistics.csv'}")

# Save overall correlation matrix
overall_corr_df = pd.DataFrame(
    overall_corr_matrix,
    index=all_subjects,
    columns=all_subjects
)
overall_corr_df.to_csv(csv_dir / 'overall_subject_correlation_matrix.csv')
print(f"Saved overall correlation matrix to: {csv_dir / 'overall_subject_correlation_matrix.csv'}")

print("\nAll results saved successfully!")
print(f"\nSummary:")
print(f"  Total categories analyzed: {len(valid_categories)}")
print(f"  Total subjects: {len(all_subjects)}")
print(f"  Total subject pairs: {len(subject_pair_summary_df)}")
print(f"  Mean correlation across all categories: {category_summary_df['mean_correlation'].mean():.4f}")
print(f"  Mean correlation across all subject pairs: {subject_pair_summary_df['mean_correlation'].mean():.4f}")

Saving results to CSV files...
Saved full correlations to: category_wise_subject_correlations_clip/csv/all_category_subject_correlations.csv
Saved category summary to: category_wise_subject_correlations_clip/csv/category_summary_statistics.csv
Saved subject pair summary to: category_wise_subject_correlations_clip/csv/subject_pair_summary_statistics.csv
Saved overall correlation matrix to: category_wise_subject_correlations_clip/csv/overall_subject_correlation_matrix.csv

All results saved successfully!

Summary:
  Total categories analyzed: 163
  Total subjects: 31
  Total subject pairs: 465
  Mean correlation across all categories: 0.6012
  Mean correlation across all subject pairs: 0.6208
