# Universal Population-Level Brain Vessel Analysis Pipeline
## Flexible Analysis Framework for Variable Feature Sets

**Purpose:** Comprehensive vessel analysis pipeline that automatically adapts to available features in any CSV format:
- **Dynamic Feature Detection** - Automatically identifies available metrics
- **Modular Analysis** - Runs appropriate analyses based on detected features
- **Robust Handling** - Gracefully handles missing regions, tortuosity, or other features
- **Comprehensive Output** - Generates publication-ready figures and statistics

**Key Features:**
- Works with any vessel feature CSV structure
- Fixed metadata file format (IXI_METADATA.xls)
- Automatic feature categorization (morphometric, topological, curvature)
- Conditional analysis execution based on available data
- Export-ready results for ISBI paper
  
**Date:** October 23, 2025

---

### Notebook Structure:
1. Setup and Configuration
2. Data Loading with Feature Detection
3. Automatic Feature Categorization
4. Data Quality Assessment
5. Descriptive Statistics
6. Age-Related Analysis
7. Sex-Based Analysis
8. Anthropometric Correlations (Height, Weight, BMI)
9. Multi-Center Analysis
10. Regional Analysis (if available)
11. Hemispheric Asymmetry (if available)
12. Advanced Analyses (ML, interactions, stratification)
13. Summary and Export for Paper

---
## 1. Setup and Configuration

In [None]:
# Import required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.stats import pearsonr, spearmanr, linregress, ttest_ind, f_oneway
from scipy.stats import mannwhitneyu, kruskal, chi2_contingency
from pathlib import Path
import warnings
from glob import glob
import re
from datetime import datetime
import json

# Statistical modeling
import statsmodels.api as sm
import statsmodels.formula.api as smf
from statsmodels.stats.multitest import multipletests
try:
    from statsmodels.regression.mixed_linear_model import MixedLM
    MIXEDLM_AVAILABLE = True
except ImportError:
    MIXEDLM_AVAILABLE = False

# Machine learning
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score, cross_val_predict, KFold
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.linear_model import LinearRegression, Ridge
from sklearn.metrics import mean_absolute_error, r2_score, mean_squared_error

# Configure display and plotting
warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_rows', 100)
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
np.random.seed(42)

# Publication-quality figure settings
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 11
plt.rcParams['axes.titlesize'] = 12
plt.rcParams['xtick.labelsize'] = 9
plt.rcParams['ytick.labelsize'] = 9
plt.rcParams['legend.fontsize'] = 9

# Professional styling
plt.rcParams['axes.linewidth'] = 0.8
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['xtick.major.width'] = 0.8
plt.rcParams['ytick.major.width'] = 0.8
plt.rcParams['xtick.minor.visible'] = True
plt.rcParams['ytick.minor.visible'] = True


print("✓ Libraries imported successfully")
print(f"Analysis date: {datetime.now().strftime('%Y-%m-%d %H:%M')}")
print(f"Mixed-effects models available: {MIXEDLM_AVAILABLE}")

### 1.1 Configure Data Paths

**IMPORTANT:** Update these paths to match your data structure.

**Required Files:**
- **DEMOGRAPHICS_FILE**: Fixed format (IXI_METADATA.xls) with required columns:
  - IXI_ID, AGE, SEX_ID, HEIGHT, WEIGHT, ETHNIC_ID, etc.
- **FEATURES_CSV**: Any vessel feature CSV with columns:
  - Must have 'subject_id' column matching IXI format (e.g., 'IXI001')
  - Can contain any vessel metrics (morphometric, topological, curvature)
  - Optional 'region' column for regional analysis
  - Optional hemisphere indicators for asymmetry analysis

In [None]:
VESSEL_METRIC_FLAG = True  # Set to True to indicate vessel metrics analysis

### General CSV

In [None]:
if not VESSEL_METRIC_FLAG:
    # ============================================================================
    # CONFIGURATION: Update these paths for your local setup
    # ============================================================================

    # Path to demographics file (FIXED FORMAT - required)
    DEMOGRAPHICS_FILE = "/path/to/IXI_METADATA.xls"

    # Path to vessel features CSV (FLEXIBLE FORMAT - any structure)
    # This can be:
    #   - A single aggregated CSV with one row per subject
    #   - A regional CSV with multiple rows per subject
    #   - Any CSV with vessel features and a 'subject_id' column
    FEATURES_CSV = "/path/to/vessel_features.csv"

    # Output directory for results
    OUTPUT_DIR = "outputs_universal_analysis/"

    # Subject ID pattern (for extracting IDs from filenames if needed)
    SUBJECT_ID_REGEX = r"(IXI\d{3})"

    # ============================================================================

    # Create output directories
    Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
    FIGURES_DIR = Path(OUTPUT_DIR) / 'figures'
    FIGURES_DIR.mkdir(parents=True, exist_ok=True)
    TABLES_DIR = Path(OUTPUT_DIR) / 'tables'
    TABLES_DIR.mkdir(parents=True, exist_ok=True)

    print("✓ Configuration set")
    print(f"  Demographics: {DEMOGRAPHICS_FILE}")
    print(f"  Features CSV: {FEATURES_CSV}")
    print(f"  Output:       {OUTPUT_DIR}")
    print(f"  Figures:      {FIGURES_DIR}")
    print(f"  Tables:       {TABLES_DIR}")
else:
    print("✓ Vessel metrics analysis mode enabled")

### VESSEL METRICS

In [None]:
if VESSEL_METRIC_FLAG:
    
    # ============================================================================
    # CONFIGURATION: Update these paths for your local setup
    # ============================================================================

    # Path to demographics file (FIXED FORMAT - required)
    DEMOGRAPHICS_FILE = "/home/falcetta/ISBI2025/METADATA/IXI_METADATA.xls"

    # Path to vessel segmentation data directory
    VESSEL_DATA_DIR = "/home/falcetta/ISBI2025/VESSELVIO_FEATURES"
    VESSEL_DATA_DIR = "/home/falcetta/ISBI2025/VESSELEXP_FEATURES"
    VESSEL_DATA_DIR = "/home/falcetta/ISBI2025/IXI_EXTRACTED_FEATURES"

    # File naming patterns (what to look for in VESSEL_DATA_DIR)
    # Option 1: Use region summary files (one CSV per subject with regional data)
    REGION_SUMMARY_PATTERN = "*region_summary*.csv"

    # Option 2: Use component files (optional, for component-level analysis)
    COMPONENTS_PATTERN = "*all_components*.csv"

    # Option 3: Use any custom CSV pattern that contains vessel features
    # CUSTOM_PATTERN = "*vessel_features*.csv"

    # Which file type to use as primary features?
    # Options: 'region_summary', 'components', or 'custom' or 'VesselExpress' or 'VesselVio'
    FEATURE_FILE_TYPE = 'VesselExpress'  if 'VESSELEXP' in VESSEL_DATA_DIR else 'VesselVio' if 'VESSELVIO' in VESSEL_DATA_DIR else 'components' # 'region_summary' or 'components'

    # Output directory for results
    OUTPUT_DIR = f"outputs_universal_analysis_VESSEL_METRICS_{FEATURE_FILE_TYPE}_REMOVE_IOP/"

    # Subject ID pattern (for extracting IDs from filenames)
    SUBJECT_ID_REGEX = r"(IXI\d{3})"

    # ============================================================================

    # Create output directories
    Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
    FIGURES_DIR = Path(OUTPUT_DIR) / 'figures'
    FIGURES_DIR.mkdir(parents=True, exist_ok=True)
    TABLES_DIR = Path(OUTPUT_DIR) / 'tables'
    TABLES_DIR.mkdir(parents=True, exist_ok=True)

    print("✓ Configuration set")
    print(f"  Demographics: {DEMOGRAPHICS_FILE}")
    print(f"  Vessel data:  {VESSEL_DATA_DIR}")
    print(f"  File type:    {FEATURE_FILE_TYPE}")
    print(f"  Output:       {OUTPUT_DIR}")
    print(f"  Figures:      {FIGURES_DIR}")
    print(f"  Tables:       {TABLES_DIR}")
else:
    print("✓ Universal analysis mode enabled")

---
## 2. Data Loading with Automatic Feature Detection

### 2.1 Load Demographics (Fixed Format)

In [None]:
# Load demographics
print("Loading demographics...")
try:
    demographics = pd.read_excel(DEMOGRAPHICS_FILE)
    print(f"✓ Loaded demographics: {len(demographics)} subjects")
except Exception as e:
    print(f"❌ Error loading demographics: {e}")
    raise

# Standardize subject ID column
if 'IXI_ID' in demographics.columns:
    demographics['subject_id'] = 'IXI' + demographics['IXI_ID'].astype(str).str.zfill(3)
elif 'subject_id' not in demographics.columns:
    raise ValueError("Demographics must have 'IXI_ID' or 'subject_id' column")

# Data quality: Remove implausible height/weight values
if 'HEIGHT' in demographics.columns:
    demographics.loc[(demographics['HEIGHT'] < 50) | (demographics['HEIGHT'] > 250), 'HEIGHT'] = np.nan
if 'WEIGHT' in demographics.columns:
    demographics.loc[(demographics['WEIGHT'] < 20) | (demographics['WEIGHT'] > 300), 'WEIGHT'] = np.nan

# Calculate BMI if height and weight are available
if 'HEIGHT' in demographics.columns and 'WEIGHT' in demographics.columns:
    demographics['BMI'] = demographics['WEIGHT'] / (demographics['HEIGHT'] / 100) ** 2
    demographics.loc[(demographics['BMI'] < 10) | (demographics['BMI'] > 60), 'BMI'] = np.nan
    print(f"✓ BMI calculated for subjects with valid height/weight")

# Rename SEX_ID (1=m, 2=f) to 'SEX_ID'
if 'SEX_ID (1=m, 2=f)' in demographics.columns:
    demographics.rename(columns={'SEX_ID (1=m, 2=f)': 'SEX_ID'}, inplace=True)
    print("✓ Renamed 'SEX_ID (1=m, 2=f)' to 'SEX_ID'")

# Identify available demographic variables
DEMOGRAPHIC_VARS = [col for col in demographics.columns if col not in ['subject_id', 'IXI_ID', 'DOB', 'STUDY_DATE', 'DATE_AVAILABLE']]

#Remove DOB	DATE_AVAILABLE	STUDY_DATE from demographic variables
demographics = demographics.drop(columns=['DOB', 'DATE_AVAILABLE', 'STUDY_DATE'], errors='ignore')
print(f"\nAvailable demographic variables: {DEMOGRAPHIC_VARS}")
print(f"\nFirst 3 subjects:")
display(demographics.head(3))

DEMOGRAPHIC_VARS

In [None]:
demographics

### 2.2 Load Vessel Features (Flexible Format)

In [None]:
if not VESSEL_METRIC_FLAG:
    #Load vessel features CSV
    print("\nLoading vessel features...")
    try:
        features_df = pd.read_csv(FEATURES_CSV)
        print(f"✓ Loaded features CSV: {len(features_df)} rows, {len(features_df.columns)} columns")
    except Exception as e:
        print(f"❌ Error loading features CSV: {e}")
        raise

    # Verify subject_id column exists
    if 'subject_id' not in features_df.columns:
        raise ValueError("Features CSV must contain 'subject_id' column")

    # Check if this is regional data (multiple rows per subject)
    rows_per_subject = features_df.groupby('subject_id').size()
    IS_REGIONAL_DATA = (rows_per_subject > 1).any()

    print(f"\nData structure:")
    print(f"  Unique subjects: {features_df['subject_id'].nunique()}")
    print(f"  Regional data: {IS_REGIONAL_DATA}")
    if IS_REGIONAL_DATA:
        print(f"  Rows per subject: {rows_per_subject.describe()}")

    # Display first rows
    print(f"\nFirst 5 rows of features:")
    display(features_df.head())

    print(f"\nColumn names:")
    print(features_df.columns.tolist())
else:
    print("\nVessel metrics analysis mode enabled; skipping generic features loading.")

### VESSEL METRICS

In [None]:
if VESSEL_METRIC_FLAG:
    # Load vessel features from directory
    print("\nLoading vessel features from directory...")

    # Select the appropriate file pattern
    if FEATURE_FILE_TYPE == 'region_summary':
        pattern = REGION_SUMMARY_PATTERN
    elif FEATURE_FILE_TYPE == 'components':
        pattern = COMPONENTS_PATTERN
    elif FEATURE_FILE_TYPE == 'VesselExpress':
        pattern = "*VesselExpress*.csv"
    elif pattern == 'VesselVio':
        pattern = "*VesselVio*.csv"
    else:
        pattern = CUSTOM_PATTERN if 'CUSTOM_PATTERN' in globals() else REGION_SUMMARY_PATTERN

    # Find all matching files
    feature_files = glob(str(Path(VESSEL_DATA_DIR) / "**" / pattern), recursive=True)
    print(f"Found {len(feature_files)} files matching pattern '{pattern}'")

    if len(feature_files) == 0:
        raise ValueError(f"No files found matching pattern {pattern} in {VESSEL_DATA_DIR}")

    # Load all files
    feature_data_list = []
    for file in feature_files:
        # Extract subject ID from filename
        match = re.search(SUBJECT_ID_REGEX, str(file))
        if match:
            subject_id = match.group(1)
            try:
                df_temp = pd.read_csv(file)
                if FEATURE_FILE_TYPE == 'VesselExpress' or FEATURE_FILE_TYPE == 'VesselVio':
                    df_temp = pd.read_csv(file, sep=';')  # Example adjustment for VesselExpress/VesselVio format
                df_temp['subject_id'] = subject_id
                feature_data_list.append(df_temp)
            except Exception as e:
                print(f"  ⚠️  Error loading {file}: {e}")
                continue

    if len(feature_data_list) == 0:
        raise ValueError("No valid feature files could be loaded")

    # Concatenate all data
    features_df = pd.concat(feature_data_list, ignore_index=True)
    print(f"✓ Loaded features from {len(feature_data_list)} subjects")

    # Check if this is regional data (multiple rows per subject)
    rows_per_subject = features_df.groupby('subject_id').size()
    IS_REGIONAL_DATA = (rows_per_subject > 1).any()

    print(f"\nData structure:")
    print(f"  Total rows: {len(features_df)}")
    print(f"  Unique subjects: {features_df['subject_id'].nunique()}")
    print(f"  Regional data: {IS_REGIONAL_DATA}")
    if IS_REGIONAL_DATA:
        print(f"  Rows per subject: min={rows_per_subject.min()}, max={rows_per_subject.max()}, mean={rows_per_subject.mean():.1f}")

    # Display first rows
    print(f"\nFirst 5 rows of features:")
    display(features_df.head())

    print(f"\nColumn names ({len(features_df.columns)} columns):")
    print(features_df.columns.tolist())
else:
    print("✓ Universal analysis mode enabled")### VESSEL METRICS

In [None]:
#features_df.to_csv(Path(OUTPUT_DIR) / "loaded_features_preview.csv", index=False)
#print(f'\nSaved preview of loaded features to {Path(OUTPUT_DIR) / "loaded_features_preview.csv"}')

In [None]:
# List all the features that contain 'curvature'
curvature_features = [col for col in features_df.columns if 'c' in col.lower()]
print(f"\nFeatures containing 'curv': {curvature_features}")

# Aggregate root_mean_curvature 
root_mean_curvature_cols = [col for col in features_df.columns if 'root_mean_curvature' in col.lower()]
if root_mean_curvature_cols:
    features_df['root_mean_curvature_avg'] = features_df[root_mean_curvature_cols].mean(axis=1)
    print(f"✓ Aggregated root_mean_curvature into 'root_mean_curvature_avg'")
    # remove individual columns if desired
    features_df.drop(columns=root_mean_curvature_cols, inplace=True)
    
# Aggregate mean_squared_curvature
mean_squared_curvature_cols = [col for col in features_df.columns if 'mean_square_curvature' in col.lower()]
if mean_squared_curvature_cols:
    features_df['mean_squared_curvature_avg'] = features_df[mean_squared_curvature_cols].mean(axis=1)
    print(f"✓ Aggregated mean_squared_curvature into 'mean_squared_curvature_avg'")
    # remove individual columns if desired
    features_df.drop(columns=mean_squared_curvature_cols, inplace=True)


| Present In           | Attributes                                                                 |
|-----------------------|----------------------------------------------------------------------------|
| **Both**              | region_label, total_length, num_bifurcations, volume, num_loops, num_abnormal_degree_nodes, subject_id, root_mean_curvature_avg, mean_squared_curvature_avg |
| **Only in COMPONENT** | bifurcation_density, fractal_dimension, lacunarity                         |
| **Only in REGION SUMMARY** | num_components, betti_0, betti_1, betti_2                                 |


## REMOVE IOP

In [None]:
import os
import pandas as pd

# # Directory you listed earlier (adjust if different)
# ixi_dir = '/data/galati/brain_data/IXI_FINAL/fold1/imagesTs'

# # Build list (or reuse existing IXI_LIST variable if already defined)
# try:
#     IXI_LIST = os.listdir(ixi_dir)
# except Exception as e:
#     IXI_LIST = globals().get('IXI_LIST', [])
#     print(f"Warning: could not list {ixi_dir}: {e}. Using existing IXI_LIST (len={len(IXI_LIST)})")

# # Prepare DataFrame and save
# out_dir = '/home/falcetta/ISBI2025/METADATA'
# os.makedirs(out_dir, exist_ok=True)
# out_path = os.path.join(out_dir, 'IXI_LIST.csv')

# df = pd.DataFrame({'filename': IXI_LIST})
# df.to_csv(out_path, header=True)
# print(f"Saved {len(df)} rows to {out_path}")

#Read csv from /home/falcetta/ISBI2025/METADATA/IXI_LIST.csv

import pandas as pd

# Define the path to your CSV file
file_path = '/home/falcetta/ISBI2025/METADATA/IXI_LIST.csv'

try:
    # Read the CSV file into a pandas DataFrame
    COMPLETE_IXI = pd.read_csv(file_path)
    
    # Display the first 5 rows of the DataFrame
    print("Successfully loaded the CSV file. Here are the some rows:")
    
    # To display the entire DataFrame, you could just use:
    # print(df)

except FileNotFoundError:
    print(f"Error: The file was not found at the path: {file_path}")
    print("Please make sure the file path is correct and the file exists.")
except Exception as e:
    print(f"An error occurred: {e}")


COMPLETE_IXI

In [None]:
# # --- Step 2: Function to extract site from a full filename ---
# def identify_site(filename):
#     """Identifies the IXI imaging site (Guys, HH, or IOP) from a filename string."""
#     match = re.search(r'(Guys|HH|IOP)', filename)
#     if match:
#         return match.group(1)
#     return 'Unknown'


# # --- Step 3: Main function to find the site from a short ID ---
# def get_site_from_id(subject_id, COMPLETE_IXI=COMPLETE_IXI):
#     """
#     Finds the site for a short subject ID (e.g., 'IXI002') by searching a DataFrame.
    
#     Args:
#         subject_id (str): The short ID to search for (e.g., 'IXI002').
#         df (pd.DataFrame): The DataFrame containing a 'filename' column.

#     Returns:
#         str: The corresponding site name or a message if not found.
#     """
#     df = COMPLETE_IXI
#     # Find all rows where the filename starts with the subject_id
#     # This is more precise than 'contains' as it avoids matching 'IXI012' with 'IXI0123'
#     matching_rows = df[df['filename'].str.startswith(subject_id)]
    
#     # If we found at least one match...
#     if not matching_rows.empty:
#         # Get the full filename from the first result
#         full_filename = matching_rows.iloc[0]['filename']
#         # Extract and return the site from that filename
#         return identify_site(full_filename)
#     else:
#         # If no filenames matched, return a 'not found' message
#         return f"Subject ID '{subject_id}' not found in the DataFrame."

# # Add site information
# features_df['site'] = features_df['subject_id'].apply(get_site_from_id)

# # Remove 'Unknown' site entries
# features_df = features_df[features_df['site'] != 'Unknown']
# site_counts = features_df['site'].value_counts()
# features_df
    

In [None]:
# --- Optimized version ---
import re

# Create a lookup dictionary once (instead of searching repeatedly)
def create_site_lookup(df):
    """Creates a fast lookup dictionary mapping subject_id to site."""
    site_lookup = {}
    pattern = re.compile(r'IXI\d+')  # Match subject IDs like IXI002, IXI123, etc.
    
    for filename in df['filename'].unique():
        # Extract subject ID from filename
        match = pattern.match(filename)
        if match:
            subject_id = match.group()
            # Extract site from filename
            site_match = re.search(r'(Guys|HH|IOP)', filename)
            if site_match and subject_id not in site_lookup:
                site_lookup[subject_id] = site_match.group(1)
    
    return site_lookup

# Build the lookup dictionary once
site_lookup = create_site_lookup(COMPLETE_IXI)

# Apply the lookup (vectorized operation)
features_df['site'] = features_df['subject_id'].map(site_lookup).fillna('Unknown')

# Remove 'Unknown' site entries
features_df = features_df[features_df['site'] != 'Unknown']
site_counts = features_df['site'].value_counts()

In [None]:
features_df

In [None]:
# REMOVE IOP
features_df = features_df[features_df['site'] != 'IOP']
# REMOVE GUYS
#features_df = features_df[features_df['site'] != 'Guys']
# REMOVE HH
#features_df = features_df[features_df['site'] != 'HH']

# JUST IOP
#features_df = features_df[features_df['site'] == 'IOP']
# JUST GUYS
#features_df = features_df[features_df['site'] == 'Guys']
# JUST HH
#features_df = features_df[features_df['site'] !== 'HH']
site_counts = features_df['site'].value_counts()

features_df

In [None]:
site_counts

---
## 3. Automatic Feature Categorization and Detection

In [None]:
# Feature categorization keywords
MORPHOMETRIC_KEYWORDS = ['length', 'volume', 'area', 'diameter', 'radius', 'thickness', 'density', 'count', 'num_']
TOPOLOGICAL_KEYWORDS = ['fractal', 'lacunarity', 'betti', 'bifurcation', 'loop', 'component', 'branch', 'node', 'degree']
CURVATURE_KEYWORDS = ['curvature', 'tortuosity', 'sinuosity', 'curl']
REGION_KEYWORDS = ['region', 'territory', 'hemisphere', 'lobe', 'area', 'region_label']

def categorize_feature(col_name):
    """Categorize a feature column based on its name."""
    col_lower = col_name.lower()
    
    if any(kw in col_lower for kw in CURVATURE_KEYWORDS):
        return 'curvature'
    elif any(kw in col_lower for kw in TOPOLOGICAL_KEYWORDS):
        return 'topological'
    elif any(kw in col_lower for kw in MORPHOMETRIC_KEYWORDS):
        return 'morphometric'
    elif any(kw in col_lower for kw in REGION_KEYWORDS):
        return 'region_identifier'
    else:
        return 'other'

# Identify feature columns (exclude subject_id and known identifiers)
EXCLUDE_COLS = ['subject_id', 'region', 'region_id', 'hemisphere', 'territory', 'site', 'label']
feature_columns = [col for col in features_df.columns if col not in EXCLUDE_COLS]

# Categorize all features
feature_categories = {}
for col in feature_columns:
    # Only include numeric columns as features
    if pd.api.types.is_numeric_dtype(features_df[col]):
        feature_categories[col] = categorize_feature(col)

# Organize features by category
MORPHOMETRIC_FEATURES = [f for f, cat in feature_categories.items() if cat == 'morphometric']
TOPOLOGICAL_FEATURES = [f for f, cat in feature_categories.items() if cat == 'topological']
CURVATURE_FEATURES = [f for f, cat in feature_categories.items() if cat == 'curvature']
OTHER_FEATURES = [f for f, cat in feature_categories.items() if cat == 'other']
ALL_FEATURES = list(feature_categories.keys())
# Remove region_label from ALL_FEATURES if present
if 'region_label' in ALL_FEATURES:
    ALL_FEATURES.remove('region_label')

# Check for regional indicators
HAS_REGIONS = 'region' in features_df.columns or 'region_id' in features_df.columns or any('region' in str(col).lower() for col in features_df.columns)
HAS_HEMISPHERE = 'hemisphere' in features_df.columns or any('hemisphere' in str(col).lower() for col in features_df.columns) 

print("\n" + "="*80)
print("FEATURE DETECTION SUMMARY")
print("="*80)
print(f"\nTotal features detected: {len(ALL_FEATURES)}")
print(f"\nFeature categories:")
print(f"  Morphometric features: {len(MORPHOMETRIC_FEATURES)}")
if len(MORPHOMETRIC_FEATURES) > 0:
    print(f"    Examples: {MORPHOMETRIC_FEATURES[:5]}")
print(f"  Topological features:  {len(TOPOLOGICAL_FEATURES)}")
if len(TOPOLOGICAL_FEATURES) > 0:
    print(f"    Examples: {TOPOLOGICAL_FEATURES[:5]}")
print(f"  Curvature features:    {len(CURVATURE_FEATURES)}")
if len(CURVATURE_FEATURES) > 0:
    print(f"    Examples: {CURVATURE_FEATURES[:5]}")
print(f"  Other numeric features: {len(OTHER_FEATURES)}")
if len(OTHER_FEATURES) > 0:
    print(f"    Examples: {OTHER_FEATURES[:5]}")

print(f"\nData structure flags:")
print(f"  Regional data:     {IS_REGIONAL_DATA}")
print(f"  Has regions:       {HAS_REGIONS}")
print(f"  Has hemispheres:   {HAS_HEMISPHERE}")

# Save feature catalog
feature_catalog = pd.DataFrame([
    {'feature': f, 'category': cat, 'dtype': features_df[f].dtype}
    for f, cat in feature_categories.items()
])
feature_catalog.to_csv(TABLES_DIR / 'feature_catalog.csv', index=False)
print(f"\n✓ Feature catalog saved to {TABLES_DIR / 'feature_catalog.csv'}")

### 3.1 Aggregate Regional Data (if applicable)

In [None]:
# If regional data, create whole-brain aggregates
if IS_REGIONAL_DATA:
    print("\nCreating whole-brain aggregates from regional data...")
    
    # Save original regional data
    regional_df = features_df.copy()
    
    # Aggregate strategies by feature type
    agg_dict = {}
    for feature in ALL_FEATURES:
        # Sum for extensive properties (length, volume, counts)
        if any(kw in feature.lower() for kw in ['length', 'volume', 'count', 'num_', 'area']):
            agg_dict[feature] = 'sum'
        # Mean for intensive properties (density, curvature, etc.)
        else:
            agg_dict[feature] = 'mean'
    
    # Create aggregated dataset
    features_df = regional_df.groupby('subject_id').agg(agg_dict).reset_index()
    
    print(f"✓ Aggregated to whole-brain features")
    print(f"  Subjects: {len(features_df)}")
    print(f"  Features: {len(ALL_FEATURES)}")
else:
    print("\nData already at subject level (one row per subject)")
    regional_df = None

In [None]:
features_df

### 3.2 Merge Demographics with Features

In [None]:
# Merge demographics with vessel features
print("\nMerging demographics with vessel features...")
df = demographics.merge(features_df, on='subject_id', how='inner')

print(f"✓ Merged dataset:")
print(f"  Total subjects: {len(df)}")
print(f"  Vessel features: {len(ALL_FEATURES)}")
print(f"  Demographic variables: {len(DEMOGRAPHIC_VARS)}")
print(f"  Total columns: {len(df.columns)}")

# Check for missing data
missing_summary = df[ALL_FEATURES + DEMOGRAPHIC_VARS].isnull().sum()
missing_summary = missing_summary[missing_summary > 0].sort_values(ascending=False)
if len(missing_summary) > 0:
    print(f"\nVariables with missing data:")
    print(missing_summary)
else:
    print(f"\n✓ No missing data detected")

# Display merged data
print(f"\nFirst 3 subjects (merged data):")
display(df.head(3))

---
## 4. Data Quality Assessment

In [None]:
# Quality checks
print("\n" + "="*80)
print("DATA QUALITY ASSESSMENT")
print("="*80)

# 1. Check for duplicates
n_duplicates = df.duplicated(subset=['subject_id']).sum()
print(f"\n1. Duplicate subjects: {n_duplicates}")
if n_duplicates > 0:
    print("   ⚠️  Warning: Duplicate subject IDs detected")
    df = df.drop_duplicates(subset=['subject_id'], keep='first')
    print(f"   Removed duplicates, {len(df)} subjects remaining")

# 2. Check age range
if 'AGE' in df.columns:
    print(f"\n2. Age distribution:")
    print(f"   Range: {df['AGE'].min():.1f} - {df['AGE'].max():.1f} years")
    print(f"   Mean ± SD: {df['AGE'].mean():.1f} ± {df['AGE'].std():.1f} years")
    print(f"   Median: {df['AGE'].median():.1f} years")

# 3. Check sex distribution
if 'SEX_ID' in df.columns:
    print(f"\n3. Sex distribution:")
    sex_counts = df['SEX_ID'].value_counts()
    for sex_id, count in sex_counts.items():
        sex_label = 'Male' if sex_id == 1 else 'Female' if sex_id == 2 else f'Unknown ({sex_id})'
        print(f"   {sex_label}: {count} ({100*count/len(df):.1f}%)")

# 4. Feature distributions
print(f"\n4. Feature value ranges:")
feature_stats = df[ALL_FEATURES].describe()
print(f"   Min values: {(feature_stats.loc['min'] == 0).sum()} features are all zeros")
print(f"   Max values: Largest = {feature_stats.loc['max'].max():.2e}")
print(f"   Missing: {df[ALL_FEATURES].isnull().any().sum()} features have missing values")

# 5. Outlier detection (IQR method)
def detect_outliers_iqr(series):
    Q1 = series.quantile(0.25)
    Q3 = series.quantile(0.75)
    IQR = Q3 - Q1
    lower = Q1 - 3 * IQR
    upper = Q3 + 3 * IQR
    return ((series < lower) | (series > upper)).sum()

outliers_per_feature = df[ALL_FEATURES].apply(detect_outliers_iqr)
features_with_outliers = outliers_per_feature[outliers_per_feature > 0]
print(f"\n5. Outliers (3×IQR method):")
print(f"   Features with outliers: {len(features_with_outliers)}")
if len(features_with_outliers) > 0:
    print(f"   Top 5 features by outlier count:")
    print(features_with_outliers.sort_values(ascending=False).head())

print(f"\n✓ Quality assessment complete")

---
## 5. Descriptive Statistics

In [None]:
# Comprehensive descriptive statistics
print("\n" + "="*80)
print("DESCRIPTIVE STATISTICS")
print("="*80)

# Demographics summary
print("\nDemographic Variables:")
demo_summary = df[DEMOGRAPHIC_VARS].describe()
display(demo_summary)

# Features summary (by category)
if len(MORPHOMETRIC_FEATURES) > 0:
    print(f"\nMorphometric Features (n={len(MORPHOMETRIC_FEATURES)}):")
    display(df[MORPHOMETRIC_FEATURES].describe())

if len(TOPOLOGICAL_FEATURES) > 0:
    print(f"\nTopological Features (n={len(TOPOLOGICAL_FEATURES)}):")
    display(df[TOPOLOGICAL_FEATURES].describe())

if len(CURVATURE_FEATURES) > 0:
    print(f"\nCurvature Features (n={len(CURVATURE_FEATURES)}):")
    display(df[CURVATURE_FEATURES].describe())

# Save descriptive statistics
all_stats = df[DEMOGRAPHIC_VARS + ALL_FEATURES].describe()
all_stats.to_csv(TABLES_DIR / 'descriptive_statistics.csv')
print(f"\n✓ Descriptive statistics saved to {TABLES_DIR / 'descriptive_statistics.csv'}")

### 5.1 Visualize Demographic Distributions

In [None]:
# Create demographic distribution plots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle('Demographic Distributions', fontsize=14, fontweight='bold')

# Age distribution
if 'AGE' in df.columns:
    ax = axes[0, 0]
    ax.hist(df['AGE'].dropna(), bins=30, edgecolor='black', alpha=0.7)
    ax.axvline(df['AGE'].mean(), color='red', linestyle='--', label=f'Mean: {df["AGE"].mean():.1f}')
    ax.set_xlabel('Age (years)')
    ax.set_ylabel('Frequency')
    ax.set_title('Age Distribution')
    ax.legend()
    ax.grid(True, alpha=0.3)

# Sex distribution
if 'SEX_ID' in df.columns:
    ax = axes[0, 1]
    sex_counts = df['SEX_ID'].value_counts()
    sex_labels = ['Male' if x==1 else 'Female' for x in sex_counts.index]
    bars = ax.bar(sex_labels, sex_counts.values, edgecolor='black', alpha=0.7)
    ax.set_ylabel('Count')
    ax.set_title('Sex Distribution')
    ax.grid(True, alpha=0.3, axis='y')
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height)}\n({100*height/len(df):.1f}%)',
                ha='center', va='bottom')

# BMI distribution
if 'BMI' in df.columns:
    ax = axes[1, 0]
    bmi_data = df['BMI'].dropna()
    ax.hist(bmi_data, bins=30, edgecolor='black', alpha=0.7)
    ax.axvline(bmi_data.mean(), color='red', linestyle='--', label=f'Mean: {bmi_data.mean():.1f}')
    ax.set_xlabel('BMI (kg/m²)')
    ax.set_ylabel('Frequency')
    ax.set_title('BMI Distribution')
    ax.legend()
    ax.grid(True, alpha=0.3)

# Age by sex
if 'AGE' in df.columns and 'SEX_ID' in df.columns:
    ax = axes[1, 1]
    for sex_id in sorted(df['SEX_ID'].dropna().unique()):
        sex_label = 'Male' if sex_id == 1 else 'Female'
        age_data = df[df['SEX_ID'] == sex_id]['AGE'].dropna()
        ax.hist(age_data, bins=20, alpha=0.6, label=sex_label, edgecolor='black')
    ax.set_xlabel('Age (years)')
    ax.set_ylabel('Frequency')
    ax.set_title('Age Distribution by Sex')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(FIGURES_DIR / 'demographic_distributions.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"✓ Demographic distributions saved to {FIGURES_DIR / 'demographic_distributions.png'}")

In [None]:
# Create demographic distribution plots
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
fig.suptitle('Demographic Distributions', fontsize=14, fontweight='bold')


if 'ETHNIC_ID' in df.columns:
    ax = axes[0, 0]
    ethnic_counts = df['ETHNIC_ID'].value_counts()
    ethnic_labels = [str(x) for x in ethnic_counts.index]
    bars = ax.bar(ethnic_labels, ethnic_counts.values, edgecolor='black', alpha=0.7)
    ax.set_ylabel('Count')
    ax.set_title('Ethnic Distribution')
    ax.grid(True, alpha=0.3, axis='y')
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height)}\n({100*height/len(df):.1f}%)',
                ha='center', va='bottom')
    

if 'MARITAL_ID' in df.columns:
    ax = axes[0, 1]
    marital_counts = df['MARITAL_ID'].value_counts()
    marital_labels = [str(x) for x in marital_counts.index]
    bars = ax.bar(marital_labels, marital_counts.values, edgecolor='black', alpha=0.7)
    ax.set_ylabel('Count')
    ax.set_title('Marital Status Distribution')
    ax.grid(True, alpha=0.3, axis='y')
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height)}\n({100*height/len(df):.1f}%)',
                ha='center', va='bottom')
if 'OCCUPATION_ID' in df.columns:
    ax = axes[1, 0]
    occupation_counts = df['OCCUPATION_ID'].value_counts()
    occupation_labels = [str(x) for x in occupation_counts.index]
    bars = ax.bar(occupation_labels, occupation_counts.values, edgecolor='black', alpha=0.7)
    ax.set_ylabel('Count')
    ax.set_title('Occupation Distribution')
    ax.grid(True, alpha=0.3, axis='y')
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height)}\n({100*height/len(df):.1f}%)',
                ha='center', va='bottom')
if 'QUALIFICATION_ID' in df.columns:
    ax = axes[1, 1]
    qualification_counts = df['QUALIFICATION_ID'].value_counts()
    qualification_labels = [str(x) for x in qualification_counts.index]
    bars = ax.bar(qualification_labels, qualification_counts.values, edgecolor='black', alpha=0.7)
    ax.set_ylabel('Count')
    ax.set_title('Qualification Distribution')
    ax.grid(True, alpha=0.3, axis='y')
    # Add value labels on bars
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{int(height)}\n({100*height/len(df):.1f}%)',
                ha='center', va='bottom')



        
plt.tight_layout()
plt.savefig(FIGURES_DIR / 'demographic_distributions.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"✓ Demographic distributions saved to {FIGURES_DIR / 'demographic_distributions.png'}")

In [None]:
df

---
## 6. Age-Related Analysis

### 6.1 Age Correlations for All Features

In [None]:
# Check if age is available
if 'AGE' not in df.columns:
    print("⚠️  AGE variable not found. Skipping age-related analyses.")
else:
    print("\n" + "="*80)
    print("AGE-RELATED ANALYSIS")
    print("="*80)
    
    # Calculate correlations with age for all features
    age_correlations = []
    
    for feature in ALL_FEATURES:
        # Remove missing values
        valid_data = df[['AGE', feature]].dropna()
        
        if len(valid_data) < 10:
            continue
        
        # Calculate Pearson and Spearman correlations
        pearson_r, pearson_p = pearsonr(valid_data['AGE'], valid_data[feature])
        spearman_r, spearman_p = spearmanr(valid_data['AGE'], valid_data[feature])
        
        # Linear regression for slope
        slope, intercept, r_value, p_value, std_err = linregress(valid_data['AGE'], valid_data[feature])
        
        age_correlations.append({
            'feature': feature,
            'category': feature_categories[feature],
            'n': len(valid_data),
            'pearson_r': pearson_r,
            'pearson_p': pearson_p,
            'spearman_r': spearman_r,
            'spearman_p': spearman_p,
            'slope': slope,
            'r_squared': r_value**2
        })
    
    age_corr_df = pd.DataFrame(age_correlations)
    
    # Multiple testing correction
    if len(age_corr_df) > 0:
        age_corr_df['pearson_p_fdr'] = multipletests(age_corr_df['pearson_p'], method='fdr_bh')[1]
        age_corr_df['spearman_p_fdr'] = multipletests(age_corr_df['spearman_p'], method='fdr_bh')[1]
        age_corr_df['significant_pearson'] = age_corr_df['pearson_p_fdr'] < 0.05
        age_corr_df['significant_spearman'] = age_corr_df['spearman_p_fdr'] < 0.05
        
        # Sort by absolute correlation strength
        age_corr_df = age_corr_df.sort_values('pearson_r', key=abs, ascending=False)
        
        # Summary
        n_sig_pearson = age_corr_df['significant_pearson'].sum()
        n_sig_spearman = age_corr_df['significant_spearman'].sum()
        
        print(f"\nAge Correlation Summary:")
        print(f"  Total features tested: {len(age_corr_df)}")
        print(f"  Significant (Pearson, FDR<0.05): {n_sig_pearson} ({100*n_sig_pearson/len(age_corr_df):.1f}%)")
        print(f"  Significant (Spearman, FDR<0.05): {n_sig_spearman} ({100*n_sig_spearman/len(age_corr_df):.1f}%)")
        
        print(f"\nTop 10 age-correlated features (by absolute Pearson r):")
        display(age_corr_df.head(10)[['feature', 'category', 'pearson_r', 'pearson_p_fdr', 'r_squared']])
        
        # Save results
        age_corr_df.to_csv(TABLES_DIR / 'age_correlations.csv', index=False)
        print(f"\n✓ Age correlations saved to {TABLES_DIR / 'age_correlations.csv'}")

In [None]:
df

### 6.2 Visualize Top Age Correlations

In [None]:
if 'AGE' in df.columns and len(age_corr_df) > 0:
    # Select top 6 most strongly correlated features
    top_features = age_corr_df.head(6)['feature'].tolist()
    
    # Create scatter plots
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle('Top Age-Correlated Vessel Features', fontsize=14, fontweight='bold')
    axes = axes.flatten()
    
    for idx, feature in enumerate(top_features):
        ax = axes[idx]
        
        # Get data
        valid_data = df[['AGE', feature]].dropna()
        
        # Scatter plot
        ax.scatter(valid_data['AGE'], valid_data[feature], alpha=0.5, s=20)
        
        # Regression line
        z = np.polyfit(valid_data['AGE'], valid_data[feature], 1)
        p = np.poly1d(z)
        ax.plot(valid_data['AGE'], p(valid_data['AGE']), 'r--', linewidth=2, alpha=0.8)
        
        # Get correlation stats
        feat_stats = age_corr_df[age_corr_df['feature'] == feature].iloc[0]
        
        ax.set_xlabel('Age (years)')
        ax.set_ylabel(feature)
        ax.set_title(f"{feature}\nr={feat_stats['pearson_r']:.3f}, p={feat_stats['pearson_p']:.2e}")
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'age_correlations_top6.png', dpi=300, bbox_inches='tight')
    plt.show()
    print(f"✓ Top age correlations plot saved to {FIGURES_DIR / 'age_correlations_top6.png'}")

### 6.3 Age Correlation Heatmap by Feature Category

In [None]:
age_corr_df

In [None]:
# Publication-quality figure settings
plt.rcParams['figure.dpi'] = 1000
plt.rcParams['savefig.dpi'] = 1000
plt.rcParams['font.size'] = 20
plt.rcParams['axes.labelsize'] = 25
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['xtick.labelsize'] = 25
plt.rcParams['ytick.labelsize'] = 25
plt.rcParams['legend.fontsize'] = 9

In [None]:
if 'AGE' in df.columns and len(age_corr_df) > 0:
    # Create heatmap of correlations grouped by category
    fig, ax = plt.subplots(figsize=(10, max(6, len(age_corr_df) * 0.3)))
    
    # Prepare data for heatmap
    heatmap_data = age_corr_df[['feature', 'category', 'spearman_r']].copy()
    heatmap_data = heatmap_data.sort_values(['category', 'spearman_r'], ascending=[True, False])
    
    # Create color array for significance
    colors = ['red' if sig else 'gray' for sig in age_corr_df.sort_values(['category', 'spearman_r'], ascending=[True, False])['significant_pearson']]
    
    # Horizontal bar plot
    y_pos = np.arange(len(heatmap_data))
    ax.barh(y_pos, heatmap_data['spearman_r'], color=colors, alpha=0.6, edgecolor='black')
    ax.set_yticks(y_pos)
    ax.set_yticklabels(heatmap_data['feature'], fontsize=8)
    ax.set_xlabel('Spearman Correlation with Age', fontweight='bold')
    ax.set_title('Age Correlations for All Features\n(Red = FDR-significant, Gray = Non-significant)', fontweight='bold')
    ax.axvline(0, color='black', linewidth=1)
    ax.grid(True, alpha=0.3, axis='x')
    
    # Add category separators
    category_changes = heatmap_data['category'].ne(heatmap_data['category'].shift())
    for idx in category_changes[category_changes].index[1:]:
        ax.axhline(y=y_pos[heatmap_data.index.get_loc(idx)] - 0.5, color='black', linestyle='--', linewidth=1, alpha=0.5)
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'age_correlations_all.png', dpi=300, bbox_inches='tight')
    plt.show()
    print(f"✓ Complete age correlations plot saved to {FIGURES_DIR / 'age_correlations_all.png'}")

In [None]:
all_regions = regional_df['region_label'].unique() if regional_df is not None and 'region_label' in regional_df.columns else None
all_regions

In [None]:
regional_df_demo = regional_df.merge(demographics, on='subject_id', how='inner')
regional_df_demo

In [None]:
# Check if age is available
if 'AGE' not in regional_df_demo.columns:
    print("⚠️  AGE variable not found. Skipping age-related analyses.")
else:
    print("\n" + "="*80)
    print("AGE-RELATED ANALYSIS")
    print("="*80)
    
    # Calculate correlations with age for all features
    age_correlations_region = []
    
    for feature in ALL_FEATURES:
        # Remove missing values
        for region in all_regions:
            print(f"Analyzing feature '{feature}' in region '{region}'...")
            valid_data = regional_df_demo[['AGE', feature, 'region_label']][regional_df_demo['region_label'] == region].dropna()
            
            if len(valid_data) < 10:
                continue
            
            # Calculate Pearson and Spearman correlations
            pearson_r, pearson_p = pearsonr(valid_data['AGE'], valid_data[feature])
            spearman_r, spearman_p = spearmanr(valid_data['AGE'], valid_data[feature])
            
            # Linear regression for slope
            slope, intercept, r_value, p_value, std_err = linregress(valid_data['AGE'], valid_data[feature])
            
            age_correlations_region.append({
                'feature': feature,
                'category': feature_categories[feature],
                'n': len(valid_data),
                'pearson_r': pearson_r,
                'pearson_p': pearson_p,
                'spearman_r': spearman_r,
                'spearman_p': spearman_p,
                'slope': slope,
                'r_squared': r_value**2,
                'region': region
            })
        
        age_corr_df_region = pd.DataFrame(age_correlations_region)
        
        # Multiple testing correction
        if len(age_corr_df_region) > 0:
            age_corr_df_region['pearson_p_fdr'] = multipletests(age_corr_df_region['pearson_p'], method='fdr_bh')[1]
            age_corr_df_region['spearman_p_fdr'] = multipletests(age_corr_df_region['spearman_p'], method='fdr_bh')[1]
            age_corr_df_region['significant_pearson'] = age_corr_df_region['pearson_p_fdr'] < 0.05
            age_corr_df_region['significant_spearman'] = age_corr_df_region['spearman_p_fdr'] < 0.05

            # Sort by absolute correlation strength
            age_corr_df_region = age_corr_df_region.sort_values('pearson_r', key=abs, ascending=False)
            
            # Summary
            n_sig_pearson = age_corr_df_region['significant_pearson'].sum()
            n_sig_spearman = age_corr_df_region['significant_spearman'].sum()

            print(f"\nAge Correlation Summary:")
            print(f"  Total features tested: {len(age_corr_df_region)}")
            print(f"  Significant (Pearson, FDR<0.05): {n_sig_pearson} ({100*n_sig_pearson/len(age_corr_df_region):.1f}%)")
            print(f"  Significant (Spearman, FDR<0.05): {n_sig_spearman} ({100*n_sig_spearman/len(age_corr_df_region):.1f}%)")

            print(f"\nTop 10 age-correlated features (by absolute Pearson r):")
            display(age_corr_df_region.head(10)[['feature', 'category', 'pearson_r', 'pearson_p_fdr', 'r_squared']])

            # Save results
            #age_corr_df.to_csv(TABLES_DIR / 'age_correlations.csv', index=False)
            #print(f"\n✓ Age correlations saved to {TABLES_DIR / 'age_correlations.csv'}")

In [None]:
age_corr_df_region

In [None]:
if 'AGE' in regional_df_demo.columns and len(age_corr_df_region) > 0:
    # Create heatmap of correlations grouped by category
    for region in all_regions:
        fig, ax = plt.subplots(figsize=(10, max(6, len(age_corr_df_region) * 0.01)))
        
        # Prepare data for heatmap
        heatmap_data = age_corr_df_region[['feature', 'category', 'pearson_r']][age_corr_df_region['region'] == region].copy()
        heatmap_data = heatmap_data.sort_values(['category', 'pearson_r'], ascending=[True, False])
        
        # Create color array for significance
        colors = ['red' if sig else 'gray' for sig in age_corr_df_region[age_corr_df_region['region'] == region].sort_values(['category', 'pearson_r'], ascending=[True, False])['significant_pearson']]

        # Horizontal bar plot
        y_pos = np.arange(len(heatmap_data))
        ax.barh(y_pos, heatmap_data['pearson_r'], color=colors, alpha=0.6, edgecolor='black')
        ax.set_yticks(y_pos)
        ax.set_yticklabels(heatmap_data['feature'], fontsize=8)
        ax.set_xlabel('Pearson Correlation with Age', fontweight='bold')
        ax.set_title(f'Age Correlations for All Features in Region: {region}\n(Red = FDR-significant, Gray = Non-significant)', fontweight='bold')
        ax.axvline(0, color='black', linewidth=1)
        ax.grid(True, alpha=0.3, axis='x')
        
        # Add category separators
        category_changes = heatmap_data['category'].ne(heatmap_data['category'].shift())
        for idx in category_changes[category_changes].index[1:]:
            ax.axhline(y=y_pos[heatmap_data.index.get_loc(idx)] - 0.5, color='black', linestyle='--', linewidth=1, alpha=0.5)
        
        plt.tight_layout()
        plt.savefig(FIGURES_DIR / 'age_correlations_all.png', dpi=300, bbox_inches='tight')
        plt.show()
        print(f"✓ Complete age correlations plot saved to {FIGURES_DIR / 'age_correlations_all.png'}")

In [None]:
# ============================================================================
# 6.4 AGE GROUP BOXPLOTS
# ============================================================================

if 'AGE' in df.columns and len(age_corr_df) > 0:
    print("\n" + "="*80)
    print("AGE GROUP ANALYSIS - BOXPLOT VISUALIZATION")
    print("="*80)
    
    # Set matplotlib parameters to prevent memory issues
    plt.rcParams['figure.max_open_warning'] = 0
    plt.rcParams['agg.path.chunksize'] = 10000
    
    # Create age groups with bins: 20, 40, 60, 80, 100
    age_bins = [20, 40, 60, 80, 100]
    age_labels = ['20-39', '40-59', '60-79', '80+']
    
    # Create quartiles for age groups
    age_bins = [df['AGE'].min()-1] + list(df['AGE'].quantile([0.25, 0.5, 0.75])) + [df['AGE'].max()+1]
    age_labels = [f"{int(age_bins[i]+1)}-{int(age_bins[i+1])}" for i in range(len(age_bins)-1)]
    
    # Create age group column
    df['age_group'] = pd.cut(df['AGE'], bins=age_bins, labels=age_labels, right=False)
    
    # Display age group distribution
    print(f"\nAge group distribution:")
    age_group_counts = df['age_group'].value_counts().sort_index()
    for group, count in age_group_counts.items():
        print(f"  {group}: {count} subjects ({100*count/len(df[df['AGE'].notna()]):.1f}%)")
    
    # ========================================================================
    # Statistical Testing Between Age Groups
    # ========================================================================
    print("\n" + "-"*80)
    print("STATISTICAL TESTING BETWEEN AGE GROUPS")
    print("-"*80)
    
    anova_results = []
    
    for feature in ALL_FEATURES:
        # Get data for each age group
        age_group_data = []
        for group in age_labels:
            group_data = df[df['age_group'] == group][feature].dropna().values
            if len(group_data) >= 5:
                age_group_data.append(group_data)
        
        if len(age_group_data) < 2:
            continue
        
        # Perform ANOVA
        f_stat, anova_p = f_oneway(*age_group_data)
        
        # Perform Kruskal-Wallis (non-parametric alternative)
        h_stat, kw_p = kruskal(*age_group_data)
        
        # Calculate effect size (eta-squared)
        grand_mean = np.mean(np.concatenate(age_group_data))
        ss_between = sum(len(d) * (np.mean(d) - grand_mean)**2 for d in age_group_data)
        ss_total = sum(np.sum((d - grand_mean)**2) for d in age_group_data)
        eta_squared = ss_between / ss_total if ss_total > 0 else 0
        
        anova_results.append({
            'feature': feature,
            'category': feature_categories[feature],
            'f_statistic': f_stat,
            'anova_pvalue': anova_p,
            'kruskal_h': h_stat,
            'kw_pvalue': kw_p,
            'eta_squared': eta_squared
        })
    
    anova_df = pd.DataFrame(anova_results)
    
    if len(anova_df) > 0:
        # Multiple testing correction
        anova_df['anova_p_fdr'] = multipletests(anova_df['anova_pvalue'], method='fdr_bh')[1]
        anova_df['kw_p_fdr'] = multipletests(anova_df['kw_pvalue'], method='fdr_bh')[1]
        anova_df['significant_anova'] = anova_df['anova_p_fdr'] < 0.05
        anova_df['significant_kw'] = anova_df['kw_p_fdr'] < 0.05
        
        anova_df = anova_df.sort_values('eta_squared', ascending=False)
        
        n_sig = anova_df['significant_anova'].sum()
        
        print(f"\nANOVA results:")
        print(f"  Features tested: {len(anova_df)}")
        print(f"  Significant differences (ANOVA, FDR<0.05): {n_sig} ({100*n_sig/len(anova_df):.1f}%)")
        print(f"  Significant differences (Kruskal-Wallis, FDR<0.05): {anova_df['significant_kw'].sum()}")
        
        print("\nTop 10 features with largest age group effects:")
        display(anova_df.head(10)[['feature', 'category', 'f_statistic', 'eta_squared', 
                         'anova_p_fdr', 'kw_p_fdr']].round(4))
        
        # Save results
        anova_df.to_csv(TABLES_DIR / 'age_group_anova.csv', index=False)
        print(f"\n✓ Age group ANOVA results saved to {TABLES_DIR / 'age_group_anova.csv'}")
        
        # ====================================================================
        # VISUALIZATION
        # ====================================================================
        
        if n_sig > 0:
            top_age_features = anova_df[anova_df['significant_anova']].head(6)['feature'].tolist()
            
            if len(top_age_features) > 0:
                print(f"\nVisualizing top {len(top_age_features)} age-correlated features...")
                
                # Define consistent color mapping for age groups
                unique_ages = age_labels
                age_colors = dict(zip(unique_ages, 
                                     sns.color_palette('YlOrRd', n_colors=len(unique_ages))))
                
                n_plots = min(6, len(top_age_features))
                fig, axes = plt.subplots(2, 3, figsize=(15, 10))
                fig.suptitle('Top Age Group Differences in Vessel Features',
                            fontsize=14, fontweight='bold')
                axes = axes.flatten()
                
                for idx, feature in enumerate(top_age_features[:n_plots]):
                    ax = axes[idx]
                    
                    plot_data = df[['age_group', feature]].dropna()
                    
                    # Age groups are already in natural order
                    age_order = age_labels
                    
                    # Create color palette in the correct order for this plot
                    plot_colors = [age_colors[age] for age in age_order]
                    
                    # Sample data if too large to prevent memory issues
                    if len(plot_data) > 500:
                        plot_data_sample = plot_data.sample(n=500, random_state=42)
                    else:
                        plot_data_sample = plot_data
                    
                    sns.boxplot(data=plot_data, x='age_group', y=feature, ax=ax,
                               order=age_order, palette=plot_colors)
                    sns.stripplot(data=plot_data_sample, x='age_group', y=feature, ax=ax,
                                 color='black', alpha=0.2, size=2, order=age_order)
                    
                    feat_stats = anova_df[anova_df['feature'] == feature].iloc[0]
                    eta_sq = feat_stats['eta_squared']
                    p_val = feat_stats['anova_p_fdr']
                    
                    if p_val < 0.001:
                        sig_stars = '***'
                        p_text = 'p < 0.001'
                    elif p_val < 0.01:
                        sig_stars = '**'
                        p_text = f'p = {p_val:.3f}'
                    elif p_val < 0.05:
                        sig_stars = '*'
                        p_text = f'p = {p_val:.3f}'
                    else:
                        sig_stars = 'ns'
                        p_text = f'p = {p_val:.2f}'
                    
                    clean_feature = feature.replace('_', ' ').replace('total ', '').replace('mean ', '').title()
                    
                    ax.set_title(f"{clean_feature} {sig_stars}\nη² = {eta_sq:.3f}, {p_text}",
                                fontsize=9)
                    ax.set_ylabel(feature.replace('_', ' ').title(), fontsize=9)
                    ax.set_xlabel('Age Group (years)', fontsize=9)
                    ax.tick_params(axis='x', labelsize=8)
                    ax.grid(True, alpha=0.3, axis='y')
                
                for idx in range(n_plots, 6):
                    axes[idx].axis('off')
                
                plt.tight_layout()
                plt.savefig(FIGURES_DIR / 'age_groups_boxplots_top6.png', dpi=150, bbox_inches='tight')
                print("✓ Age group boxplots saved")
                plt.show()
                plt.close('all')
    
    print("\n" + "="*80)
    print("✓ AGE GROUP ANALYSIS COMPLETE")
    print("="*80)
else:
    print("\n⚠️  Skipping age group analysis - no age correlation data available")

---
## 7. Sex-Based Analysis

In [None]:
if 'SEX_ID' not in df.columns:
    print("⚠️  SEX_ID variable not found. Skipping sex-based analyses.")
else:
    print("\n" + "="*80)
    print("SEX-BASED ANALYSIS")
    print("="*80)
    
    # Sex comparison for all features
    sex_comparisons = []
    
    for feature in ALL_FEATURES:
        # Get data for each sex
        male_data = df[(df['SEX_ID'] == 1) & df[feature].notna()][feature]
        female_data = df[(df['SEX_ID'] == 2) & df[feature].notna()][feature]
        
        if len(male_data) < 5 or len(female_data) < 5:
            continue
        
        # T-test and Mann-Whitney U test
        t_stat, t_pval = ttest_ind(male_data, female_data)
        u_stat, u_pval = mannwhitneyu(male_data, female_data, alternative='two-sided')
        
        # Effect size (Cohen's d)
        pooled_std = np.sqrt(((len(male_data)-1)*male_data.std()**2 + (len(female_data)-1)*female_data.std()**2) / (len(male_data)+len(female_data)-2))
        cohens_d = (male_data.mean() - female_data.mean()) / pooled_std
        
        sex_comparisons.append({
            'feature': feature,
            'category': feature_categories[feature],
            'n_male': len(male_data),
            'n_female': len(female_data),
            'male_mean': male_data.mean(),
            'female_mean': female_data.mean(),
            'male_std': male_data.std(),
            'female_std': female_data.std(),
            't_statistic': t_stat,
            't_pvalue': t_pval,
            'mwu_pvalue': u_pval,
            'cohens_d': cohens_d,
            'percent_diff': 100 * (male_data.mean() - female_data.mean()) / female_data.mean()
        })
    
    sex_comp_df = pd.DataFrame(sex_comparisons)
    
    if len(sex_comp_df) > 0:
        # Multiple testing correction
        sex_comp_df['t_pvalue_fdr'] = multipletests(sex_comp_df['t_pvalue'], method='fdr_bh')[1]
        sex_comp_df['mwu_pvalue_fdr'] = multipletests(sex_comp_df['mwu_pvalue'], method='fdr_bh')[1]
        sex_comp_df['significant_ttest'] = sex_comp_df['t_pvalue_fdr'] < 0.05
        sex_comp_df['significant_mwu'] = sex_comp_df['mwu_pvalue_fdr'] < 0.05
        
        # Sort by effect size
        sex_comp_df = sex_comp_df.sort_values('cohens_d', key=abs, ascending=False)
        
        # Summary
        n_sig = sex_comp_df['significant_ttest'].sum()
        print(f"\nSex Comparison Summary:")
        print(f"  Total features tested: {len(sex_comp_df)}")
        print(f"  Significant differences (t-test, FDR<0.05): {n_sig} ({100*n_sig/len(sex_comp_df):.1f}%)")
        
        print(f"\nTop 10 sex differences (by absolute Cohen's d):")
        display(sex_comp_df.head(10)[['feature', 'category', 'male_mean', 'female_mean', 'cohens_d', 't_pvalue_fdr']])
        
        # Save results
        sex_comp_df.to_csv(TABLES_DIR / 'sex_comparisons.csv', index=False)
        print(f"\n✓ Sex comparisons saved to {TABLES_DIR / 'sex_comparisons.csv'}")

### 7.1 Visualize Sex Differences

In [None]:
if 'SEX_ID' in df.columns and len(sex_comp_df) > 0:
    # Select top 6 features with largest effect sizes
    top_sex_features = sex_comp_df.head(6)['feature'].tolist()
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle('Top Sex Differences in Vessel Features', fontsize=14, fontweight='bold')
    axes = axes.flatten()
    
    for idx, feature in enumerate(top_sex_features):
        ax = axes[idx]
        
        # Prepare data
        plot_data = df[['SEX_ID', feature]].dropna()
        plot_data['Sex'] = plot_data['SEX_ID'].map({1: 'Male', 2: 'Female'})
        
        # Box plot
        sns.boxplot(data=plot_data, x='Sex', y=feature, ax=ax, palette='Set2')
        sns.stripplot(data=plot_data, x='Sex', y=feature, ax=ax, color='black', alpha=0.3, size=3)
        
        # Get stats
        feat_stats = sex_comp_df[sex_comp_df['feature'] == feature].iloc[0]
        
        # Determine significance asterisks
        p_val = feat_stats['t_pvalue']
        if p_val < 0.001:
            sig_marker = '***'
        elif p_val < 0.01:
            sig_marker = '**'
        elif p_val < 0.05:
            sig_marker = '*'
        else:
            sig_marker = 'ns'
        
        ax.set_title(f"{feature}\nd={feat_stats['cohens_d']:.3f}, p={p_val:.2e} {sig_marker}")
        ax.set_ylabel(feature)
        ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'sex_differences_top6.png', dpi=300, bbox_inches='tight')
    plt.show()
    print(f"✓ Sex differences plot saved to {FIGURES_DIR / 'sex_differences_top6.png'}")

---
## 7. BALANCED Sex-Related Analysis

In [None]:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def balance_sex_distribution(df, age_bin_width=5):
    """
    Create a balanced dataset with equal male/female counts in each age bin.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        Input dataframe with 'AGE' and 'SEX_ID (1=m, 2=f)' columns
    age_bin_width : int
        Width of age bins in years (default: 5)
    
    Returns:
    --------
    df_balanced : pandas.DataFrame
        Balanced dataframe with equal male/female representation per age bin
    balance_report : pandas.DataFrame
        Report showing original and balanced counts per age bin
    """
    
    # Get the correct sex column name
    sex_col = 'SEX_ID (1=m, 2=f)' if 'SEX_ID (1=m, 2=f)' in df.columns else 'SEX_ID'
    
    # Create age bins
    min_age = df['AGE'].min()
    max_age = df['AGE'].max()
    bins = range(int(min_age), int(max_age) + age_bin_width, age_bin_width)
    
    df_copy = df.copy()
    df_copy['age_bin'] = pd.cut(df_copy['AGE'], bins=bins, include_lowest=True, right=False)
    
    # Initialize list to store balanced samples
    balanced_samples = []
    balance_info = []
    
    # Process each age bin
    for age_bin in df_copy['age_bin'].dropna().unique():
        bin_data = df_copy[df_copy['age_bin'] == age_bin]
        
        # Count males and females in this bin
        males = bin_data[bin_data[sex_col] == 1]
        females = bin_data[bin_data[sex_col] == 2]
        
        n_males = len(males)
        n_females = len(females)
        
        # Take minimum count
        min_count = min(n_males, n_females)
        
        if min_count > 0:
            # Randomly sample min_count from each sex
            males_sampled = males.sample(n=min_count, random_state=42)
            females_sampled = females.sample(n=min_count, random_state=42)
            
            balanced_samples.append(males_sampled)
            balanced_samples.append(females_sampled)
            
            balance_info.append({
                'age_bin': str(age_bin),
                'original_male': n_males,
                'original_female': n_females,
                'balanced_male': min_count,
                'balanced_female': min_count,
                'total_kept': 2 * min_count,
                'total_removed': (n_males - min_count) + (n_females - min_count)
            })
    
    # Combine all balanced samples
    df_balanced = pd.concat(balanced_samples, ignore_index=True)
    df_balanced = df_balanced.drop('age_bin', axis=1)
    df_balanced = df_balanced.sort_values('AGE').reset_index(drop=True)
    
    # Create balance report
    balance_report = pd.DataFrame(balance_info)
    
    return df_balanced, balance_report


def plot_balanced_comparison(df_original, df_balanced, age_bin_width=5):
    """
    Visualize the original vs balanced sex distribution by age.
    """
    sex_col = 'SEX_ID (1=m, 2=f)' if 'SEX_ID (1=m, 2=f)' in df_original.columns else 'SEX_ID'
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # Create age bins for visualization
    min_age = int(df_original['AGE'].min())
    max_age = int(df_original['AGE'].max())
    bins = range(min_age, max_age + age_bin_width, age_bin_width)
    
    for idx, (df_plot, title) in enumerate([(df_original, 'Original Distribution'),
                                              (df_balanced, 'Balanced Distribution')]):
        ax = axes[idx]
        
        # Separate male and female data
        males = df_plot[df_plot[sex_col] == 1]['AGE']
        females = df_plot[df_plot[sex_col] == 2]['AGE']
        
        # Create stacked histogram
        ax.hist([males, females], bins=bins, label=['Male', 'Female'],
                color=['#ff7f0e', '#ff69b4'], alpha=0.8, edgecolor='black', linewidth=0.5)
        
        ax.set_xlabel('Age (years)', fontsize=12)
        ax.set_ylabel('Frequency', fontsize=12)
        ax.set_title(title, fontsize=14, fontweight='bold')
        ax.legend(fontsize=11)
        ax.grid(axis='y', alpha=0.3)
        
        # Add sample size annotation
        n_total = len(df_plot)
        n_males = len(males)
        n_females = len(females)
        ax.text(0.02, 0.98, f'N = {n_total}\nMale: {n_males}\nFemale: {n_females}',
                transform=ax.transAxes, fontsize=10, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.savefig('age_sex_distribution_comparison.png', dpi=300, bbox_inches='tight')
    print("✓ Saved: age_sex_distribution_comparison.png")
    plt.show()


# ============================================================================
# USAGE EXAMPLE
# ============================================================================

# Balance the dataset
df_balanced, balance_report = balance_sex_distribution(df, age_bin_width=5)

print("=" * 80)
print("SEX DISTRIBUTION BALANCING REPORT")
print("=" * 80)
print(f"\nOriginal dataset: {len(df)} subjects")
print(f"Balanced dataset: {len(df_balanced)} subjects")
print(f"Subjects removed: {len(df) - len(df_balanced)} ({(len(df) - len(df_balanced))/len(df)*100:.1f}%)")

print("\n" + "=" * 80)
print("PER-AGE-BIN BALANCING DETAILS")
print("=" * 80)
print(balance_report.to_string(index=False))

# Summary statistics
sex_col = 'SEX_ID (1=m, 2=f)' if 'SEX_ID (1=m, 2=f)' in df.columns else 'SEX_ID'

print("\n" + "=" * 80)
print("BALANCED DATASET CHARACTERISTICS")
print("=" * 80)
print(f"\nSex distribution:")
print(f"  Male:   {(df_balanced[sex_col] == 1).sum()} ({(df_balanced[sex_col] == 1).sum()/len(df_balanced)*100:.1f}%)")
print(f"  Female: {(df_balanced[sex_col] == 2).sum()} ({(df_balanced[sex_col] == 2).sum()/len(df_balanced)*100:.1f}%)")

print(f"\nAge statistics:")
print(f"  Original - Mean: {df['AGE'].mean():.1f}, SD: {df['AGE'].std():.1f}, Range: [{df['AGE'].min():.0f}, {df['AGE'].max():.0f}]")
print(f"  Balanced - Mean: {df_balanced['AGE'].mean():.1f}, SD: {df_balanced['AGE'].std():.1f}, Range: [{df_balanced['AGE'].min():.0f}, {df_balanced['AGE'].max():.0f}]")

# Visualize comparison
plot_balanced_comparison(df, df_balanced, age_bin_width=5)

In [None]:
df_balanced

In [None]:
if 'SEX_ID' not in df_balanced.columns:
    print("⚠️  SEX_ID variable not found. Skipping sex-based analyses.")
else:
    print("\n" + "="*80)
    print("SEX-BASED ANALYSIS")
    print("="*80)
    
    # Sex comparison for all features
    sex_comparisons = []
    
    for feature in ALL_FEATURES:
        # Get data for each sex
        male_data = df_balanced[(df_balanced['SEX_ID'] == 1) & df_balanced[feature].notna()][feature]
        female_data = df_balanced[(df_balanced['SEX_ID'] == 2) & df_balanced[feature].notna()][feature]
        
        if len(male_data) < 5 or len(female_data) < 5:
            continue
        
        # T-test and Mann-Whitney U test
        t_stat, t_pval = ttest_ind(male_data, female_data)
        u_stat, u_pval = mannwhitneyu(male_data, female_data, alternative='two-sided')
        
        # Effect size (Cohen's d)
        pooled_std = np.sqrt(((len(male_data)-1)*male_data.std()**2 + (len(female_data)-1)*female_data.std()**2) / (len(male_data)+len(female_data)-2))
        cohens_d = (male_data.mean() - female_data.mean()) / pooled_std
        
        sex_comparisons.append({
            'feature': feature,
            'category': feature_categories[feature],
            'n_male': len(male_data),
            'n_female': len(female_data),
            'male_mean': male_data.mean(),
            'female_mean': female_data.mean(),
            'male_std': male_data.std(),
            'female_std': female_data.std(),
            't_statistic': t_stat,
            't_pvalue': t_pval,
            'mwu_pvalue': u_pval,
            'cohens_d': cohens_d,
            'percent_diff': 100 * (male_data.mean() - female_data.mean()) / female_data.mean()
        })
    
    sex_comp_df = pd.DataFrame(sex_comparisons)
    
    if len(sex_comp_df) > 0:
        # Multiple testing correction
        sex_comp_df['t_pvalue_fdr'] = multipletests(sex_comp_df['t_pvalue'], method='fdr_bh')[1]
        sex_comp_df['mwu_pvalue_fdr'] = multipletests(sex_comp_df['mwu_pvalue'], method='fdr_bh')[1]
        sex_comp_df['significant_ttest'] = sex_comp_df['t_pvalue_fdr'] < 0.05
        sex_comp_df['significant_mwu'] = sex_comp_df['mwu_pvalue_fdr'] < 0.05
        
        # Sort by effect size
        sex_comp_df = sex_comp_df.sort_values('cohens_d', key=abs, ascending=False)
        
        # Summary
        n_sig = sex_comp_df['significant_ttest'].sum()
        print(f"\nSex Comparison Summary:")
        print(f"  Total features tested: {len(sex_comp_df)}")
        print(f"  Significant differences (t-test, FDR<0.05): {n_sig} ({100*n_sig/len(sex_comp_df):.1f}%)")
        
        print(f"\nTop 10 sex differences (by absolute Cohen's d):")
        display(sex_comp_df.head(10)[['feature', 'category', 'male_mean', 'female_mean', 'cohens_d', 't_pvalue_fdr']])
        
        # Save results
        sex_comp_df.to_csv(TABLES_DIR / 'sex_comparisons.csv', index=False)
        print(f"\n✓ Sex comparisons saved to {TABLES_DIR / 'sex_comparisons.csv'}")

In [None]:
if 'SEX_ID' in df_balanced.columns and len(sex_comp_df) > 0:
    # Select top 6 features with largest effect sizes
    top_sex_features = sex_comp_df.head(6)['feature'].tolist()
    
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle('Top Sex Differences in Vessel Features', fontsize=14, fontweight='bold')
    axes = axes.flatten()
    
    def get_significance_stars(p_value):
        """Convert p-value to significance stars"""
        if p_value < 0.001:
            return '***'
        elif p_value < 0.01:
            return '**'
        elif p_value < 0.05:
            return '*'
        else:
            return 'ns'
    
    for idx, feature in enumerate(top_sex_features):
        ax = axes[idx]
        
        # Prepare data
        plot_data = df_balanced[['SEX_ID', feature]].dropna()
        plot_data['Sex'] = plot_data['SEX_ID'].map({1: 'Male', 2: 'Female'})
        
        # Box plot
        sns.boxplot(data=plot_data, x='Sex', y=feature, ax=ax, palette='Set2')
        sns.stripplot(data=plot_data, x='Sex', y=feature, ax=ax, color='black', alpha=0.3, size=3)
        
        # Get stats
        feat_stats = sex_comp_df[sex_comp_df['feature'] == feature].iloc[0]
        p_value = feat_stats['t_pvalue']
        cohens_d = feat_stats['cohens_d']
        
        # Get significance stars
        sig_stars = get_significance_stars(p_value)
        sig_indicator = f" {sig_stars}" if sig_stars != 'ns' else " (ns)"
        
        # Format p-value
        if p_value < 0.001:
            p_text = 'p < 0.001'
        elif p_value < 0.01:
            p_text = f'p = {p_value:.3f}'
        else:
            p_text = f'p = {p_value:.2f}'
        
        # Clean feature name for display
        clean_feature = feature.replace('_', ' ').replace('total ', '').replace('mean ', '').title()
        
        # Title with significance indicator
        ax.set_title(f"{clean_feature}{sig_indicator}\nd = {cohens_d:.3f}, {p_text}", 
                    fontsize=10)
        ax.set_ylabel(feature.replace('_', ' ').title(), fontsize=9)
        ax.set_xlabel('')
        ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'sex_differences_top6.png', dpi=300, bbox_inches='tight')
    plt.show()
    print(f"✓ Sex differences plot saved to {FIGURES_DIR / 'sex_differences_top6.png'}")
    
    # Print legend for significance levels
    print("\nSignificance levels:")
    print("  *** p < 0.001")
    print("  **  p < 0.01")
    print("  *   p < 0.05")
    print("  ns  not significant (p ≥ 0.05)")

---
## 8. Anthropometric Correlations (Height, Weight, BMI)

In [None]:
# Check which anthropometric variables are available
anthro_vars = [v for v in ['HEIGHT', 'WEIGHT', 'BMI'] if v in df.columns]

if len(anthro_vars) == 0:
    print("⚠️  No anthropometric variables found. Skipping anthropometric analyses.")
else:
    print("\n" + "="*80)
    print("ANTHROPOMETRIC CORRELATIONS")
    print("="*80)
    print(f"\nAvailable anthropometric variables: {anthro_vars}")
    
    # Calculate correlations for each anthropometric variable
    anthro_results = {}
    
    for anthro_var in anthro_vars:
        correlations = []
        
        for feature in ALL_FEATURES:
            valid_data = df[[anthro_var, feature]].dropna()
            
            if len(valid_data) < 10:
                continue
            
            r, p = pearsonr(valid_data[anthro_var], valid_data[feature])
            
            correlations.append({
                'feature': feature,
                'category': feature_categories[feature],
                'n': len(valid_data),
                'correlation': r,
                'pvalue': p
            })
        
        corr_df = pd.DataFrame(correlations)
        
        if len(corr_df) > 0:
            # Multiple testing correction
            corr_df['pvalue_fdr'] = multipletests(corr_df['pvalue'], method='fdr_bh')[1]
            corr_df['significant'] = corr_df['pvalue_fdr'] < 0.05
            corr_df = corr_df.sort_values('correlation', key=abs, ascending=False)
            anthro_results[anthro_var] = corr_df
            
            print(f"\n{anthro_var} Correlations:")
            print(f"  Significant correlations (FDR<0.05): {corr_df['significant'].sum()}")
            print(f"  Top 5 correlations:")
            display(corr_df.head(5)[['feature', 'category', 'correlation', 'pvalue_fdr']])
            
            # Save results
            corr_df.to_csv(TABLES_DIR / f'{anthro_var.lower()}_correlations.csv', index=False)
    
    if len(anthro_results) > 0:
        print(f"\n✓ Anthropometric correlations saved to {TABLES_DIR}")

In [None]:
# ============================================================================
# ANTHROPOMETRIC BOXPLOTS BY GROUPS
# ============================================================================

if len(anthro_results) > 0:
    print("\n" + "="*80)
    print("ANTHROPOMETRIC BOXPLOT VISUALIZATIONS")
    print("="*80)
    
    # Set matplotlib parameters to prevent memory issues
    plt.rcParams['figure.max_open_warning'] = 0
    plt.rcParams['agg.path.chunksize'] = 10000
    
    # For each anthropometric variable, create boxplots by groups
    for anthro_var in anthro_vars:
        if anthro_var not in anthro_results:
            continue
            
        corr_df = anthro_results[anthro_var]
        
        print(f"\n" + "-"*80)
        print(f"{anthro_var} GROUP ANALYSIS")
        print("-"*80)
        
        # Create groups based on the anthropometric variable
        if anthro_var == 'HEIGHT':
            # Height groups (in cm)
            anthro_bins = [140, 160, 170, 180, 200]
            anthro_labels = ['<160cm', '160-169cm', '170-179cm', '≥180cm']
            group_name = 'Height Group'
        elif anthro_var == 'WEIGHT':
            # Weight groups (in kg)
            anthro_bins = [40, 60, 75, 90, 150]
            anthro_labels = ['<60kg', '60-74kg', '75-89kg', '≥90kg']
            group_name = 'Weight Group'
        elif anthro_var == 'BMI':
            # BMI categories (WHO classification)
            #anthro_bins = [10, 18.5, 25, 30, 50]
            #anthro_labels = ['Underweight', 'Normal', 'Overweight', 'Obese']
            anthro_bins = [18.5, 25, 30, 50]
            anthro_labels = ['Normal', 'Overweight', 'Obese']
            group_name = 'BMI Category'
        else:
            continue
        
        # Create group column
        df[f'{anthro_var}_group'] = pd.cut(df[anthro_var], bins=anthro_bins, 
                                            labels=anthro_labels, right=False)
        
        # Display group distribution
        group_counts = df[f'{anthro_var}_group'].value_counts().sort_index()
        print(f"\n{group_name} distribution:")
        for group, count in group_counts.items():
            print(f"  {group}: {count} subjects ({100*count/len(df[df[anthro_var].notna()]):.1f}%)")
        
        # ====================================================================
        # Statistical Testing Between Groups
        # ====================================================================
        print(f"\nStatistical testing between {anthro_var} groups...")
        
        anova_results = []
        
        for feature in ALL_FEATURES:
            group_data = []
            for group in anthro_labels:
                data = df[df[f'{anthro_var}_group'] == group][feature].dropna().values
                if len(data) >= 5:
                    group_data.append(data)
            
            if len(group_data) < 2:
                continue
            
            # Perform ANOVA
            f_stat, anova_p = f_oneway(*group_data)
            
            # Perform Kruskal-Wallis
            h_stat, kw_p = kruskal(*group_data)
            
            # Calculate effect size (eta-squared)
            grand_mean = np.mean(np.concatenate(group_data))
            ss_between = sum(len(d) * (np.mean(d) - grand_mean)**2 for d in group_data)
            ss_total = sum(np.sum((d - grand_mean)**2) for d in group_data)
            eta_squared = ss_between / ss_total if ss_total > 0 else 0
            
            anova_results.append({
                'feature': feature,
                'category': feature_categories[feature],
                'f_statistic': f_stat,
                'anova_pvalue': anova_p,
                'kruskal_h': h_stat,
                'kw_pvalue': kw_p,
                'eta_squared': eta_squared
            })
        
        anova_df = pd.DataFrame(anova_results)
        
        if len(anova_df) > 0:
            # Multiple testing correction
            anova_df['anova_p_fdr'] = multipletests(anova_df['anova_pvalue'], method='fdr_bh')[1]
            anova_df['kw_p_fdr'] = multipletests(anova_df['kw_pvalue'], method='fdr_bh')[1]
            anova_df['significant_anova'] = anova_df['anova_p_fdr'] < 0.05
            anova_df['significant_kw'] = anova_df['kw_p_fdr'] < 0.05
            
            anova_df = anova_df.sort_values('eta_squared', ascending=False)
            
            n_sig = anova_df['significant_anova'].sum()
            
            print(f"\nANOVA results:")
            print(f"  Features tested: {len(anova_df)}")
            print(f"  Significant differences (ANOVA, FDR<0.05): {n_sig} ({100*n_sig/len(anova_df):.1f}%)")
            print(f"  Significant differences (Kruskal-Wallis, FDR<0.05): {anova_df['significant_kw'].sum()}")
            
            print(f"\nTop 10 features with largest {anthro_var} group effects:")
            display(anova_df.head(10)[['feature', 'category', 'f_statistic', 'eta_squared', 
                             'anova_p_fdr', 'kw_p_fdr']].round(4))
            
            # Save results
            anova_df.to_csv(TABLES_DIR / f'{anthro_var.lower()}_group_anova.csv', index=False)
            print(f"\n✓ {anthro_var} group ANOVA results saved to {TABLES_DIR / f'{anthro_var.lower()}_group_anova.csv'}")
            
            # ================================================================
            # VISUALIZATION
            # ================================================================
            
            if n_sig > 0:
                top_features = anova_df[anova_df['significant_anova']].head(6)['feature'].tolist()
                
                if len(top_features) > 0:
                    print(f"\nVisualizing top {len(top_features)} {anthro_var}-related features...")
                    
                    # Define consistent color mapping
                    unique_groups = anthro_labels
                    if anthro_var == 'BMI':
                        # Use meaningful colors for BMI
                        group_colors = dict(zip(unique_groups, ['#3498db', '#2ecc71', '#f39c12', '#e74c3c']))
                    elif anthro_var == 'HEIGHT':
                        group_colors = dict(zip(unique_groups, 
                                              sns.color_palette('Blues', n_colors=len(unique_groups))))
                    else:  # WEIGHT
                        group_colors = dict(zip(unique_groups, 
                                              sns.color_palette('Oranges', n_colors=len(unique_groups))))
                    
                    n_plots = min(6, len(top_features))
                    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
                    fig.suptitle(f'Top {group_name} Differences in Vessel Features',
                                fontsize=14, fontweight='bold')
                    axes = axes.flatten()
                    
                    for idx, feature in enumerate(top_features[:n_plots]):
                        ax = axes[idx]
                        
                        plot_data = df[[f'{anthro_var}_group', feature]].dropna()
                        
                        # Order groups naturally
                        group_order = anthro_labels
                        
                        # Create color palette in the correct order for this plot
                        plot_colors = [group_colors[grp] for grp in group_order]
                        
                        # Sample data if too large to prevent memory issues
                        if len(plot_data) > 500:
                            plot_data_sample = plot_data.sample(n=500, random_state=42)
                        else:
                            plot_data_sample = plot_data
                        
                        sns.boxplot(data=plot_data, x=f'{anthro_var}_group', y=feature, ax=ax,
                                   order=group_order, palette=plot_colors)
                        sns.stripplot(data=plot_data_sample, x=f'{anthro_var}_group', y=feature, ax=ax,
                                     order=group_order, color='black', alpha=0.2, size=2)
                        
                        feat_stats = anova_df[anova_df['feature'] == feature].iloc[0]
                        eta_sq = feat_stats['eta_squared']
                        p_val = feat_stats['anova_p_fdr']
                        
                        if p_val < 0.001:
                            sig_stars = '***'
                            p_text = 'p < 0.001'
                        elif p_val < 0.01:
                            sig_stars = '**'
                            p_text = f'p = {p_val:.3f}'
                        elif p_val < 0.05:
                            sig_stars = '*'
                            p_text = f'p = {p_val:.3f}'
                        else:
                            sig_stars = 'ns'
                            p_text = f'p = {p_val:.2f}'
                        
                        clean_feature = feature.replace('_', ' ').replace('total ', '').replace('mean ', '').title()
                        
                        ax.set_title(f"{clean_feature} {sig_stars}\nη² = {eta_sq:.3f}, {p_text}",
                                    fontsize=9)
                        ax.set_ylabel(feature.replace('_', ' ').title(), fontsize=9)
                        ax.set_xlabel(group_name, fontsize=9)
                        ax.tick_params(axis='x', rotation=45, labelsize=8)
                        ax.grid(True, alpha=0.3, axis='y')
                    
                    for idx in range(n_plots, 6):
                        axes[idx].axis('off')
                    
                    plt.tight_layout()
                    plt.savefig(FIGURES_DIR / f'{anthro_var.lower()}_groups_boxplots.png', 
                               dpi=150, bbox_inches='tight')
                    print(f"✓ {anthro_var} group boxplots saved")
                    plt.show()
                    plt.close('all')
    
    print("\n" + "="*80)
    print("✓ ANTHROPOMETRIC BOXPLOT ANALYSIS COMPLETE")
    print("="*80)

In [None]:
# ============================================================================
# AGE DISTRIBUTION ACROSS BMI CATEGORIES
# ============================================================================

if 'BMI' in df.columns and 'BMI_group' in df.columns and 'AGE' in df.columns:
    print("\n" + "="*80)
    print("AGE DISTRIBUTION ACROSS BMI CATEGORIES")
    print("="*80)
    
    # Set matplotlib parameters
    plt.rcParams['figure.max_open_warning'] = 0
    plt.rcParams['agg.path.chunksize'] = 10000
    
    # Get data
    plot_data = df[['BMI_group', 'AGE']].dropna()
    
    print(f"\nAnalyzing age distribution across BMI categories...")
    print(f"Total subjects with both BMI and age data: {len(plot_data)}")
    
    # Calculate summary statistics
    bmi_age_stats = plot_data.groupby('BMI_group')['AGE'].agg(['count', 'mean', 'std', 'min', 'max'])
    print("\nAge statistics by BMI category:")
    display(bmi_age_stats.round(2))
    
    # Statistical test - ANOVA for age differences across BMI groups
    bmi_categories = ['Underweight', 'Normal', 'Overweight', 'Obese']
    age_by_bmi = [plot_data[plot_data['BMI_group'] == cat]['AGE'].dropna().values 
                  for cat in bmi_categories]
    age_by_bmi = [d for d in age_by_bmi if len(d) >= 5]
    
    if len(age_by_bmi) >= 2:
        f_stat, anova_p = f_oneway(*age_by_bmi)
        h_stat, kw_p = kruskal(*age_by_bmi)
        
        # Calculate effect size
        grand_mean = np.mean(np.concatenate(age_by_bmi))
        ss_between = sum(len(d) * (np.mean(d) - grand_mean)**2 for d in age_by_bmi)
        ss_total = sum(np.sum((d - grand_mean)**2) for d in age_by_bmi)
        eta_squared = ss_between / ss_total if ss_total > 0 else 0
        
        print(f"\nStatistical tests:")
        print(f"  ANOVA: F = {f_stat:.3f}, p = {anova_p:.4f}")
        print(f"  Kruskal-Wallis: H = {h_stat:.3f}, p = {kw_p:.4f}")
        print(f"  Effect size (η²): {eta_squared:.3f}")
        
        if anova_p < 0.05:
            print(f"  → Significant age differences across BMI categories")
        else:
            print(f"  → No significant age differences across BMI categories")
    
    # Visualization
    print("\nGenerating age distribution visualizations...")
    
    # Define BMI colors (same as before)
    bmi_labels = ['Underweight', 'Normal', 'Overweight', 'Obese']
    bmi_colors = dict(zip(bmi_labels, ['#3498db', '#2ecc71', '#f39c12', '#e74c3c']))
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    fig.suptitle('Age Distribution Across BMI Categories', fontsize=14, fontweight='bold')
    
    # Plot 1: Box plot
    ax = axes[0]
    plot_colors = [bmi_colors[cat] for cat in bmi_labels]
    
    # Sample if too many points
    if len(plot_data) > 500:
        plot_data_sample = plot_data.sample(n=500, random_state=42)
    else:
        plot_data_sample = plot_data
    
    sns.boxplot(data=plot_data, x='BMI_group', y='AGE', ax=ax,
               order=bmi_labels, palette=plot_colors)
    sns.stripplot(data=plot_data_sample, x='BMI_group', y='AGE', ax=ax,
                 order=bmi_labels, color='black', alpha=0.2, size=3)
    
    # Add significance annotation if significant
    if 'anova_p' in locals() and anova_p < 0.05:
        if anova_p < 0.001:
            sig_text = '***'
        elif anova_p < 0.01:
            sig_text = '**'
        elif anova_p < 0.05:
            sig_text = '*'
        ax.text(0.5, 0.98, f'ANOVA: p = {anova_p:.4f} {sig_text}', 
               ha='center', va='top', transform=ax.transAxes,
               bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    ax.set_xlabel('BMI Category', fontsize=11)
    ax.set_ylabel('Age (years)', fontsize=11)
    ax.set_title('Age by BMI Category', fontsize=12, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    ax.tick_params(axis='x', rotation=45)
    
    # Plot 2: Violin plot with quartiles
    ax = axes[1]
    
    parts = ax.violinplot([plot_data[plot_data['BMI_group'] == cat]['AGE'].dropna().values 
                           for cat in bmi_labels],
                          positions=range(len(bmi_labels)),
                          showmeans=True, showmedians=True, showextrema=True)
    
    # Color the violins
    for i, pc in enumerate(parts['bodies']):
        pc.set_facecolor(plot_colors[i])
        pc.set_alpha(0.7)
        pc.set_edgecolor('black')
        pc.set_linewidth(1)
    
    # Style the lines
    for partname in ('cbars', 'cmins', 'cmaxes', 'cmedians', 'cmeans'):
        if partname in parts:
            vp = parts[partname]
            vp.set_edgecolor('black')
            vp.set_linewidth(1.5)
    
    ax.set_xticks(range(len(bmi_labels)))
    ax.set_xticklabels(bmi_labels, rotation=45)
    ax.set_xlabel('BMI Category', fontsize=11)
    ax.set_ylabel('Age (years)', fontsize=11)
    ax.set_title('Age Distribution (Violin Plot)', fontsize=12, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'age_by_bmi_categories.png', dpi=150, bbox_inches='tight')
    print(f"✓ Age by BMI visualization saved to {FIGURES_DIR / 'age_by_bmi_categories.png'}")
    plt.show()
    plt.close('all')
    
    # Additional plot: Age × BMI scatter with category colors
    print("\nGenerating age-BMI scatter plot...")
    
    fig, ax = plt.subplots(figsize=(10, 6))
    
    scatter_data = df[['BMI', 'AGE', 'BMI_group']].dropna()
    
    for bmi_cat in bmi_labels:
        cat_data = scatter_data[scatter_data['BMI_group'] == bmi_cat]
        ax.scatter(cat_data['AGE'], cat_data['BMI'], 
                  label=bmi_cat, alpha=0.5, s=30, color=bmi_colors[bmi_cat])
    
    # Add BMI category threshold lines
    ax.axhline(18.5, color='gray', linestyle='--', linewidth=1, alpha=0.5)
    ax.axhline(25, color='gray', linestyle='--', linewidth=1, alpha=0.5)
    ax.axhline(30, color='gray', linestyle='--', linewidth=1, alpha=0.5)
    
    # Add text labels for thresholds
    ax.text(scatter_data['AGE'].max() * 0.98, 18.5, 'Underweight/Normal', 
           va='bottom', ha='right', fontsize=8, style='italic')
    ax.text(scatter_data['AGE'].max() * 0.98, 25, 'Normal/Overweight', 
           va='bottom', ha='right', fontsize=8, style='italic')
    ax.text(scatter_data['AGE'].max() * 0.98, 30, 'Overweight/Obese', 
           va='bottom', ha='right', fontsize=8, style='italic')
    
    # Calculate and plot trend line
    valid_data = scatter_data[['AGE', 'BMI']].dropna()
    if len(valid_data) > 10:
        z = np.polyfit(valid_data['AGE'], valid_data['BMI'], 1)
        p = np.poly1d(z)
        age_range = np.linspace(valid_data['AGE'].min(), valid_data['AGE'].max(), 100)
        ax.plot(age_range, p(age_range), 'k--', linewidth=2, alpha=0.8, label='Trend')
        
        # Calculate correlation
        r_corr, p_corr = pearsonr(valid_data['AGE'], valid_data['BMI'])
        ax.text(0.02, 0.98, f'r = {r_corr:.3f}, p = {p_corr:.4f}',
               transform=ax.transAxes, va='top', ha='left',
               bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    ax.set_xlabel('Age (years)', fontsize=11)
    ax.set_ylabel('BMI (kg/m²)', fontsize=11)
    ax.set_title('Age vs BMI with Category Classification', fontsize=13, fontweight='bold')
    ax.legend(loc='upper right', framealpha=0.9)
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'age_bmi_scatter.png', dpi=150, bbox_inches='tight')
    print(f"✓ Age-BMI scatter plot saved to {FIGURES_DIR / 'age_bmi_scatter.png'}")
    plt.show()
    plt.close('all')
    
    # Save summary statistics
    bmi_age_stats.to_csv(TABLES_DIR / 'age_by_bmi_statistics.csv')
    print(f"\n✓ Age by BMI statistics saved to {TABLES_DIR / 'age_by_bmi_statistics.csv'}")
    
    print("\n" + "="*80)
    print("✓ AGE-BMI ANALYSIS COMPLETE")
    print("="*80)
    
else:
    print("\n⚠️  Skipping age-BMI analysis - BMI groups or age data not available")

---
## 9. Multi-Center Analysis

In [None]:
df

In [None]:
# Try to identify site/center from subject ID
# IXI naming convention: IXIXXX where XXX is a number
# Different centers may have different ID ranges


# Build the lookup dictionary once
site_lookup = create_site_lookup(COMPLETE_IXI)

# Apply the lookup (vectorized operation)
df['site'] = df['subject_id'].map(site_lookup).fillna('Unknown')

# Remove 'Unknown' site entries
df = df[df['site'] != 'Unknown']
site_counts = df['site'].value_counts()

In [None]:
df

In [None]:
print("\n" + "="*80)
print("MULTI-CENTER ANALYSIS")
print("="*80)
print(f"\nSite distribution:")
for site, count in site_counts.items():
    print(f"  {site}: {count} subjects ({100*count/len(df):.1f}%)")


if len(site_counts) > 1:
    # Perform ANOVA for each feature across sites
    site_comparisons = []
    
    for feature in ALL_FEATURES:
        # Get data for each site
        site_data = [df[df['site'] == site][feature].dropna() for site in site_counts.index]
        
        # Remove groups with <5 samples
        site_data = [d for d in site_data if len(d) >= 5]
        
        if len(site_data) < 2:
            continue
        
        # ANOVA and Kruskal-Wallis test
        f_stat, f_pval = f_oneway(*site_data)
        h_stat, h_pval = kruskal(*site_data)
        
        # Effect size (eta-squared)
        grand_mean = np.mean(np.concatenate(site_data))
        ss_between = sum(len(d) * (np.mean(d) - grand_mean)**2 for d in site_data)
        ss_total = sum(np.sum((d - grand_mean)**2) for d in site_data)
        eta_squared = ss_between / ss_total if ss_total > 0 else 0
        
        site_comparisons.append({
            'feature': feature,
            'category': feature_categories[feature],
            'f_statistic': f_stat,
            'anova_pvalue': f_pval,
            'kw_pvalue': h_pval,
            'eta_squared': eta_squared
        })
    
    site_comp_df = pd.DataFrame(site_comparisons)
    
    if len(site_comp_df) > 0:
        # Multiple testing correction
        site_comp_df['anova_pvalue_fdr'] = multipletests(site_comp_df['anova_pvalue'], method='fdr_bh')[1]
        site_comp_df['significant'] = site_comp_df['anova_pvalue_fdr'] < 0.05
        site_comp_df = site_comp_df.sort_values('eta_squared', ascending=False)
        
        n_sig = site_comp_df['significant'].sum()
        print(f"\nMulti-center comparison:")
        print(f"  Features tested: {len(site_comp_df)}")
        print(f"  Significant site differences (FDR<0.05): {n_sig} ({100*n_sig/len(site_comp_df):.1f}%)")
        
        print(f"\nTop 10 features with site differences (by eta-squared):")
        display(site_comp_df.head(10)[['feature', 'category', 'eta_squared', 'anova_pvalue_fdr']])
        
        site_comp_df.to_csv(TABLES_DIR / 'site_comparisons.csv', index=False)
        print(f"\n✓ Site comparisons saved to {TABLES_DIR / 'site_comparisons.csv'}")
else:
    print("\n⚠️  Insufficient data for multi-center analysis")

In [None]:
df

In [None]:
import scipy.stats as stats
from scipy.stats import f_oneway, kruskal, levene, shapiro
from statsmodels.stats.multicomp import pairwise_tukeyhsd
from statsmodels.formula.api import ols
from statsmodels.stats.anova import anova_lm
import statsmodels.api as sm

# ============================================================================
# COMPREHENSIVE MULTI-CENTER ANALYSIS
# ============================================================================
# Add site information
print("=" * 80)
print("COMPREHENSIVE MULTI-CENTER ANALYSIS")
print("=" * 80)


In [None]:
# ============================================================================
# 1. DEMOGRAPHIC CHARACTERISTICS BY SITE
# ============================================================================

print("\n" + "=" * 80)
print("1. DEMOGRAPHIC CHARACTERISTICS BY SITE")
print("=" * 80)

site_counts = df['site'].value_counts()
print(f"\nTotal subjects across sites: {len(df)}")
print(f"\nSite distribution:")
for site, count in site_counts.items():
    print(f"  {site}: {count} subjects ({100*count/len(df):.1f}%)")

# Demographic comparison by site
sex_col = 'SEX_ID (1=m, 2=f)' if 'SEX_ID (1=m, 2=f)' in df.columns else 'SEX_ID'

demo_by_site = df.groupby('site').agg({
    'subject_id': 'count',
    'AGE': ['mean', 'std', 'min', 'max'],
    sex_col: lambda x: (x == 1).sum(),  # Count males
    'HEIGHT': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan,
    'WEIGHT': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan,
    'BMI': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan
})

demo_by_site.columns = ['N', 'Age_Mean', 'Age_SD', 'Age_Min', 'Age_Max', 
                        'N_Males', 'Height_Mean', 'Weight_Mean', 'BMI_Mean']
demo_by_site['Pct_Male'] = 100 * demo_by_site['N_Males'] / demo_by_site['N']

print("\nDemographic characteristics by site:")
display(demo_by_site.round(2))

# Test for demographic differences across sites
print("\n" + "-" * 80)
print("Testing for demographic differences across sites:")
print("-" * 80)

# Age differences
age_by_site = [df[df['site'] == site]['AGE'].dropna() for site in site_counts.index]
f_age, p_age = f_oneway(*age_by_site)
print(f"\nAge difference across sites:")
print(f"  ANOVA F-statistic: {f_age:.3f}, p-value: {p_age:.4f}")

# Sex distribution differences (Chi-square)
sex_site_table = pd.crosstab(df['site'], df[sex_col])
chi2, p_chi2, dof, expected = stats.chi2_contingency(sex_site_table)
print(f"\nSex distribution across sites:")
print(f"  Chi-square: {chi2:.3f}, p-value: {p_chi2:.4f}")

In [None]:
# ============================================================================
# 2. SITE EFFECTS ON VESSEL FEATURES
# ============================================================================

print("\n" + "=" * 80)
print("2. SITE EFFECTS ON VESSEL FEATURES")
print("=" * 80)

site_comparisons = []

for feature in ALL_FEATURES:
    # Get data for each site
    site_data = [df[df['site'] == site][feature].dropna() for site in site_counts.index]
    site_data = [d for d in site_data if len(d) >= 5]
    
    if len(site_data) < 2:
        continue
    
    # One-way ANOVA
    f_stat, f_pval = f_oneway(*site_data)
    
    # Kruskal-Wallis (non-parametric alternative)
    h_stat, h_pval = kruskal(*site_data)
    
    # Levene's test for homogeneity of variance
    levene_stat, levene_pval = levene(*site_data)
    
    # Effect size (eta-squared)
    grand_mean = np.mean(np.concatenate(site_data))
    ss_between = sum(len(d) * (np.mean(d) - grand_mean)**2 for d in site_data)
    ss_total = sum(np.sum((d - grand_mean)**2) for d in site_data)
    eta_squared = ss_between / ss_total if ss_total > 0 else 0
    
    # Omega-squared (less biased than eta-squared)
    n_total = sum(len(d) for d in site_data)
    k_groups = len(site_data)
    omega_squared = (ss_between - (k_groups - 1) * (ss_total - ss_between) / (n_total - k_groups)) / (ss_total + (ss_total - ss_between) / (n_total - k_groups))
    omega_squared = max(0, omega_squared)  # Can't be negative
    
    # Calculate mean and SD for each site
    site_means = {site: df[df['site'] == site][feature].mean() 
                  for site in site_counts.index}
    site_stds = {site: df[df['site'] == site][feature].std() 
                 for site in site_counts.index}
    
    site_comparisons.append({
        'feature': feature,
        'category': feature_categories[feature],
        'f_statistic': f_stat,
        'anova_pvalue': f_pval,
        'kw_statistic': h_stat,
        'kw_pvalue': h_pval,
        'levene_pvalue': levene_pval,
        'eta_squared': eta_squared,
        'omega_squared': omega_squared,
        **{f'{site}_mean': site_means[site] for site in site_counts.index},
        **{f'{site}_std': site_stds[site] for site in site_counts.index}
    })

site_comp_df = pd.DataFrame(site_comparisons)

if len(site_comp_df) > 0:
    # Multiple testing correction
    site_comp_df['anova_pvalue_fdr'] = multipletests(site_comp_df['anova_pvalue'], method='fdr_bh')[1]
    site_comp_df['kw_pvalue_fdr'] = multipletests(site_comp_df['kw_pvalue'], method='fdr_bh')[1]
    site_comp_df['significant_anova'] = site_comp_df['anova_pvalue_fdr'] < 0.05
    site_comp_df['significant_kw'] = site_comp_df['kw_pvalue_fdr'] < 0.05
    site_comp_df = site_comp_df.sort_values('omega_squared', ascending=False)
    
    n_sig_anova = site_comp_df['significant_anova'].sum()
    n_sig_kw = site_comp_df['significant_kw'].sum()
    
    print(f"\nFeatures tested: {len(site_comp_df)}")
    print(f"Significant site differences (ANOVA, FDR<0.05): {n_sig_anova} ({100*n_sig_anova/len(site_comp_df):.1f}%)")
    print(f"Significant site differences (Kruskal-Wallis, FDR<0.05): {n_sig_kw} ({100*n_sig_kw/len(site_comp_df):.1f}%)")
    
    print(f"\nTop 15 features with largest site effects (by ω²):")
    display_cols = ['feature', 'category', 'omega_squared', 'eta_squared', 
                    'anova_pvalue_fdr', 'kw_pvalue_fdr']
    display(site_comp_df.head(15)[display_cols].round(4))
    
    # Save detailed results
    site_comp_df.to_csv(TABLES_DIR / 'site_comparisons_detailed.csv', index=False)
    print(f"\n✓ Detailed site comparisons saved")


In [None]:
# ============================================================================
# 3. POST-HOC PAIRWISE COMPARISONS (for significant features)
# ============================================================================

print("\n" + "=" * 80)
print("3. POST-HOC PAIRWISE COMPARISONS")
print("=" * 80)

# Select top 5 features with largest site effects
top_site_features = site_comp_df[site_comp_df['significant_anova']].head(5)['feature'].tolist()

if len(top_site_features) > 0:
    print(f"\nPerforming Tukey HSD post-hoc tests for top {len(top_site_features)} features...")
    
    posthoc_results = []
    
    for feature in top_site_features:
        # Prepare data
        data_for_tukey = df[['site', feature]].dropna()
        
        # Tukey HSD
        tukey = pairwise_tukeyhsd(endog=data_for_tukey[feature], 
                                   groups=data_for_tukey['site'], 
                                   alpha=0.05)
        
        # Parse results
        tukey_df = pd.DataFrame(data=tukey.summary().data[1:], 
                               columns=tukey.summary().data[0])
        tukey_df['feature'] = feature
        posthoc_results.append(tukey_df)
        
        print(f"\n{feature}:")
        print(tukey)
    
    # Combine all post-hoc results
    posthoc_df = pd.concat(posthoc_results, ignore_index=True)
    posthoc_df.to_csv(TABLES_DIR / 'site_posthoc_tukey.csv', index=False)
    print(f"\n✓ Post-hoc comparisons saved")
else:
    print("\nNo significant site differences found for post-hoc testing")

In [None]:
# ============================================================================
# 4. SITE AS COVARIATE IN AGE-VESSEL RELATIONSHIPS
# ============================================================================

print("\n" + "=" * 80)
print("4. SITE AS COVARIATE IN AGE-VESSEL RELATIONSHIPS")
print("=" * 80)

print("\nTesting whether site affects age-vessel correlations...")

# Test interaction between age and site for top features
site_age_interactions = []

for feature in ALL_FEATURES[:20]:  # Test top 20 for computational efficiency
    # Prepare data
    model_data = df[['AGE', 'site', sex_col, feature]].dropna()
    
    if len(model_data) < 30:
        continue
    
    # Model 1: Age + Sex (baseline)
    formula1 = f'Q("{feature}") ~ AGE + C(Q("{sex_col}"))'
    model1 = ols(formula1, data=model_data).fit()
    
    # Model 2: Age + Sex + Site (main effect)
    formula2 = f'Q("{feature}") ~ AGE + C(Q("{sex_col}")) + C(site)'
    model2 = ols(formula2, data=model_data).fit()
    
    # Model 3: Age + Sex + Site + Age×Site (interaction)
    formula3 = f'Q("{feature}") ~ AGE * C(site) + C(Q("{sex_col}"))'
    model3 = ols(formula3, data=model_data).fit()
    
    # F-test for site main effect
    f_site = ((model1.ssr - model2.ssr) / (model2.df_resid - model1.df_resid)) / (model2.ssr / model2.df_resid)
    p_site = 1 - stats.f.cdf(f_site, model2.df_resid - model1.df_resid, model2.df_resid)
    
    # F-test for age×site interaction
    f_interaction = ((model2.ssr - model3.ssr) / (model3.df_resid - model2.df_resid)) / (model3.ssr / model3.df_resid)
    p_interaction = 1 - stats.f.cdf(f_interaction, model3.df_resid - model2.df_resid, model3.df_resid)
    
    site_age_interactions.append({
        'feature': feature,
        'category': feature_categories[feature],
        'r2_baseline': model1.rsquared,
        'r2_with_site': model2.rsquared,
        'r2_with_interaction': model3.rsquared,
        'delta_r2_site': model2.rsquared - model1.rsquared,
        'delta_r2_interaction': model3.rsquared - model2.rsquared,
        'p_site_effect': p_site,
        'p_age_site_interaction': p_interaction,
        'age_coef_model1': model1.params['AGE'],
        'age_coef_model2': model2.params['AGE']
    })

site_age_df = pd.DataFrame(site_age_interactions)

if len(site_age_df) > 0:
    # Multiple testing correction
    site_age_df['p_site_fdr'] = multipletests(site_age_df['p_site_effect'], method='fdr_bh')[1]
    site_age_df['p_interaction_fdr'] = multipletests(site_age_df['p_age_site_interaction'], method='fdr_bh')[1]
    site_age_df['significant_site'] = site_age_df['p_site_fdr'] < 0.05
    site_age_df['significant_interaction'] = site_age_df['p_interaction_fdr'] < 0.05
    
    site_age_df = site_age_df.sort_values('delta_r2_site', ascending=False)
    
    n_sig_site = site_age_df['significant_site'].sum()
    n_sig_interaction = site_age_df['significant_interaction'].sum()
    
    print(f"\nFeatures tested: {len(site_age_df)}")
    print(f"Features with significant site effect: {n_sig_site} ({100*n_sig_site/len(site_age_df):.1f}%)")
    print(f"Features with significant Age×Site interaction: {n_sig_interaction} ({100*n_sig_interaction/len(site_age_df):.1f}%)")
    
    print(f"\nTop 10 features where site explains additional variance:")
    display_cols = ['feature', 'category', 'r2_baseline', 'r2_with_site', 
                    'delta_r2_site', 'p_site_fdr']
    display(site_age_df.head(10)[display_cols].round(4))
    
    if n_sig_interaction > 0:
        print(f"\nFeatures with significant Age×Site interaction:")
        interaction_features = site_age_df[site_age_df['significant_interaction']]
        display(interaction_features[['feature', 'delta_r2_interaction', 'p_interaction_fdr']].round(4))
    
    site_age_df.to_csv(TABLES_DIR / 'site_age_interactions.csv', index=False)
    print(f"\n✓ Site-age interaction analysis saved")

In [None]:
# ============================================================================
# 5. VISUALIZATIONS
# ============================================================================

print("\n" + "=" * 80)
print("5. GENERATING VISUALIZATIONS")
print("=" * 80)

# 5.0: Site demographics overview (NEW)
if 'site' in df.columns:
    print("\nGenerating site demographics visualizations...")
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('Site Demographics Overview', fontsize=14, fontweight='bold')
    
    # 1. Site distribution
    ax = axes[0, 0]
    site_counts = df['site'].value_counts().sort_values(ascending=False)
    bars = ax.barh(range(len(site_counts)), site_counts.values, 
                  color='steelblue', alpha=0.7, edgecolor='black')
    ax.set_yticks(range(len(site_counts)))
    ax.set_yticklabels(site_counts.index, fontsize=9)
    ax.set_xlabel('Number of Subjects', fontsize=10)
    ax.set_title('Sample Size by Site', fontsize=11, fontweight='bold')
    ax.grid(axis='x', alpha=0.3)
    
    # Add counts on bars
    for i, v in enumerate(site_counts.values):
        ax.text(v + 1, i, f'n={int(v)}', va='center', fontsize=9)
    
    # 2. Age distribution by site
    ax = axes[0, 1]
    if 'AGE' in df.columns:
        plot_data = df[['site', 'AGE']].dropna()
        site_order = site_counts.index.tolist()
        sns.boxplot(data=plot_data, y='site', x='AGE', ax=ax,
                   order=site_order, palette='Set2')
        ax.set_xlabel('Age (years)', fontsize=10)
        ax.set_ylabel('')
        ax.set_title('Age Distribution by Site', fontsize=11, fontweight='bold')
        ax.grid(axis='x', alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'Age data not available', ha='center', va='center',
               transform=ax.transAxes)
        ax.axis('off')
    
    # 3. Sex distribution by site
    ax = axes[1, 0]
    sex_col = 'SEX_ID (1=m, 2=f)' if 'SEX_ID (1=m, 2=f)' in df.columns else 'SEX_ID'
    if sex_col in df.columns:
        sex_by_site = df.groupby('site')[sex_col].value_counts(normalize=True).unstack()
        sex_by_site = sex_by_site.reindex(site_order)
        sex_by_site.plot(kind='barh', stacked=True, ax=ax, 
                        color=['#ff7f0e', '#ff69b4'], alpha=0.8)
        ax.set_xlabel('Proportion', fontsize=10)
        ax.set_ylabel('')
        ax.set_title('Sex Distribution by Site', fontsize=11, fontweight='bold')
        ax.legend(['Male', 'Female'], fontsize=9)
        ax.grid(axis='x', alpha=0.3)
    else:
        ax.text(0.5, 0.5, 'Sex data not available', ha='center', va='center',
               transform=ax.transAxes)
        ax.axis('off')
    
    # 4. BMI by site (if available)
    ax = axes[1, 1]
    if 'BMI' in df.columns:
        plot_data_bmi = df[df['BMI'] > 0][['site', 'BMI']].dropna()
        if len(plot_data_bmi) > 0:
            sns.boxplot(data=plot_data_bmi, y='site', x='BMI', ax=ax,
                       order=site_order, palette='Set3')
            ax.axvline(25, color='orange', linestyle='--', linewidth=2, alpha=0.7, label='Overweight')
            ax.axvline(30, color='red', linestyle='--', linewidth=2, alpha=0.7, label='Obese')
            ax.set_xlabel('BMI (kg/m²)', fontsize=10)
            ax.set_ylabel('')
            ax.set_title('BMI Distribution by Site', fontsize=11, fontweight='bold')
            ax.legend(fontsize=8)
            ax.grid(axis='x', alpha=0.3)
        else:
            ax.text(0.5, 0.5, 'Insufficient BMI data', ha='center', va='center',
                   transform=ax.transAxes)
            ax.axis('off')
    else:
        ax.text(0.5, 0.5, 'BMI data not available', ha='center', va='center',
               transform=ax.transAxes)
        ax.axis('off')
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'site_demographics.png', dpi=300, bbox_inches='tight')
    print("✓ Site demographics plot saved")
    plt.show()


In [None]:
# ============================================================================
# 5. VISUALIZATIONS
# ============================================================================

print("\n" + "=" * 80)
print("5. GENERATING VISUALIZATIONS")
print("=" * 80)

# 5.1: Site effect sizes visualization
if len(site_comp_df) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Effect size distribution
    ax = axes[0]
    ax.hist(site_comp_df['omega_squared'], bins=30, edgecolor='black', alpha=0.7)
    ax.axvline(0.01, color='orange', linestyle='--', label='Small effect (0.01)')
    ax.axvline(0.06, color='red', linestyle='--', label='Medium effect (0.06)')
    ax.axvline(0.14, color='darkred', linestyle='--', label='Large effect (0.14)')
    ax.set_xlabel('Omega-squared (ω²)', fontsize=11)
    ax.set_ylabel('Number of Features', fontsize=11)
    ax.set_title('Distribution of Site Effect Sizes', fontsize=12, fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3, axis='y')
    
    # Volcano plot
    ax = axes[1]
    site_comp_df['-log10_p'] = -np.log10(site_comp_df['anova_pvalue_fdr'])
    colors = ['red' if sig else 'gray' for sig in site_comp_df['significant_anova']]
    ax.scatter(site_comp_df['omega_squared'], site_comp_df['-log10_p'], 
              c=colors, alpha=0.6, s=30)
    ax.axhline(-np.log10(0.05), color='blue', linestyle='--', label='FDR = 0.05')
    ax.axvline(0.06, color='orange', linestyle='--', label='Medium effect')
    ax.set_xlabel('Effect Size (ω²)', fontsize=11)
    ax.set_ylabel('-log₁₀(FDR-adjusted p-value)', fontsize=11)
    ax.set_title('Site Effect Volcano Plot', fontsize=12, fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'site_effects_overview.png', dpi=300, bbox_inches='tight')
    print("✓ Site effects overview saved")
    plt.show()

# 5.2: Top site differences visualization
if len(top_site_features) > 0:
    n_features = min(6, len(top_site_features))
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle('Top Site Differences in Vessel Features', fontsize=14, fontweight='bold')
    axes = axes.flatten()
    
    for idx, feature in enumerate(top_site_features[:n_features]):
        ax = axes[idx]
        
        # Prepare data
        plot_data = df[['site', feature]].dropna()
        
        # Box plot
        sns.boxplot(data=plot_data, x='site', y=feature, ax=ax, palette='Set2')
        sns.stripplot(data=plot_data, x='site', y=feature, ax=ax, 
                     color='black', alpha=0.3, size=3)
        
        # Get stats
        feat_stats = site_comp_df[site_comp_df['feature'] == feature].iloc[0]
        omega_sq = feat_stats['omega_squared']
        p_val = feat_stats['anova_pvalue_fdr']
        
        # Significance stars
        if p_val < 0.001:
            sig_stars = '***'
            p_text = 'p < 0.001'
        elif p_val < 0.01:
            sig_stars = '**'
            p_text = f'p = {p_val:.3f}'
        elif p_val < 0.05:
            sig_stars = '*'
            p_text = f'p = {p_val:.3f}'
        else:
            sig_stars = 'ns'
            p_text = f'p = {p_val:.2f}'
        
        clean_feature = feature.replace('_', ' ').replace('total ', '').replace('mean ', '').title()
        
        ax.set_title(f"{clean_feature} {sig_stars}\nω² = {omega_sq:.3f}, {p_text}", 
                    fontsize=10)
        ax.set_ylabel(feature.replace('_', ' ').title(), fontsize=9)
        ax.set_xlabel('Site', fontsize=9)
        ax.grid(True, alpha=0.3, axis='y')
    
    # Hide unused subplots
    for idx in range(n_features, 6):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / 'site_differences_top6.png', dpi=300, bbox_inches='tight')
    print("✓ Site differences visualization saved")
    plt.show()

# 5.3: Age-vessel relationships by site (if interactions found)
if len(site_age_df) > 0 and site_age_df['significant_interaction'].any():
    interaction_features = site_age_df[site_age_df['significant_interaction']].head(4)['feature'].tolist()
    
    if len(interaction_features) > 0:
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        fig.suptitle('Age-Vessel Relationships by Site (Significant Interactions)', 
                    fontsize=14, fontweight='bold')
        axes = axes.flatten()
        
        for idx, feature in enumerate(interaction_features[:4]):
            ax = axes[idx]
            
            for site in site_counts.index:
                site_data = df[df['site'] == site]
                ax.scatter(site_data['AGE'], site_data[feature], 
                          label=site, alpha=0.5, s=20)
                
                # Fit line for each site
                valid_data = site_data[['AGE', feature]].dropna()
                if len(valid_data) > 10:
                    z = np.polyfit(valid_data['AGE'], valid_data[feature], 1)
                    p = np.poly1d(z)
                    age_range = np.linspace(valid_data['AGE'].min(), valid_data['AGE'].max(), 100)
                    ax.plot(age_range, p(age_range), linewidth=2)
            
            clean_feature = feature.replace('_', ' ').replace('total ', '').replace('mean ', '').title()
            ax.set_xlabel('Age (years)', fontsize=10)
            ax.set_ylabel(clean_feature, fontsize=10)
            ax.set_title(clean_feature, fontsize=11, fontweight='bold')
            ax.legend(fontsize=8)
            ax.grid(alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(FIGURES_DIR / 'age_site_interactions.png', dpi=300, bbox_inches='tight')
        print("✓ Age×Site interaction visualization saved")
        plt.show()

In [None]:
print("\n" + "=" * 80)
print("MULTI-CENTER ANALYSIS COMPLETE")
print("=" * 80)
print("\nKey findings:")
print(f"1. Site effects detected in {n_sig_anova}/{len(site_comp_df)} features")
print(f"2. Site explains additional variance in {n_sig_site} age-vessel relationships")
if 'n_sig_interaction' in locals():
    print(f"3. Age×Site interactions found in {n_sig_interaction} features")
print("\nRecommendations for your paper:")
print("- Include site as covariate in all age-vessel analyses")
print("- Report site effects in supplementary materials")
print("- Discuss multi-center design as strength (generalizability)")

---
## 9.2 Ethnic Analysis

In [None]:
# ============================================================================
# ETHNICITY ANALYSIS
# ============================================================================

print("=" * 80)
print("ETHNICITY ANALYSIS")
print("=" * 80)

if 'ETHNIC_ID' in df.columns:
    # Remove missing values and check distribution
    df_ethnic = df[df['ETHNIC_ID'].notna()].copy()
    ethnic_counts = df_ethnic['ETHNIC_ID'].value_counts().sort_index()
    
    print(f"\nEthnicity distribution:")
    print(f"  Total with ethnicity data: {len(df_ethnic)} ({100*len(df_ethnic)/len(df):.1f}% of cohort)")
    print(f"  Missing ethnicity data: {df['ETHNIC_ID'].isna().sum()} ({100*df['ETHNIC_ID'].isna().sum()/len(df):.1f}%)")
    
    print(f"\nEthnicity categories:")
    for ethnic_id, count in ethnic_counts.items():
        print(f"  Category {int(ethnic_id)}: {count} subjects ({100*count/len(df_ethnic):.1f}%)")
    
    # Check if we have sufficient data for analysis (at least 2 groups with n≥10)
    valid_groups = ethnic_counts[ethnic_counts >= 10]
    
    if len(valid_groups) >= 2:
        print(f"\n✓ Sufficient data for analysis: {len(valid_groups)} groups with n≥10")
        
        # Filter to valid groups only
        df_ethnic = df_ethnic[df_ethnic['ETHNIC_ID'].isin(valid_groups.index)].copy()
        
        # Check demographic differences by ethnicity
        print("\n" + "-" * 80)
        print("Demographic characteristics by ethnicity:")
        print("-" * 80)
        
        sex_col = 'SEX_ID (1=m, 2=f)' if 'SEX_ID (1=m, 2=f)' in df_ethnic.columns else 'SEX_ID'
        
        ethnic_demo = df_ethnic.groupby('ETHNIC_ID').agg({
            'subject_id': 'count',
            'AGE': ['mean', 'std', 'min', 'max'],
            sex_col: lambda x: (x == 1).sum(),
            'HEIGHT': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan,
            'WEIGHT': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan,
            'BMI': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan
        })
        
        ethnic_demo.columns = ['N', 'Age_Mean', 'Age_SD', 'Age_Min', 'Age_Max',
                               'N_Males', 'Height_Mean', 'Weight_Mean', 'BMI_Mean']
        ethnic_demo['Pct_Male'] = 100 * ethnic_demo['N_Males'] / ethnic_demo['N']
        
        display(ethnic_demo.round(2))
        
        # Test for age differences
        age_by_ethnic = [df_ethnic[df_ethnic['ETHNIC_ID'] == eid]['AGE'].dropna() 
                        for eid in valid_groups.index]
        f_age, p_age = f_oneway(*age_by_ethnic)
        print(f"\nAge difference across ethnicities: F={f_age:.3f}, p={p_age:.4f}")
        
        # Ethnicity effects on vessel features
        print("\n" + "-" * 80)
        print("Ethnicity effects on vessel features:")
        print("-" * 80)
        
        ethnic_comparisons = []
        
        for feature in ALL_FEATURES:
            # Get data for each ethnicity
            ethnic_data = [df_ethnic[df_ethnic['ETHNIC_ID'] == eid][feature].dropna() 
                          for eid in valid_groups.index]
            ethnic_data = [d for d in ethnic_data if len(d) >= 5]
            
            if len(ethnic_data) < 2:
                continue
            
            # ANOVA and Kruskal-Wallis
            f_stat, f_pval = f_oneway(*ethnic_data)
            h_stat, h_pval = kruskal(*ethnic_data)
            
            # Levene's test
            levene_stat, levene_pval = levene(*ethnic_data)
            
            # Effect sizes
            grand_mean = np.mean(np.concatenate(ethnic_data))
            ss_between = sum(len(d) * (np.mean(d) - grand_mean)**2 for d in ethnic_data)
            ss_total = sum(np.sum((d - grand_mean)**2) for d in ethnic_data)
            eta_squared = ss_between / ss_total if ss_total > 0 else 0
            
            n_total = sum(len(d) for d in ethnic_data)
            k_groups = len(ethnic_data)
            omega_squared = (ss_between - (k_groups - 1) * (ss_total - ss_between) / (n_total - k_groups)) / \
                           (ss_total + (ss_total - ss_between) / (n_total - k_groups))
            omega_squared = max(0, omega_squared)
            
            ethnic_comparisons.append({
                'feature': feature,
                'category': feature_categories[feature],
                'f_statistic': f_stat,
                'anova_pvalue': f_pval,
                'kw_statistic': h_stat,
                'kw_pvalue': h_pval,
                'levene_pvalue': levene_pval,
                'eta_squared': eta_squared,
                'omega_squared': omega_squared
            })
        
        ethnic_comp_df = pd.DataFrame(ethnic_comparisons)
        
        if len(ethnic_comp_df) > 0:
            # Multiple testing correction
            ethnic_comp_df['anova_pvalue_fdr'] = multipletests(ethnic_comp_df['anova_pvalue'], method='fdr_bh')[1]
            ethnic_comp_df['kw_pvalue_fdr'] = multipletests(ethnic_comp_df['kw_pvalue'], method='fdr_bh')[1]
            ethnic_comp_df['significant_anova'] = ethnic_comp_df['anova_pvalue_fdr'] < 0.05
            ethnic_comp_df['significant_kw'] = ethnic_comp_df['kw_pvalue_fdr'] < 0.05
            ethnic_comp_df = ethnic_comp_df.sort_values('omega_squared', ascending=False)
            
            n_sig_anova = ethnic_comp_df['significant_anova'].sum()
            n_sig_kw = ethnic_comp_df['significant_kw'].sum()
            
            print(f"\nFeatures tested: {len(ethnic_comp_df)}")
            print(f"Significant ethnicity differences (ANOVA, FDR<0.05): {n_sig_anova} ({100*n_sig_anova/len(ethnic_comp_df):.1f}%)")
            print(f"Significant ethnicity differences (Kruskal-Wallis, FDR<0.05): {n_sig_kw} ({100*n_sig_kw/len(ethnic_comp_df):.1f}%)")
            
            if n_sig_anova > 0:
                print(f"\nTop 10 features with ethnicity effects (by ω²):")
                display_cols = ['feature', 'category', 'omega_squared', 'eta_squared', 
                               'anova_pvalue_fdr', 'kw_pvalue_fdr']
                display(ethnic_comp_df.head(10)[display_cols].round(4))
            else:
                print("\n⚠️  No significant ethnicity effects detected after FDR correction")
            
            # Save results
            ethnic_comp_df.to_csv(TABLES_DIR / 'ethnicity_comparisons.csv', index=False)
            print(f"\n✓ Ethnicity comparisons saved")
            
            # Visualization if significant effects found
            if n_sig_anova > 0:
                top_ethnic_features = ethnic_comp_df[ethnic_comp_df['significant_anova']].head(6)['feature'].tolist()
                
                if len(top_ethnic_features) > 0:
                    n_plots = min(6, len(top_ethnic_features))
                    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
                    fig.suptitle('Top Ethnicity Differences in Vessel Features', 
                                fontsize=14, fontweight='bold')
                    axes = axes.flatten()
                    
                    for idx, feature in enumerate(top_ethnic_features[:n_plots]):
                        ax = axes[idx]
                        
                        plot_data = df_ethnic[['ETHNIC_ID', feature]].dropna()
                        plot_data['Ethnicity'] = plot_data['ETHNIC_ID'].astype(str)
                        
                        sns.boxplot(data=plot_data, x='Ethnicity', y=feature, ax=ax, palette='Set3')
                        sns.stripplot(data=plot_data, x='Ethnicity', y=feature, ax=ax,
                                     color='black', alpha=0.3, size=3)
                        
                        feat_stats = ethnic_comp_df[ethnic_comp_df['feature'] == feature].iloc[0]
                        omega_sq = feat_stats['omega_squared']
                        p_val = feat_stats['anova_pvalue_fdr']
                        
                        if p_val < 0.001:
                            sig_stars = '***'
                            p_text = 'p < 0.001'
                        elif p_val < 0.01:
                            sig_stars = '**'
                            p_text = f'p = {p_val:.3f}'
                        elif p_val < 0.05:
                            sig_stars = '*'
                            p_text = f'p = {p_val:.3f}'
                        else:
                            sig_stars = 'ns'
                            p_text = f'p = {p_val:.2f}'
                        
                        clean_feature = feature.replace('_', ' ').replace('total ', '').replace('mean ', '').title()
                        
                        ax.set_title(f"{clean_feature} {sig_stars}\nω² = {omega_sq:.3f}, {p_text}", 
                                    fontsize=10)
                        ax.set_ylabel(feature.replace('_', ' ').title(), fontsize=9)
                        ax.set_xlabel('Ethnicity Category', fontsize=9)
                        ax.grid(True, alpha=0.3, axis='y')
                    
                    for idx in range(n_plots, 6):
                        axes[idx].axis('off')
                    
                    plt.tight_layout()
                    plt.savefig(FIGURES_DIR / 'ethnicity_differences_top6.png', dpi=300, bbox_inches='tight')
                    print("✓ Ethnicity differences visualization saved")
                    plt.show()
    else:
        print(f"\n⚠️  Insufficient data for analysis: only {len(valid_groups)} group(s) with n≥10")
        print("    Ethnicity analysis skipped")
else:
    print("\n⚠️  ETHNIC_ID column not found in dataset")

---
## 9.3 Marital Status Analysis


In [None]:
# ============================================================================
# MARITAL STATUS ANALYSIS
# ============================================================================

print("\n" + "=" * 80)
print("MARITAL STATUS ANALYSIS")
print("=" * 80)

if 'MARITAL_ID' in df.columns:
    df_marital = df[df['MARITAL_ID'].notna()].copy()
    marital_counts = df_marital['MARITAL_ID'].value_counts().sort_index()
    
    print(f"\nMarital status distribution:")
    print(f"  Total with marital data: {len(df_marital)} ({100*len(df_marital)/len(df):.1f}% of cohort)")
    print(f"  Missing marital data: {df['MARITAL_ID'].isna().sum()} ({100*df['MARITAL_ID'].isna().sum()/len(df):.1f}%)")
    
    # Map marital status codes (adjust based on your dataset documentation)
    marital_labels = {
        1: 'Single',
        2: 'Married/Partnership',
        3: 'Divorced/Separated',
        4: 'Widowed'
    }
    
    print(f"\nMarital status categories:")
    for marital_id, count in marital_counts.items():
        label = marital_labels.get(int(marital_id), f'Category {int(marital_id)}')
        print(f"  {label}: {count} subjects ({100*count/len(df_marital):.1f}%)")
    
    valid_groups = marital_counts[marital_counts >= 10]
    
    if len(valid_groups) >= 2:
        print(f"\n✓ Sufficient data for analysis: {len(valid_groups)} groups with n≥10")
        
        df_marital = df_marital[df_marital['MARITAL_ID'].isin(valid_groups.index)].copy()
        df_marital['MARITAL_LABEL'] = df_marital['MARITAL_ID'].map(marital_labels)
        
        # Demographic characteristics
        print("\n" + "-" * 80)
        print("Demographic characteristics by marital status:")
        print("-" * 80)
        
        sex_col = 'SEX_ID (1=m, 2=f)' if 'SEX_ID (1=m, 2=f)' in df_marital.columns else 'SEX_ID'
        
        marital_demo = df_marital.groupby('MARITAL_LABEL').agg({
            'subject_id': 'count',
            'AGE': ['mean', 'std', 'min', 'max'],
            sex_col: lambda x: (x == 1).sum(),
            'HEIGHT': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan,
            'WEIGHT': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan,
            'BMI': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan
        })
        
        marital_demo.columns = ['N', 'Age_Mean', 'Age_SD', 'Age_Min', 'Age_Max',
                                'N_Males', 'Height_Mean', 'Weight_Mean', 'BMI_Mean']
        marital_demo['Pct_Male'] = 100 * marital_demo['N_Males'] / marital_demo['N']
        
        display(marital_demo.round(2))
        
        # Note: Marital status is strongly confounded with age
        print("\n⚠️  NOTE: Marital status is highly correlated with age")
        print("    Consider age-adjusted analyses or stratified analyses")
        
        # Age differences
        age_by_marital = [df_marital[df_marital['MARITAL_ID'] == mid]['AGE'].dropna()
                         for mid in valid_groups.index]
        f_age, p_age = f_oneway(*age_by_marital)
        print(f"\nAge difference across marital status: F={f_age:.3f}, p={p_age:.4e}")
        
        # Marital status effects on vessel features (unadjusted)
        print("\n" + "-" * 80)
        print("Marital status effects on vessel features (UNADJUSTED):")
        print("-" * 80)
        
        marital_comparisons = []
        
        for feature in ALL_FEATURES:
            marital_data = [df_marital[df_marital['MARITAL_ID'] == mid][feature].dropna()
                           for mid in valid_groups.index]
            marital_data = [d for d in marital_data if len(d) >= 5]
            
            if len(marital_data) < 2:
                continue
            
            f_stat, f_pval = f_oneway(*marital_data)
            h_stat, h_pval = kruskal(*marital_data)
            levene_stat, levene_pval = levene(*marital_data)
            
            grand_mean = np.mean(np.concatenate(marital_data))
            ss_between = sum(len(d) * (np.mean(d) - grand_mean)**2 for d in marital_data)
            ss_total = sum(np.sum((d - grand_mean)**2) for d in marital_data)
            eta_squared = ss_between / ss_total if ss_total > 0 else 0
            
            n_total = sum(len(d) for d in marital_data)
            k_groups = len(marital_data)
            omega_squared = (ss_between - (k_groups - 1) * (ss_total - ss_between) / (n_total - k_groups)) / \
                           (ss_total + (ss_total - ss_between) / (n_total - k_groups))
            omega_squared = max(0, omega_squared)
            
            marital_comparisons.append({
                'feature': feature,
                'category': feature_categories[feature],
                'f_statistic': f_stat,
                'anova_pvalue': f_pval,
                'kw_pvalue': h_pval,
                'eta_squared': eta_squared,
                'omega_squared': omega_squared
            })
        
        marital_comp_df = pd.DataFrame(marital_comparisons)
        
        if len(marital_comp_df) > 0:
            marital_comp_df['anova_pvalue_fdr'] = multipletests(marital_comp_df['anova_pvalue'], method='fdr_bh')[1]
            marital_comp_df['significant'] = marital_comp_df['anova_pvalue_fdr'] < 0.05
            marital_comp_df = marital_comp_df.sort_values('omega_squared', ascending=False)
            
            n_sig = marital_comp_df['significant'].sum()
            
            print(f"\nFeatures tested: {len(marital_comp_df)}")
            print(f"Significant marital status differences (ANOVA, FDR<0.05): {n_sig} ({100*n_sig/len(marital_comp_df):.1f}%)")
            
            if n_sig > 0:
                print(f"\nTop 10 features with marital status effects (by ω²):")
                display(marital_comp_df.head(10)[['feature', 'category', 'omega_squared', 'anova_pvalue_fdr']].round(4))
                
                print("\n⚠️  WARNING: These differences may be confounded by age!")
                print("    Perform age-adjusted analysis below for proper interpretation")
            
            marital_comp_df.to_csv(TABLES_DIR / 'marital_comparisons_unadjusted.csv', index=False)
            
            # AGE-ADJUSTED ANALYSIS
            print("\n" + "-" * 80)
            print("Marital status effects on vessel features (AGE-ADJUSTED):")
            print("-" * 80)
            
            marital_adjusted = []
            
            for feature in ALL_FEATURES[:30]:  # Test subset for efficiency
                model_data = df_marital[['AGE', 'MARITAL_ID', sex_col, feature]].dropna()
                
                if len(model_data) < 30:
                    continue
                
                # Model 1: Age + Sex only
                formula1 = f'Q("{feature}") ~ AGE + C(Q("{sex_col}"))'
                model1 = ols(formula1, data=model_data).fit()
                
                # Model 2: Age + Sex + Marital Status
                formula2 = f'Q("{feature}") ~ AGE + C(Q("{sex_col}")) + C(MARITAL_ID)'
                model2 = ols(formula2, data=model_data).fit()
                
                # F-test for marital status effect
                f_marital = ((model1.ssr - model2.ssr) / (model2.df_resid - model1.df_resid)) / \
                           (model2.ssr / model2.df_resid)
                p_marital = 1 - stats.f.cdf(f_marital, model2.df_resid - model1.df_resid, model2.df_resid)
                
                marital_adjusted.append({
                    'feature': feature,
                    'category': feature_categories[feature],
                    'r2_baseline': model1.rsquared,
                    'r2_with_marital': model2.rsquared,
                    'delta_r2': model2.rsquared - model1.rsquared,
                    'p_marital_adjusted': p_marital
                })
            
            marital_adj_df = pd.DataFrame(marital_adjusted)
            
            if len(marital_adj_df) > 0:
                marital_adj_df['p_fdr'] = multipletests(marital_adj_df['p_marital_adjusted'], method='fdr_bh')[1]
                marital_adj_df['significant_adjusted'] = marital_adj_df['p_fdr'] < 0.05
                marital_adj_df = marital_adj_df.sort_values('delta_r2', ascending=False)
                
                n_sig_adj = marital_adj_df['significant_adjusted'].sum()
                
                print(f"\nFeatures tested: {len(marital_adj_df)}")
                print(f"Significant age-adjusted marital effects: {n_sig_adj} ({100*n_sig_adj/len(marital_adj_df):.1f}%)")
                
                if n_sig_adj > 0:
                    print(f"\nTop features with age-adjusted marital status effects:")
                    display(marital_adj_df[marital_adj_df['significant_adjusted']][
                        ['feature', 'delta_r2', 'p_fdr']].round(4))
                else:
                    print("\n✓ No significant marital status effects after adjusting for age")
                    print("    → Observed differences were primarily age-related")
                
                marital_adj_df.to_csv(TABLES_DIR / 'marital_comparisons_age_adjusted.csv', index=False)
                print(f"\n✓ Age-adjusted marital status analysis saved")
    else:
        print(f"\n⚠️  Insufficient data: only {len(valid_groups)} group(s) with n≥10")
else:
    print("\n⚠️  MARITAL_ID column not found in dataset")

---
## 9.4 Occupation Analysis

In [None]:
# ============================================================================
# OCCUPATION ANALYSIS
# ============================================================================

print("\n" + "=" * 80)
print("OCCUPATION ANALYSIS")
print("=" * 80)

if 'OCCUPATION_ID' in df.columns:
    df_occup = df[df['OCCUPATION_ID'].notna()].copy()
    occup_counts = df_occup['OCCUPATION_ID'].value_counts().sort_index()
    
    print(f"\nOccupation distribution:")
    print(f"  Total with occupation data: {len(df_occup)} ({100*len(df_occup)/len(df):.1f}% of cohort)")
    print(f"  Missing occupation data: {df['OCCUPATION_ID'].isna().sum()} ({100*df['OCCUPATION_ID'].isna().sum()/len(df):.1f}%)")
    
    # Define occupation labels
    occupation_labels = {
        1: 'Full-time employed',
        2: 'Part-time employed',
        3: 'Student',
        4: 'Housework',
        5: 'Retired',
        6: 'Unemployed',
        7: 'Work at home',
        8: 'Other'
    }
    
    print(f"\nOccupation categories:")
    for occup_id, count in occup_counts.items():
        label = occupation_labels.get(int(occup_id), f'Category {int(occup_id)}')
        print(f"  {label}: {count} subjects ({100*count/len(df_occup):.1f}%)")
    
    valid_groups = occup_counts[occup_counts >= 10]
    
    if len(valid_groups) >= 2:
        print(f"\n✓ Sufficient data for analysis: {len(valid_groups)} groups with n≥10")
        
        df_occup = df_occup[df_occup['OCCUPATION_ID'].isin(valid_groups.index)].copy()
        df_occup['OCCUPATION_LABEL'] = df_occup['OCCUPATION_ID'].map(occupation_labels)
        
        # Define consistent color mapping for ALL occupation visualizations
        unique_occupations = df_occup['OCCUPATION_LABEL'].unique()
        occupation_colors = dict(zip(unique_occupations, 
                                    sns.color_palette('tab10', n_colors=len(unique_occupations))))
        
        # Demographic characteristics
        print("\n" + "-" * 80)
        print("Demographic characteristics by occupation:")
        print("-" * 80)
        
        sex_col = 'SEX_ID (1=m, 2=f)' if 'SEX_ID (1=m, 2=f)' in df_occup.columns else 'SEX_ID'
        
        occup_demo = df_occup.groupby('OCCUPATION_LABEL').agg({
            'subject_id': 'count',
            'AGE': ['mean', 'std', 'min', 'max'],
            sex_col: lambda x: (x == 1).sum(),
            'HEIGHT': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan,
            'WEIGHT': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan,
            'BMI': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan
        })
        
        occup_demo.columns = ['N', 'Age_Mean', 'Age_SD', 'Age_Min', 'Age_Max',
                              'N_Males', 'Height_Mean', 'Weight_Mean', 'BMI_Mean']
        occup_demo['Pct_Male'] = 100 * occup_demo['N_Males'] / occup_demo['N']
        
        display(occup_demo.round(2))
        
        print("\n⚠️  NOTE: Occupation may be confounded with age, sex, and socioeconomic status")
        
        # Occupation effects (unadjusted)
        print("\n" + "-" * 80)
        print("Occupation effects on vessel features (UNADJUSTED):")
        print("-" * 80)
        
        occup_comparisons = []
        
        for feature in ALL_FEATURES:
            occup_data = [df_occup[df_occup['OCCUPATION_ID'] == oid][feature].dropna()
                         for oid in valid_groups.index]
            occup_data = [d for d in occup_data if len(d) >= 5]
            
            if len(occup_data) < 2:
                continue
            
            f_stat, f_pval = f_oneway(*occup_data)
            h_stat, h_pval = kruskal(*occup_data)
            
            grand_mean = np.mean(np.concatenate(occup_data))
            ss_between = sum(len(d) * (np.mean(d) - grand_mean)**2 for d in occup_data)
            ss_total = sum(np.sum((d - grand_mean)**2) for d in occup_data)
            eta_squared = ss_between / ss_total if ss_total > 0 else 0
            
            n_total = sum(len(d) for d in occup_data)
            k_groups = len(occup_data)
            omega_squared = (ss_between - (k_groups - 1) * (ss_total - ss_between) / (n_total - k_groups)) / \
                           (ss_total + (ss_total - ss_between) / (n_total - k_groups))
            omega_squared = max(0, omega_squared)
            
            occup_comparisons.append({
                'feature': feature,
                'category': feature_categories[feature],
                'f_statistic': f_stat,
                'anova_pvalue': f_pval,
                'kw_pvalue': h_pval,
                'omega_squared': omega_squared
            })
        
        occup_comp_df = pd.DataFrame(occup_comparisons)
        
        if len(occup_comp_df) > 0:
            occup_comp_df['anova_pvalue_fdr'] = multipletests(occup_comp_df['anova_pvalue'], method='fdr_bh')[1]
            occup_comp_df['significant'] = occup_comp_df['anova_pvalue_fdr'] < 0.05
            occup_comp_df = occup_comp_df.sort_values('omega_squared', ascending=False)
            
            n_sig = occup_comp_df['significant'].sum()
            
            print(f"\nFeatures tested: {len(occup_comp_df)}")
            print(f"Significant occupation differences: {n_sig} ({100*n_sig/len(occup_comp_df):.1f}%)")
            
            if n_sig > 0:
                print(f"\nTop 10 features with occupation effects:")
                display(occup_comp_df.head(10)[['feature', 'category', 'omega_squared', 'anova_pvalue_fdr']].round(4))
            
            occup_comp_df.to_csv(TABLES_DIR / 'occupation_comparisons.csv', index=False)
            print(f"\n✓ Occupation analysis saved")
            
            # ========================================================================
            # VISUALIZATIONS
            # ========================================================================
            
            # Plot 1: Occupation distribution with demographics
            print("\nGenerating occupation visualizations...")
            
            fig, axes = plt.subplots(2, 2, figsize=(14, 10))
            fig.suptitle('Occupation Demographics Overview', fontsize=14, fontweight='bold')
            
            # 1. Occupation distribution
            ax = axes[0, 0]
            occup_demo_sorted = occup_demo.sort_values('N', ascending=False)
            occupation_order_demo = occup_demo_sorted.index.tolist()
            bar_colors = [occupation_colors[occ] for occ in occupation_order_demo]
            bars = ax.barh(range(len(occup_demo_sorted)), occup_demo_sorted['N'], 
                          color=bar_colors, alpha=0.7, edgecolor='black')
            ax.set_yticks(range(len(occup_demo_sorted)))
            ax.set_yticklabels(occupation_order_demo, fontsize=9)
            ax.set_xlabel('Number of Subjects', fontsize=10)
            ax.set_title('Sample Size by Occupation', fontsize=11, fontweight='bold')
            ax.grid(axis='x', alpha=0.3)
            
            # Add counts on bars
            for i, v in enumerate(occup_demo_sorted['N']):
                ax.text(v + 1, i, f'n={int(v)}', va='center', fontsize=9)
            
            # 2. Age distribution by occupation
            ax = axes[0, 1]
            plot_data = df_occup[['OCCUPATION_LABEL', 'AGE']].dropna()
            box_colors = [occupation_colors[occ] for occ in occupation_order_demo]
            sns.boxplot(data=plot_data, y='OCCUPATION_LABEL', x='AGE', ax=ax,
                       order=occupation_order_demo, palette=box_colors)
            ax.set_xlabel('Age (years)', fontsize=10)
            ax.set_ylabel('')
            ax.set_title('Age Distribution by Occupation', fontsize=11, fontweight='bold')
            ax.grid(axis='x', alpha=0.3)
            
            # 3. Sex distribution by occupation
            ax = axes[1, 0]
            sex_by_occup = df_occup.groupby('OCCUPATION_LABEL')[sex_col].value_counts(normalize=True).unstack()
            sex_by_occup = sex_by_occup.reindex(occupation_order_demo)
            sex_by_occup.plot(kind='barh', stacked=True, ax=ax, 
                             color=['#ff7f0e', '#ff69b4'], alpha=0.8)
            ax.set_xlabel('Proportion', fontsize=10)
            ax.set_ylabel('')
            ax.set_title('Sex Distribution by Occupation', fontsize=11, fontweight='bold')
            ax.legend(['Male', 'Female'], fontsize=9)
            ax.grid(axis='x', alpha=0.3)
            
            # 4. BMI by occupation (if available)
            ax = axes[1, 1]
            if 'BMI' in df_occup.columns:
                plot_data_bmi = df_occup[df_occup['BMI'] > 0][['OCCUPATION_LABEL', 'BMI']].dropna()
                if len(plot_data_bmi) > 0:
                    bmi_colors = [occupation_colors[occ] for occ in occupation_order_demo]
                    sns.boxplot(data=plot_data_bmi, y='OCCUPATION_LABEL', x='BMI', ax=ax,
                               order=occupation_order_demo, palette=bmi_colors)
                    ax.axvline(25, color='orange', linestyle='--', linewidth=2, alpha=0.7, label='Overweight')
                    ax.axvline(30, color='red', linestyle='--', linewidth=2, alpha=0.7, label='Obese')
                    ax.set_xlabel('BMI (kg/m²)', fontsize=10)
                    ax.set_ylabel('')
                    ax.set_title('BMI Distribution by Occupation', fontsize=11, fontweight='bold')
                    ax.legend(fontsize=8)
                    ax.grid(axis='x', alpha=0.3)
                else:
                    ax.text(0.5, 0.5, 'Insufficient BMI data', ha='center', va='center',
                           transform=ax.transAxes)
                    ax.axis('off')
            else:
                ax.text(0.5, 0.5, 'BMI data not available', ha='center', va='center',
                       transform=ax.transAxes)
                ax.axis('off')
            
            plt.tight_layout()
            plt.savefig(FIGURES_DIR / 'occupation_demographics.png', dpi=300, bbox_inches='tight')
            print("✓ Occupation demographics plot saved")
            plt.show()
            
            # Plot 2: Top vessel feature differences by occupation
            if n_sig > 0:
                top_occup_features = occup_comp_df[occup_comp_df['significant']].head(6)['feature'].tolist()
                
                if len(top_occup_features) > 0:
                    n_plots = min(6, len(top_occup_features))
                    fig, axes = plt.subplots(2, 3, figsize=(16, 10))
                    fig.suptitle('Top Occupation Differences in Vessel Features',
                                fontsize=14, fontweight='bold')
                    axes = axes.flatten()
                    
                    for idx, feature in enumerate(top_occup_features[:n_plots]):
                        ax = axes[idx]
                        
                        plot_data = df_occup[['OCCUPATION_LABEL', feature]].dropna()
                        
                        # Sort by median value for better visualization
                        median_values = plot_data.groupby('OCCUPATION_LABEL')[feature].median().sort_values()
                        occupation_order_plot = median_values.index.tolist()
                        
                        # Create color palette in the correct order for this plot
                        plot_colors = [occupation_colors[occ] for occ in occupation_order_plot]
                        
                        sns.boxplot(data=plot_data, y='OCCUPATION_LABEL', x=feature, ax=ax,
                                   order=occupation_order_plot, palette=plot_colors)
                        sns.stripplot(data=plot_data, y='OCCUPATION_LABEL', x=feature, ax=ax,
                                     color='black', alpha=0.3, size=2, order=occupation_order_plot)
                        
                        feat_stats = occup_comp_df[occup_comp_df['feature'] == feature].iloc[0]
                        omega_sq = feat_stats['omega_squared']
                        p_val = feat_stats['anova_pvalue_fdr']
                        
                        if p_val < 0.001:
                            sig_stars = '***'
                            p_text = 'p < 0.001'
                        elif p_val < 0.01:
                            sig_stars = '**'
                            p_text = f'p = {p_val:.3f}'
                        elif p_val < 0.05:
                            sig_stars = '*'
                            p_text = f'p = {p_val:.3f}'
                        else:
                            sig_stars = 'ns'
                            p_text = f'p = {p_val:.2f}'
                        
                        clean_feature = feature.replace('_', ' ').replace('total ', '').replace('mean ', '').title()
                        
                        ax.set_title(f"{clean_feature} {sig_stars}\nω² = {omega_sq:.3f}, {p_text}",
                                    fontsize=9)
                        ax.set_xlabel(feature.replace('_', ' ').title(), fontsize=9)
                        ax.set_ylabel('')
                        ax.tick_params(axis='y', labelsize=8)
                        ax.grid(True, alpha=0.3, axis='x')
                    
                    for idx in range(n_plots, 6):
                        axes[idx].axis('off')
                    
                    plt.tight_layout()
                    plt.savefig(FIGURES_DIR / 'occupation_differences_top6.png', dpi=300, bbox_inches='tight')
                    print("✓ Occupation differences visualization saved")
                    plt.show()
            
            # Plot 3: Effect size heatmap (if multiple significant features)
            if n_sig >= 10:
                print("\nGenerating occupation effect size heatmap...")
                
                # Get top 20 features by effect size
                top_features = occup_comp_df.head(20)['feature'].tolist()
                
                # Create matrix of mean values for each occupation
                heatmap_data = []
                for feature in top_features:
                    feature_means = []
                    for occup_id in sorted(valid_groups.index):
                        mean_val = df_occup[df_occup['OCCUPATION_ID'] == occup_id][feature].mean()
                        feature_means.append(mean_val)
                    heatmap_data.append(feature_means)
                
                heatmap_columns = [occupation_labels.get(oid, str(oid)) for oid in sorted(valid_groups.index)]
                heatmap_df = pd.DataFrame(
                    heatmap_data,
                    index=[f.replace('_', ' ').replace('total ', '').replace('mean ', '').title()[:30] 
                           for f in top_features],
                    columns=heatmap_columns
                )
                
                # Normalize by row (z-score) for better visualization
                heatmap_df_norm = heatmap_df.sub(heatmap_df.mean(axis=1), axis=0).div(heatmap_df.std(axis=1), axis=0)
                
                fig, ax = plt.subplots(figsize=(10, 12))
                sns.heatmap(heatmap_df_norm, cmap='RdBu_r', center=0, 
                           cbar_kws={'label': 'Normalized Value (z-score)'},
                           linewidths=0.5, ax=ax, fmt='.2f',
                           xticklabels=True, yticklabels=True)
                ax.set_title('Vessel Features by Occupation\n(Top 20 by Effect Size, Z-score normalized)',
                            fontsize=12, fontweight='bold', pad=20)
                ax.set_xlabel('Occupation', fontsize=11)
                ax.set_ylabel('Vessel Feature', fontsize=11)
                plt.xticks(rotation=45, ha='right', fontsize=9)
                plt.yticks(fontsize=8)
                
                plt.tight_layout()
                plt.savefig(FIGURES_DIR / 'occupation_heatmap.png', dpi=300, bbox_inches='tight')
                print("✓ Occupation heatmap saved")
                plt.show()
                
    else:
        print(f"\n⚠️  Insufficient data: only {len(valid_groups)} group(s) with n≥10")
else:
    print("\n⚠️  OCCUPATION_ID column not found in dataset")

---
## 9.5 Education/QUALIFICATION ANalysis

In [None]:
# ============================================================================
# EDUCATIONAL QUALIFICATION ANALYSIS
# ============================================================================

print("\n" + "=" * 80)
print("EDUCATIONAL QUALIFICATION ANALYSIS")
print("=" * 80)

if 'QUALIFICATION_ID' in df.columns:
    df_qual = df[df['QUALIFICATION_ID'].notna()].copy()
    qual_counts = df_qual['QUALIFICATION_ID'].value_counts().sort_index()
    
    print(f"\nEducational qualification distribution:")
    print(f"  Total with education data: {len(df_qual)} ({100*len(df_qual)/len(df):.1f}% of cohort)")
    print(f"  Missing education data: {df['QUALIFICATION_ID'].isna().sum()} ({100*df['QUALIFICATION_ID'].isna().sum()/len(df):.1f}%)")
    
    # Define qualification labels
    qualification_labels = {
        0: 'Missing',
        1: 'None',
        2: 'O-levels/GCSEs',
        3: 'A-levels',
        4: 'Further Ed.',
        5: 'University'
    }
    
    print(f"\nEducation level categories:")
    for qual_id, count in qual_counts.items():
        label = qualification_labels.get(int(qual_id), f'Level {int(qual_id)}')
        print(f"  {label}: {count} subjects ({100*count/len(df_qual):.1f}%)")
    
    valid_groups = qual_counts[qual_counts >= 10]
    
    if len(valid_groups) >= 2:
        print(f"\n✓ Sufficient data for analysis: {len(valid_groups)} groups with n≥10")
        
        df_qual = df_qual[df_qual['QUALIFICATION_ID'].isin(valid_groups.index)].copy()
        df_qual['QUALIFICATION_LABEL'] = df_qual['QUALIFICATION_ID'].map(qualification_labels)
        
        # Demographic characteristics
        print("\n" + "-" * 80)
        print("Demographic characteristics by education level:")
        print("-" * 80)
        
        sex_col = 'SEX_ID (1=m, 2=f)' if 'SEX_ID (1=m, 2=f)' in df_qual.columns else 'SEX_ID'
        
        qual_demo = df_qual.groupby('QUALIFICATION_LABEL').agg({
            'subject_id': 'count',
            'AGE': ['mean', 'std', 'min', 'max'],
            sex_col: lambda x: (x == 1).sum(),
            'HEIGHT': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan,
            'WEIGHT': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan,
            'BMI': lambda x: x[x > 0].mean() if (x > 0).any() else np.nan
        })
        
        qual_demo.columns = ['N', 'Age_Mean', 'Age_SD', 'Age_Min', 'Age_Max',
                            'N_Males', 'Height_Mean', 'Weight_Mean', 'BMI_Mean']
        qual_demo['Pct_Male'] = 100 * qual_demo['N_Males'] / qual_demo['N']
        
        display(qual_demo.round(2))
        
        print("\n⚠️  NOTE: Education is often a proxy for socioeconomic status")
        print("    May be associated with health behaviors and vascular risk factors")
        
        # Education effects
        print("\n" + "-" * 80)
        print("Education effects on vessel features:")
        print("-" * 80)
        
        qual_comparisons = []
        
        for feature in ALL_FEATURES:
            qual_data = [df_qual[df_qual['QUALIFICATION_ID'] == qid][feature].dropna()
                        for qid in valid_groups.index]
            qual_data = [d for d in qual_data if len(d) >= 5]
            
            if len(qual_data) < 2:
                continue
            
            f_stat, f_pval = f_oneway(*qual_data)
            h_stat, h_pval = kruskal(*qual_data)
            
            # Test for trend (if ordinal)
            all_qual_data = df_qual[['QUALIFICATION_ID', feature]].dropna()
            if len(all_qual_data) > 20:
                spearman_r, spearman_p = stats.spearmanr(all_qual_data['QUALIFICATION_ID'], 
                                                         all_qual_data[feature])
            else:
                spearman_r, spearman_p = np.nan, np.nan
            
            grand_mean = np.mean(np.concatenate(qual_data))
            ss_between = sum(len(d) * (np.mean(d) - grand_mean)**2 for d in qual_data)
            ss_total = sum(np.sum((d - grand_mean)**2) for d in qual_data)
            eta_squared = ss_between / ss_total if ss_total > 0 else 0
            
            n_total = sum(len(d) for d in qual_data)
            k_groups = len(qual_data)
            omega_squared = (ss_between - (k_groups - 1) * (ss_total - ss_between) / (n_total - k_groups)) / \
                           (ss_total + (ss_total - ss_between) / (n_total - k_groups))
            omega_squared = max(0, omega_squared)
            
            qual_comparisons.append({
                'feature': feature,
                'category': feature_categories[feature],
                'f_statistic': f_stat,
                'anova_pvalue': f_pval,
                'kw_pvalue': h_pval,
                'spearman_r': spearman_r,
                'spearman_pvalue': spearman_p,
                'omega_squared': omega_squared
            })
        
        qual_comp_df = pd.DataFrame(qual_comparisons)
        
        if len(qual_comp_df) > 0:
            qual_comp_df['anova_pvalue_fdr'] = multipletests(qual_comp_df['anova_pvalue'], method='fdr_bh')[1]
            qual_comp_df['spearman_pvalue_fdr'] = multipletests(qual_comp_df['spearman_pvalue'].fillna(1), 
                                                                 method='fdr_bh')[1]
            qual_comp_df['significant_anova'] = qual_comp_df['anova_pvalue_fdr'] < 0.05
            qual_comp_df['significant_trend'] = qual_comp_df['spearman_pvalue_fdr'] < 0.05
            qual_comp_df = qual_comp_df.sort_values('omega_squared', ascending=False)
            
            n_sig = qual_comp_df['significant_anova'].sum()
            n_sig_trend = qual_comp_df['significant_trend'].sum()
            
            print(f"\nFeatures tested: {len(qual_comp_df)}")
            print(f"Significant education differences (ANOVA, FDR<0.05): {n_sig} ({100*n_sig/len(qual_comp_df):.1f}%)")
            print(f"Significant monotonic trends (Spearman): {n_sig_trend} ({100*n_sig_trend/len(qual_comp_df):.1f}%)")
            
            if n_sig > 0:
                print(f"\nTop 10 features with education effects:")
                display(qual_comp_df.head(10)[['feature', 'category', 'omega_squared', 'anova_pvalue_fdr', 'spearman_r']].round(4))
            
            qual_comp_df.to_csv(TABLES_DIR / 'education_comparisons.csv', index=False)
            print(f"\n✓ Education analysis saved")
            
            # ========================================================================
            # VISUALIZATIONS
            # ========================================================================
            
            # Plot 1: Education distribution with demographics
            print("\nGenerating education visualizations...")
            
            fig, axes = plt.subplots(2, 2, figsize=(14, 10))
            fig.suptitle('Education Demographics Overview', fontsize=14, fontweight='bold')
            
            # 1. Education distribution
            ax = axes[0, 0]
            qual_demo_sorted = qual_demo.sort_values('N', ascending=False)
            bars = ax.barh(range(len(qual_demo_sorted)), qual_demo_sorted['N'], 
                          color='steelblue', alpha=0.7, edgecolor='black')
            ax.set_yticks(range(len(qual_demo_sorted)))
            ax.set_yticklabels(qual_demo_sorted.index, fontsize=9)
            ax.set_xlabel('Number of Subjects', fontsize=10)
            ax.set_title('Sample Size by Education Level', fontsize=11, fontweight='bold')
            ax.grid(axis='x', alpha=0.3)
            
            # Add counts on bars
            for i, v in enumerate(qual_demo_sorted['N']):
                ax.text(v + 1, i, f'n={int(v)}', va='center', fontsize=9)
            
            # 2. Age distribution by education
            ax = axes[0, 1]
            plot_data = df_qual[['QUALIFICATION_LABEL', 'AGE']].dropna()
            # Order by education level
            education_order = [qualification_labels[i] for i in sorted(df_qual['QUALIFICATION_ID'].unique())]
            sns.boxplot(data=plot_data, y='QUALIFICATION_LABEL', x='AGE', ax=ax,
                       order=education_order, palette='viridis')
            ax.set_xlabel('Age (years)', fontsize=10)
            ax.set_ylabel('')
            ax.set_title('Age Distribution by Education Level', fontsize=11, fontweight='bold')
            ax.grid(axis='x', alpha=0.3)
            
            # 3. Sex distribution by education
            ax = axes[1, 0]
            sex_by_qual = df_qual.groupby('QUALIFICATION_LABEL')[sex_col].value_counts(normalize=True).unstack()
            sex_by_qual = sex_by_qual.reindex(education_order)
            sex_by_qual.plot(kind='barh', stacked=True, ax=ax, 
                           color=['#ff7f0e', '#ff69b4'], alpha=0.8)
            ax.set_xlabel('Proportion', fontsize=10)
            ax.set_ylabel('')
            ax.set_title('Sex Distribution by Education Level', fontsize=11, fontweight='bold')
            ax.legend(['Male', 'Female'], fontsize=9)
            ax.grid(axis='x', alpha=0.3)
            
            # 4. BMI by education (if available)
            ax = axes[1, 1]
            if 'BMI' in df_qual.columns:
                plot_data_bmi = df_qual[df_qual['BMI'] > 0][['QUALIFICATION_LABEL', 'BMI']].dropna()
                if len(plot_data_bmi) > 0:
                    sns.boxplot(data=plot_data_bmi, y='QUALIFICATION_LABEL', x='BMI', ax=ax,
                               order=education_order, palette='viridis')
                    ax.axvline(25, color='orange', linestyle='--', linewidth=2, alpha=0.7, label='Overweight')
                    ax.axvline(30, color='red', linestyle='--', linewidth=2, alpha=0.7, label='Obese')
                    ax.set_xlabel('BMI (kg/m²)', fontsize=10)
                    ax.set_ylabel('')
                    ax.set_title('BMI Distribution by Education Level', fontsize=11, fontweight='bold')
                    ax.legend(fontsize=8)
                    ax.grid(axis='x', alpha=0.3)
                else:
                    ax.text(0.5, 0.5, 'Insufficient BMI data', ha='center', va='center',
                           transform=ax.transAxes)
                    ax.axis('off')
            else:
                ax.text(0.5, 0.5, 'BMI data not available', ha='center', va='center',
                       transform=ax.transAxes)
                ax.axis('off')
            
            plt.tight_layout()
            plt.savefig(FIGURES_DIR / 'education_demographics.png', dpi=300, bbox_inches='tight')
            print("✓ Education demographics plot saved")
            plt.show()
            
            # Plot 2: Top vessel feature differences by education
            if n_sig > 0:
                top_qual_features = qual_comp_df[qual_comp_df['significant_anova']].head(6)['feature'].tolist()
                
                if len(top_qual_features) > 0:
                    # Define consistent color mapping for education levels
                    unique_education = df_qual['QUALIFICATION_LABEL'].unique()
                    education_colors = dict(zip(unique_education, 
                                              sns.color_palette('viridis', n_colors=len(unique_education))))
                    
                    n_plots = min(6, len(top_qual_features))
                    fig, axes = plt.subplots(2, 3, figsize=(16, 10))
                    fig.suptitle('Top Education Differences in Vessel Features',
                                fontsize=14, fontweight='bold')
                    axes = axes.flatten()
                    
                    for idx, feature in enumerate(top_qual_features[:n_plots]):
                        ax = axes[idx]
                        
                        plot_data = df_qual[['QUALIFICATION_LABEL', feature]].dropna()
                        
                        # Order by education level (natural ordering)
                        education_order = [qualification_labels[i] for i in sorted(df_qual['QUALIFICATION_ID'].unique())]
                        education_order = ['None', 'O-levels/GCSEs', 'A-levels', 'Further Ed.', 'University']

                        # Create color palette in the correct order for this plot
                        plot_colors = [education_colors[edu] for edu in education_order]
                        
                        sns.boxplot(data=plot_data, x='QUALIFICATION_LABEL', y=feature, ax=ax,
                                order=education_order, palette=plot_colors)
                        sns.stripplot(data=plot_data, x='QUALIFICATION_LABEL', y=feature, ax=ax,
                                    color='black', alpha=0.3, size=2, order=education_order)
                        
                        feat_stats = qual_comp_df[qual_comp_df['feature'] == feature].iloc[0]
                        omega_sq = feat_stats['omega_squared']
                        p_val = feat_stats['anova_pvalue_fdr']
                        
                        if p_val < 0.001:
                            sig_stars = '***'
                            p_text = 'p < 0.001'
                        elif p_val < 0.01:
                            sig_stars = '**'
                            p_text = f'p = {p_val:.3f}'
                        elif p_val < 0.05:
                            sig_stars = '*'
                            p_text = f'p = {p_val:.3f}'
                        else:
                            sig_stars = 'ns'
                            p_text = f'p = {p_val:.2f}'
                        
                        clean_feature = feature.replace('_', ' ').replace('total ', '').replace('mean ', '').title()
                        
                        # Add trend info if significant
                        if feat_stats['significant_trend'] and not pd.isna(feat_stats['spearman_r']):
                            spearman_r = feat_stats['spearman_r']
                            trend_text = f", ρ={spearman_r:.3f}"
                        else:
                            trend_text = ""
                        
                        ax.set_title(f"{clean_feature} {sig_stars}\nω² = {omega_sq:.3f}, {p_text}{trend_text}",
                                    fontsize=9)
                        ax.set_ylabel(feature.replace('_', ' ').title(), fontsize=9)
                        ax.set_xlabel('')
                        ax.tick_params(axis='x', labelsize=8, rotation=45)
                        ax.grid(True, alpha=0.3, axis='x')
                    
                    for idx in range(n_plots, 6):
                        axes[idx].axis('off')
                    
                    plt.tight_layout()
                    plt.savefig(FIGURES_DIR / 'education_differences_top6.png', dpi=300, bbox_inches='tight')
                    print("✓ Education differences visualization saved")
                    plt.show()
            
            # Plot 3: Effect size heatmap (if multiple significant features)
            if n_sig >= 10:
                print("\nGenerating education effect size heatmap...")
                
                # Get top 20 features by effect size
                top_features = qual_comp_df.head(20)['feature'].tolist()
                
                # Create matrix of mean values for each education level
                heatmap_data = []
                education_order = [qualification_labels[i] for i in sorted(valid_groups.index)]
                
                for feature in top_features:
                    feature_means = []
                    for qual_id in sorted(valid_groups.index):
                        mean_val = df_qual[df_qual['QUALIFICATION_ID'] == qual_id][feature].mean()
                        feature_means.append(mean_val)
                    heatmap_data.append(feature_means)
                
                heatmap_df = pd.DataFrame(
                    heatmap_data,
                    index=[f.replace('_', ' ').replace('total ', '').replace('mean ', '').title()[:30] 
                           for f in top_features],
                    columns=education_order
                )
                
                # Normalize by row (z-score) for better visualization
                heatmap_df_norm = heatmap_df.sub(heatmap_df.mean(axis=1), axis=0).div(heatmap_df.std(axis=1), axis=0)
                
                fig, ax = plt.subplots(figsize=(10, 12))
                sns.heatmap(heatmap_df_norm, cmap='RdBu_r', center=0, 
                           cbar_kws={'label': 'Normalized Value (z-score)'},
                           linewidths=0.5, ax=ax, fmt='.2f',
                           xticklabels=True, yticklabels=True)
                ax.set_title('Vessel Features by Education Level\n(Top 20 by Effect Size, Z-score normalized)',
                            fontsize=12, fontweight='bold', pad=20)
                ax.set_xlabel('Education Level', fontsize=11)
                ax.set_ylabel('Vessel Feature', fontsize=11)
                plt.xticks(rotation=45, ha='right', fontsize=9)
                plt.yticks(fontsize=8)
                
                plt.tight_layout()
                plt.savefig(FIGURES_DIR / 'education_heatmap.png', dpi=300, bbox_inches='tight')
                print("✓ Education heatmap saved")
                plt.show()
                
    else:
        print(f"\n⚠️  Insufficient data: only {len(valid_groups)} group(s) with n≥10")
else:
    print("\n⚠️  QUALIFICATION_ID column not found in dataset")

---
## 9.X SUMMARY

In [None]:
# ============================================================================
# SUMMARY OF ALL DEMOGRAPHIC FACTORS
# ============================================================================

print("\n" + "=" * 80)
print("SUMMARY: DEMOGRAPHIC FACTORS AND VESSEL FEATURES")
print("=" * 80)

summary_results = []

# Collect results from all analyses
demographic_analyses = {
    'Sex': ('sex_comp_df', 'significant'),
    'Site': ('site_comp_df', 'significant_anova'),
    'Ethnicity': ('ethnic_comp_df', 'significant_anova'),
    'Marital Status (unadj)': ('marital_comp_df', 'significant'),
    'Marital Status (adj)': ('marital_adj_df', 'significant_adjusted'),
    'Occupation': ('occup_comp_df', 'significant'),
    'Education': ('qual_comp_df', 'significant_anova')
}

print("\nNumber of vessel features with significant associations:")
print("-" * 60)

for demo_name, (df_name, sig_col) in demographic_analyses.items():
    if df_name in locals():
        df_temp = locals()[df_name]
        n_total = len(df_temp)
        n_sig = df_temp[sig_col].sum() if sig_col in df_temp.columns else 0
        pct_sig = 100 * n_sig / n_total if n_total > 0 else 0
        
        print(f"{demo_name:25s}: {n_sig:3d}/{n_total:3d} ({pct_sig:5.1f}%)")
        
        summary_results.append({
            'demographic_factor': demo_name,
            'n_features_tested': n_total,
            'n_significant': n_sig,
            'percent_significant': pct_sig
        })

if len(summary_results) > 0:
    summary_df = pd.DataFrame(summary_results)
    summary_df.to_csv(TABLES_DIR / 'demographic_factors_summary.csv', index=False)
    print(f"\n✓ Summary saved to {TABLES_DIR / 'demographic_factors_summary.csv'}")

print("\n" + "=" * 80)
print("KEY RECOMMENDATIONS FOR YOUR PAPER:")
print("=" * 80)
print("""
1. PRIMARY COVARIATES (include in main analyses):
   - Age (strongest predictor)
   - Sex (biological differences)
   - Site/Scanner (technical variation)

2. SECONDARY FACTORS (mention in discussion):
   - Ethnicity (if significant)
   - Education/Occupation (socioeconomic proxies)

3. CONFOUNDED FACTORS (interpret cautiously):
   - Marital status (age-confounded)
   - Analyze with age adjustment

4. STATISTICAL APPROACH:
   - Use multiple regression including age, sex, and site
   - Report both unadjusted and adjusted analyses
   - Consider interaction terms for key factors
   - Apply FDR correction for multiple comparisons

5. LIMITATIONS TO DISCUSS:
   - Missing data in some demographic variables
   - Potential residual confounding
   - Cross-sectional design (no causality)
""")

---
## 10. Regional Analysis (if applicable)

In [None]:
# Publication-quality figure settings
plt.rcParams['figure.dpi'] = 1000
plt.rcParams['savefig.dpi'] = 1000
plt.rcParams['font.size'] = 20
plt.rcParams['axes.labelsize'] = 25
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['xtick.labelsize'] = 25
plt.rcParams['ytick.labelsize'] = 25
plt.rcParams['legend.fontsize'] = 9

In [None]:
if not IS_REGIONAL_DATA or regional_df is None:
    print("\n⚠️  No regional data available. Skipping regional analysis.")
else:
    print("\n" + "="*80)
    print("REGIONAL ANALYSIS")
    print("="*80)
    
    # Identify region column ('region' or 'region_id' or 'region_label' or similar)
    region_col = 'region' if 'region' in regional_df.columns else 'region_id' if 'region_id' in regional_df.columns else 'region_label' if 'region_label' in regional_df.columns else None
    
    
    if region_col:
        n_regions = regional_df[region_col].nunique()
        print(f"\nNumber of regions: {n_regions}")
        print(f"Regions: {sorted(regional_df[region_col].unique())}")
        
        # Merge demographics for regional analysis
        regional_with_demo = regional_df.merge(demographics[['subject_id', 'AGE', 'SEX_ID', 'ETHNIC_ID','MARITAL_ID','OCCUPATION_ID','QUALIFICATION_ID']], on='subject_id', how='left')
        
        # Regional variability analysis
        print(f"\nRegional variability (coefficient of variation across regions):")
        
        regional_cv = []
        for feature in ALL_FEATURES:
            if feature in regional_df.columns:
                # Calculate mean and std per region
                region_stats = regional_df.groupby(region_col)[feature].agg(['mean', 'std'])
                
                # Coefficient of variation across regions
                cv = region_stats['mean'].std() / region_stats['mean'].mean() if region_stats['mean'].mean() != 0 else 0
                
                regional_cv.append({
                    'feature': feature,
                    'category': feature_categories[feature],
                    'cv_across_regions': cv,
                    'min_regional_mean': region_stats['mean'].min(),
                    'max_regional_mean': region_stats['mean'].max()
                })
        
        regional_cv_df = pd.DataFrame(regional_cv).sort_values('cv_across_regions', ascending=False)
        
        print(f"\nTop 10 features by regional variability:")
        display(regional_cv_df.head(10))
        
        regional_cv_df.to_csv(TABLES_DIR / 'regional_variability.csv', index=False)
        print(f"\n✓ Regional variability analysis saved to {TABLES_DIR / 'regional_variability.csv'}")
    else:
        print("\n⚠️  No region identifier column found")

---
## 10.2 In-depth Regional Analysis

In [None]:
if not IS_REGIONAL_DATA or regional_df is None:
    print("\n⚠️  No regional data available. Skipping regional analysis.")
else:
    print("\n" + "="*80)
    print("COMPREHENSIVE REGIONAL ANALYSIS")
    print("="*80)
    
    # Identify region column
    region_col = None
    for col_name in ['region', 'region_id', 'region_label', 'Region', 'REGION']:
        if col_name in regional_df.columns:
            region_col = col_name
            break
    
    if region_col is None:
        print("\n⚠️  No region identifier column found")
        print(f"Available columns: {regional_df.columns.tolist()}")
    else:
        n_regions = regional_df[region_col].nunique()
        regions = sorted(regional_df[region_col].unique())
        print(f"\nNumber of regions: {n_regions}")
        print(f"Regions: {regions}")
        
        # Merge demographics for regional analysis
        demo_cols = ['subject_id'] + [c for c in ['AGE', 'SEX_ID', 'BMI', 'ETHNIC_ID',	'MARITAL_ID',	'OCCUPATION_ID',	'QUALIFICATION_ID'] if c in demographics.columns]
        regional_with_demo = regional_df.merge(demographics[demo_cols], on='subject_id', how='inner')
        
        print(f"Regional data with demographics: {len(regional_with_demo)} rows, {regional_with_demo['subject_id'].nunique()} subjects")
        
        # ====================================================================
        # 10.1 Regional Feature Distributions
        # ====================================================================
        print("\n" + "-"*80)
        print("10.1 REGIONAL FEATURE DISTRIBUTIONS")
        print("-"*80)
        
        # Calculate statistics per region for all features
        regional_stats = []
        for feature in ALL_FEATURES:
            if feature not in regional_df.columns:
                continue
                
            for region in regions:
                region_data = regional_df[regional_df[region_col] == region][feature].dropna()
                
                if len(region_data) > 0:
                    regional_stats.append({
                        'region': region,
                        'feature': feature,
                        'category': feature_categories[feature],
                        'n': len(region_data),
                        'mean': region_data.mean(),
                        'std': region_data.std(),
                        'median': region_data.median(),
                        'min': region_data.min(),
                        'max': region_data.max(),
                        'cv': region_data.std() / region_data.mean() if region_data.mean() != 0 else 0
                    })
        
        regional_stats_df = pd.DataFrame(regional_stats)
        
        if len(regional_stats_df) > 0:
            # Calculate coefficient of variation across regions for each feature
            regional_variability = []
            for feature in ALL_FEATURES:
                if feature not in regional_df.columns:
                    continue
                    
                feature_stats = regional_stats_df[regional_stats_df['feature'] == feature]
                if len(feature_stats) > 0:
                    cv_across_regions = feature_stats['mean'].std() / feature_stats['mean'].mean() if feature_stats['mean'].mean() != 0 else 0
                    
                    regional_variability.append({
                        'feature': feature,
                        'category': feature_categories[feature],
                        'cv_across_regions': cv_across_regions,
                        'min_regional_mean': feature_stats['mean'].min(),
                        'max_regional_mean': feature_stats['mean'].max(),
                        'fold_change': feature_stats['mean'].max() / feature_stats['mean'].min() if feature_stats['mean'].min() > 0 else np.inf
                    })
            
            regional_var_df = pd.DataFrame(regional_variability).sort_values('cv_across_regions', ascending=False)
            
            print(f"\nTop 15 features by regional variability (coefficient of variation):")
            display(regional_var_df.head(15))
            
            regional_var_df.to_csv(TABLES_DIR / 'regional_variability.csv', index=False)
            print(f"✓ Regional variability saved to {TABLES_DIR / 'regional_variability.csv'}")
            
            # Save complete regional statistics
            regional_stats_df.to_csv(TABLES_DIR / 'regional_statistics_all.csv', index=False)
            print(f"✓ Complete regional statistics saved to {TABLES_DIR / 'regional_statistics_all.csv'}")
        
        # ====================================================================
        # 10.2 Regional Heatmaps - Feature Means by Region
        # ====================================================================
        print("\n" + "-"*80)
        print("10.2 REGIONAL HEATMAPS")
        print("-"*80)
        
        # Create heatmaps for each feature category
        for category_name, features_list in [
            ('Morphometric', MORPHOMETRIC_FEATURES),
            ('Topological', TOPOLOGICAL_FEATURES),
            ('Curvature', CURVATURE_FEATURES)
        ]:
            if len(features_list) == 0:
                continue
            
            # Select top features by variability in this category
            cat_var = regional_var_df[regional_var_df['category'] == category_name.lower()]
            top_features = cat_var.head(min(20, len(cat_var)))['feature'].tolist()
            
            if len(top_features) == 0:
                continue
            
            # Create pivot table for heatmap
            heatmap_data = []
            for feature in top_features:
                if feature in regional_df.columns:
                    means_by_region = regional_df.groupby(region_col)[feature].mean()
                    heatmap_data.append(means_by_region)
            
            if len(heatmap_data) > 0:
                heatmap_df = pd.DataFrame(heatmap_data, index=top_features)
                
                # Normalize each row (z-score) for better visualization
                heatmap_normalized = heatmap_df.sub(heatmap_df.mean(axis=1), axis=0).div(heatmap_df.std(axis=1), axis=0)
                
                fig, ax = plt.subplots(figsize=(max(10, n_regions * 0.5), max(8, len(top_features) * 0.3)))
                sns.heatmap(heatmap_normalized, cmap='RdBu_r', center=0, 
                           cbar_kws={'label': 'Z-score'}, ax=ax,
                           linewidths=0.5, linecolor='gray')
                ax.set_title(f'{category_name} Features Across Regions (Z-scored)', 
                            fontweight='bold', fontsize=14)
                ax.set_xlabel('Region', fontweight='bold')
                ax.set_ylabel('Feature', fontweight='bold')
                plt.tight_layout()
                plt.savefig(FIGURES_DIR / f'regional_heatmap_{category_name.lower()}.png', 
                           dpi=300, bbox_inches='tight')
                plt.show()
                
                print(f"✓ {category_name} regional heatmap saved")
        
        # ====================================================================
        # 10.3 Regional Comparisons - ANOVA/Kruskal-Wallis
        # ====================================================================
        print("\n" + "-"*80)
        print("10.3 REGIONAL COMPARISONS (ANOVA)")
        print("-"*80)
        
        regional_comparisons = []
        
        for feature in ALL_FEATURES:
            if feature not in regional_df.columns:
                continue
            
            # Get data for each region
            region_data = [regional_df[regional_df[region_col] == r][feature].dropna() 
                          for r in regions]
            
            # Remove groups with <5 samples
            region_data = [d for d in region_data if len(d) >= 5]
            
            if len(region_data) < 2:
                continue
            
            try:
                # ANOVA and Kruskal-Wallis test
                f_stat, f_pval = f_oneway(*region_data)
                h_stat, h_pval = kruskal(*region_data)
                
                # Effect size (eta-squared)
                grand_mean = np.mean(np.concatenate(region_data))
                ss_between = sum(len(d) * (np.mean(d) - grand_mean)**2 for d in region_data)
                ss_total = sum(np.sum((d - grand_mean)**2) for d in region_data)
                eta_squared = ss_between / ss_total if ss_total > 0 else 0
                
                regional_comparisons.append({
                    'feature': feature,
                    'category': feature_categories[feature],
                    'n_regions': len(region_data),
                    'f_statistic': f_stat,
                    'anova_pvalue': f_pval,
                    'kruskal_pvalue': h_pval,
                    'eta_squared': eta_squared
                })
            except:
                continue
        
        regional_comp_df = pd.DataFrame(regional_comparisons)
        
        if len(regional_comp_df) > 0:
            # Multiple testing correction
            regional_comp_df['anova_pvalue_fdr'] = multipletests(regional_comp_df['anova_pvalue'], method='fdr_bh')[1]
            regional_comp_df['significant'] = regional_comp_df['anova_pvalue_fdr'] < 0.05
            regional_comp_df = regional_comp_df.sort_values('eta_squared', ascending=False)
            
            n_sig = regional_comp_df['significant'].sum()
            print(f"\nRegional comparison summary:")
            print(f"  Features tested: {len(regional_comp_df)}")
            print(f"  Significant regional differences (ANOVA, FDR<0.05): {n_sig} ({100*n_sig/len(regional_comp_df):.1f}%)")
            
            print(f"\nTop 20 features with regional differences (by eta-squared):")
            display(regional_comp_df.head(20)[['feature', 'category', 'eta_squared', 'anova_pvalue_fdr']])
            
            regional_comp_df.to_csv(TABLES_DIR / 'regional_anova_comparisons.csv', index=False)
            print(f"✓ Regional ANOVA results saved to {TABLES_DIR / 'regional_anova_comparisons.csv'}")

In [None]:
atlas_path = '/home/falcetta/ISBI2025/LIANE/ArterialAtlas.nii.gz'  # Set to your atlas file path

In [None]:
# ============================================================================
# 10.4 BRAIN ATLAS VISUALIZATION OF REGIONAL DIFFERENCES
# ============================================================================

if not IS_REGIONAL_DATA or regional_df is None or region_col is None:
    print("\n⚠️  Skipping atlas visualization - no regional data available")
else:
    print("\n" + "-"*80)
    print("10.4 BRAIN ATLAS VISUALIZATION")
    print("-"*80)
    
    # Import required libraries for atlas visualization
    from scipy import ndimage
    
    def plot_atlas_panel(ax, atlas_slice, value_slice, title):
        """Create a single atlas visualization panel."""
        # 1. WHITE background
        ax.imshow(np.ones_like(atlas_slice), cmap='gray', vmin=0, vmax=1)
        # 2. LIGHT GREY fill inside regions
        mask = atlas_slice > 0
        bg = np.ones_like(atlas_slice)
        bg[mask] = 0.85
        ax.imshow(bg, cmap='gray', vmin=0, vmax=1)
        # 3. SMOOTH GREY contours
        smooth = ndimage.gaussian_filter(atlas_slice.astype(float), sigma=0.8)
        regions_arr = np.unique(atlas_slice)[1:]
        if regions_arr.size:
            ax.contour(
                atlas_slice.astype(float), levels=regions_arr + 0.5, colors='#BBBBBB',
                linewidths=0.5, alpha=0.6, antialiased=True
            )
        # 4. Value overlay
        vals = np.ma.masked_where(value_slice == 0, value_slice)
        im = ax.imshow(vals, cmap='RdBu_r', alpha=0.9)
        ax.set_title(title, fontsize=14, pad=6)
        ax.axis('off')
        return im
    
    def create_regional_atlas_visualization(atlas_path, regional_data_dict, 
                                            output_prefix='regional_atlas',
                                            title='Regional Feature Distribution',
                                            vmin=None, vmax=None,
                                            colorbar_label='Feature Value',
                                            use_diverging=False):
        """
        Create atlas visualization for regional data.
        
        Parameters:
        -----------
        atlas_path : str or None
            Path to atlas NIfTI file (if available)
        regional_data_dict : dict
            Dictionary mapping region IDs to values
        output_prefix : str
            Prefix for output files
        title : str
            Figure title
        vmin, vmax : float or None
            Value range for colormap
        colorbar_label : str
            Label for colorbar
        use_diverging : bool
            If True, use diverging colormap centered at 0 (for z-scores, correlations, effect sizes)
        """
        if len(regional_data_dict) == 0:
            print(f"⚠️  No data to visualize for {output_prefix}")
            return None

        # Check if atlas file exists
        if atlas_path is None or not Path(atlas_path).exists():
            print(f"⚠️  Atlas file not found at {atlas_path}")
            print("   Creating bar plot visualization instead...")
            
            # Create bar plot as alternative
            fig, ax = plt.subplots(figsize=(12, max(6, len(regional_data_dict) * 0.3)))
            
            regions_list = sorted(regional_data_dict.keys())
            values = [regional_data_dict[r] for r in regions_list]
            
            # Color bars by value
            if vmin is None:
                vmin = min(values)
            if vmax is None:
                vmax = max(values)
            
            # COLORS
            if use_diverging:
                # Center at 0 for z-scores
                max_abs = max(abs(vmin), abs(vmax))
                vmin, vmax = -max_abs, max_abs
                norm = plt.Normalize(vmin=vmin, vmax=vmax)
                cmap = plt.cm.RdBu_r
            else:
                norm = plt.Normalize(vmin=vmin, vmax=vmax)
                cmap = plt.cm.OrRd
            
            colors = [cmap(norm(val)) for val in values]
            bars = ax.barh(range(len(regions_list)), values, color=colors, alpha=0.8, edgecolor='black')
            ax.set_yticks(range(len(regions_list)))
            ax.set_yticklabels([f'Region {r}' for r in regions_list], fontsize=9)
            ax.set_xlabel(colorbar_label, fontweight='bold', fontsize=11)
            ax.set_title(title, fontweight='bold', fontsize=13)
            ax.grid(True, alpha=0.3, axis='x')
            
            # Add zero line for z-scores
            if use_diverging:
                ax.axvline(0, color='black', linestyle='-', linewidth=1.5, alpha=0.8)
            
            # Add colorbar
            sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
            sm.set_array([])
            cbar = plt.colorbar(sm, ax=ax, pad=0.02)
            cbar.set_label(colorbar_label, fontsize=10)
            
            plt.tight_layout()
            plt.savefig(FIGURES_DIR / f'{output_prefix}.png', dpi=150, bbox_inches='tight')
            plt.show()
            plt.close('all')
            
            return None

        try:
            import nibabel as nib
            
            print(f"Loading atlas from {atlas_path}...")
            atlas_img = nib.load(atlas_path)
            atlas_data = atlas_img.get_fdata()
            
            # Create value map
            value_map = np.zeros_like(atlas_data)
            
            print("Mapping values to regions...")
            for region_id, value in regional_data_dict.items():
                mask = atlas_data == region_id
                if np.any(mask):
                    value_map[mask] = value
                    print(f"  Region {region_id}: value={value:.3f}, voxels={np.sum(mask)}")
            
            # Crop to brain region
            atlas_cropped = atlas_data[50:-50, 40:-40, 50:-50]
            value_cropped = value_map[50:-50, 40:-40, 50:-50]
            
            # Get slice positions
            x, y, z = atlas_cropped.shape
            x_slice, y_slice, z_slice = x//2, y//2, z//2
            
            slices = {
                'Sagittal': (np.rot90(atlas_cropped[x_slice, :, :]), np.rot90(value_cropped[x_slice, :, :])),
                'Coronal': (np.rot90(atlas_cropped[:, y_slice, :]), np.rot90(value_cropped[:, y_slice, :])),
                'Axial': (np.rot90(atlas_cropped[:, :, z_slice]), np.rot90(value_cropped[:, :, z_slice]))
            }
            
            # Create figure
            fig, axes = plt.subplots(1, 3, figsize=(15, 6))
            fig.suptitle(title, fontsize=16, fontweight='bold')
            
            # Determine value range if not provided
            non_zero_values = value_map[value_map != 0]
            if vmin is None and len(non_zero_values) > 0:
                vmin = np.percentile(non_zero_values, 5)
            if vmax is None and len(non_zero_values) > 0:
                vmax = np.percentile(non_zero_values, 95)
            
            # For diverging colormaps, ensure symmetric range around 0
            if use_diverging and vmin is not None and vmax is not None:
                max_abs = max(abs(vmin), abs(vmax))
                vmin, vmax = -max_abs, max_abs
            
            # Plot each view
            for ax, (view_name, (atlas_sl, val_sl)) in zip(axes, slices.items()):
                im = plot_atlas_panel(ax, atlas_sl, val_sl, view_name)
                im.set_clim(vmin, vmax)
            
            # Use appropriate colormap
            if use_diverging:
                im.set_cmap('RdBu_r')
            else:
                im.set_cmap('OrRd')
            
            # Colorbar
            cbar = fig.colorbar(im, ax=axes, orientation='horizontal', 
                                fraction=0.05, pad=0.05)
            cbar.set_label(colorbar_label, fontsize=12)
            
            plt.tight_layout()
            plt.savefig(FIGURES_DIR / f'{output_prefix}.png', dpi=150, bbox_inches='tight')
            plt.show()
            plt.close('all')
            
            print(f"✓ Atlas visualization saved to {FIGURES_DIR / f'{output_prefix}.png'}")
            
            return value_map
            
        except ImportError:
            print("⚠️  nibabel not installed. Install with: pip install nibabel")
            return None
        except Exception as e:
            print(f"⚠️  Error creating atlas visualization: {e}")
            import traceback
            traceback.print_exc()
            return None
    
    # ========================================================================
    # Check if we have significant regional differences to visualize
    # ========================================================================
    
    if 'regional_comp_df' in locals() and len(regional_comp_df) > 0 and regional_comp_df['significant'].sum() > 0:
        print("\nGenerating brain atlas visualizations for top features with regional differences...")
        print(f"Using atlas file: {atlas_path}")
        
        # Get top features with significant regional differences
        top_regional_features = regional_comp_df[regional_comp_df['significant']].head(6)['feature'].tolist()
        
        if len(top_regional_features) > 0:
            # Visualize top 6 features with Z-SCORES
            for idx, feature in enumerate(top_regional_features[:6]):
                print(f"\n{'='*60}")
                print(f"Visualizing feature {idx+1}/6: {feature}")
                print('='*60)
                
                # Get regional means for this feature
                regional_means = regional_df.groupby(region_col)[feature].mean()
                
                # CALCULATE Z-SCORES (normalize across regions)
                mean_val = regional_means.mean()
                std_val = regional_means.std()
                
                if std_val > 0:
                    regional_zscores = (regional_means - mean_val) / std_val
                else:
                    regional_zscores = regional_means - mean_val
                
                # Create dictionary mapping region to Z-SCORE
                regional_data_dict = regional_zscores.to_dict()
                
                print(f"Regional z-scores for {feature}:")
                for region, zscore in sorted(regional_data_dict.items()):
                    print(f"  Region {region}: z={zscore:.3f}")
                
                # Clean feature name
                clean_feature = feature.replace('_', ' ').replace('total ', '').replace('mean ', '').title()
                
                # Get statistics
                feat_stats = regional_comp_df[regional_comp_df['feature'] == feature].iloc[0]
                eta_sq = feat_stats['eta_squared']
                p_val = feat_stats['anova_pvalue_fdr']
                
                if p_val < 0.001:
                    sig_text = '***'
                    p_full = 'p < 0.001'
                elif p_val < 0.01:
                    sig_text = '**'
                    p_full = f'p = {p_val:.3f}'
                else:
                    sig_text = '*'
                    p_full = f'p = {p_val:.3f}'
                
                # Create visualization with Z-SCORES
                create_regional_atlas_visualization(
                    atlas_path=atlas_path,
                    regional_data_dict=regional_data_dict,
                    output_prefix=f'regional_atlas_{feature}_zscore',
                    title=f'{clean_feature} (Z-scored) {sig_text}\nη² = {eta_sq:.3f}, {p_full}',
                    colorbar_label='Z-score',
                    use_diverging=True  # Use RdBu_r colormap centered at 0
                )
            
            print("\n" + "="*80)
            print("✓ BRAIN ATLAS VISUALIZATION COMPLETE")
            print("="*80)
        else:
            print("\n⚠️  No significant regional differences found to visualize")
    else:
        print("\n⚠️  No regional comparison data available or no significant differences to visualize")

In [None]:
skip_it =True

In [None]:
if not IS_REGIONAL_DATA or regional_df is None:
    print("\n⚠️  No regional data available. Skipping regional analysis.")
elif skip_it:
    print(f"This cell take too long: skip it")
else:
    print("\n" + "="*80)
    print("COMPREHENSIVE REGIONAL ANALYSIS")
    print("="*80)
    # Identify region column
    region_col = None
    for col_name in ['region', 'region_id', 'region_label', 'Region', 'REGION']:
        if col_name in regional_df.columns:
            region_col = col_name
            break
    
    if region_col is None:
        print("\n⚠️  No region identifier column found")
        print(f"Available columns: {regional_df.columns.tolist()}")
    else:
        # ====================================================================
        # 10.4 Age-Region Interactions
        # ====================================================================
        if 'AGE' in regional_with_demo.columns:
            print("\n" + "-"*80)
            print("10.4 AGE × REGION INTERACTIONS")
            print("-"*80)
            
            age_region_interactions = []
            
            for feature in ALL_FEATURES:
                if feature not in regional_with_demo.columns:
                    continue
                
                # Prepare data
                data_for_model = regional_with_demo[[region_col, 'AGE', feature]].dropna()
                
                if len(data_for_model) < 50:  # Need sufficient data
                    continue
                
                try:
                    # Fit model with interaction
                    formula = f'{feature} ~ AGE + C({region_col}) + AGE:C({region_col})'
                    model = smf.ols(formula, data=data_for_model).fit()
                    
                    # Check if any interaction terms are significant
                    interaction_terms = [p for p in model.pvalues.index if 'AGE:C(' in p]
                    
                    if len(interaction_terms) > 0:
                        min_interaction_p = model.pvalues[interaction_terms].min()
                        
                        age_region_interactions.append({
                            'feature': feature,
                            'category': feature_categories[feature],
                            'n': len(data_for_model),
                            'min_interaction_pvalue': min_interaction_p,
                            'model_r_squared': model.rsquared,
                            'n_regions_tested': len(data_for_model[region_col].unique())
                        })
                except:
                    continue
            
            if len(age_region_interactions) > 0:
                age_region_df = pd.DataFrame(age_region_interactions)
                age_region_df['min_interaction_pvalue_fdr'] = multipletests(
                    age_region_df['min_interaction_pvalue'], method='fdr_bh')[1]
                age_region_df['significant'] = age_region_df['min_interaction_pvalue_fdr'] < 0.05
                age_region_df = age_region_df.sort_values('min_interaction_pvalue')
                
                n_sig = age_region_df['significant'].sum()
                print(f"\nAge×Region interaction summary:")
                print(f"  Features tested: {len(age_region_df)}")
                print(f"  Significant interactions (FDR<0.05): {n_sig}")
                
                if n_sig > 0:
                    print(f"\nFeatures with significant Age×Region interaction:")
                    display(age_region_df[age_region_df['significant']][
                        ['feature', 'category', 'min_interaction_pvalue_fdr', 'model_r_squared']
                    ].head(15))
                
                age_region_df.to_csv(TABLES_DIR / 'age_region_interactions.csv', index=False)
                print(f"✓ Age×Region interactions saved to {TABLES_DIR / 'age_region_interactions.csv'}")

In [None]:
if not IS_REGIONAL_DATA or regional_df is None:
    print("\n⚠️  No regional data available. Skipping regional analysis.")
else:
    print("\n" + "="*80)
    print("COMPREHENSIVE REGIONAL ANALYSIS")
    print("="*80)
    
    # Identify region column
    region_col = None
    for col_name in ['region', 'region_id', 'region_label', 'Region', 'REGION']:
        if col_name in regional_df.columns:
            region_col = col_name
            break
    
    if region_col is None:
        print("\n⚠️  No region identifier column found")
        print(f"Available columns: {regional_df.columns.tolist()}")
    else:
        # ====================================================================
        # 10.5 Regional Age Correlations
        # ====================================================================
        if 'AGE' in regional_with_demo.columns:
            print("\n" + "-"*80)
            print("10.5 REGIONAL AGE CORRELATIONS")
            print("-"*80)
            
            regional_age_corr = []
            
            for region in regions:
                region_data = regional_with_demo[regional_with_demo[region_col] == region]
                
                for feature in ALL_FEATURES:
                    if feature not in region_data.columns:
                        continue
                    
                    valid_data = region_data[['AGE', feature]].dropna()
                    
                    if len(valid_data) < 10:
                        continue
                    
                    r, p = pearsonr(valid_data['AGE'], valid_data[feature])
                    
                    regional_age_corr.append({
                        'region': region,
                        'feature': feature,
                        'category': feature_categories[feature],
                        'n': len(valid_data),
                        'correlation': r,
                        'pvalue': p
                    })
            
            if len(regional_age_corr) > 0:
                regional_age_corr_df = pd.DataFrame(regional_age_corr)
                
                # FDR correction within each region
                for region in regions:
                    region_mask = regional_age_corr_df['region'] == region
                    if region_mask.sum() > 0:
                        regional_age_corr_df.loc[region_mask, 'pvalue_fdr'] = multipletests(
                            regional_age_corr_df.loc[region_mask, 'pvalue'], method='fdr_bh')[1]
                
                regional_age_corr_df['significant'] = regional_age_corr_df['pvalue_fdr'] < 0.05
                
                # Summary by region
                print("\nSignificant age correlations by region:")
                for region in regions:
                    region_data = regional_age_corr_df[regional_age_corr_df['region'] == region]
                    n_sig = region_data['significant'].sum()
                    print(f"  Region {region}: {n_sig} significant correlations")
                
                # Find features with consistent age effects across regions
                feature_consistency = []
                for feature in ALL_FEATURES:
                    feature_data = regional_age_corr_df[regional_age_corr_df['feature'] == feature]
                    if len(feature_data) >= n_regions * 0.5:  # Present in at least half the regions
                        n_sig = feature_data['significant'].sum()
                        mean_r = feature_data['correlation'].mean()
                        std_r = feature_data['correlation'].std()
                        
                        feature_consistency.append({
                            'feature': feature,
                            'category': feature_categories[feature],
                            'n_regions': len(feature_data),
                            'n_significant': n_sig,
                            'mean_correlation': mean_r,
                            'std_correlation': std_r,
                            'consistency': 1 - std_r  # Higher = more consistent
                        })
                
                if len(feature_consistency) > 0:
                    consistency_df = pd.DataFrame(feature_consistency)
                    consistency_df = consistency_df.sort_values('n_significant', ascending=False)
                    
                    print(f"\nFeatures with most consistent age effects across regions:")
                    display(consistency_df.head(15))
                    
                    consistency_df.to_csv(TABLES_DIR / 'regional_age_correlation_consistency.csv', index=False)
                    print(f"✓ Regional age correlation consistency saved")
                
                regional_age_corr_df.to_csv(TABLES_DIR / 'regional_age_correlations.csv', index=False)
                print(f"✓ Regional age correlations saved to {TABLES_DIR / 'regional_age_correlations.csv'}")

In [None]:
if not IS_REGIONAL_DATA or regional_df is None:
    print("\n⚠️  No regional data available. Skipping regional analysis.")
else:
    print("\n" + "="*80)
    print("COMPREHENSIVE REGIONAL ANALYSIS")
    print("="*80)
    
    # Identify region column
    region_col = None
    for col_name in ['region', 'region_id', 'region_label', 'Region', 'REGION']:
        if col_name in regional_df.columns:
            region_col = col_name
            break
    
    if region_col is None:
        print("\n⚠️  No region identifier column found")
        print(f"Available columns: {regional_df.columns.tolist()}")
    else:
        # ====================================================================
        # 10.6 Visualize Top Regional Differences
        # ====================================================================
        print("\n" + "-"*80)
        print("10.6 VISUALIZING TOP REGIONAL DIFFERENCES")
        print("-"*80)

        if 'regional_comp_df' in locals() and len(regional_comp_df) > 0:
            # Select top 6 features with largest regional differences
            top_regional_features = regional_comp_df.head(6)['feature'].tolist()
            
            fig, axes = plt.subplots(2, 3, figsize=(18, 12))
            fig.suptitle('Top Features with Regional Differences', fontsize=16, fontweight='bold')
            axes = axes.flatten()
            
            for idx, feature in enumerate(top_regional_features):
                ax = axes[idx]
                
                # Extract data as numpy arrays first, then rebuild DataFrame
                region_values = regional_df[region_col].values
                feature_values = regional_df[feature].values
                
                # Create a clean DataFrame from scratch
                plot_data = pd.DataFrame({
                    'region': region_values,
                    'feature_value': feature_values
                })
                
                # Remove NaN values
                plot_data = plot_data.dropna()
                
                # Convert region to string
                plot_data['region'] = plot_data['region'].astype(str)
                
                # Get unique regions and sort
                unique_regions = plot_data['region'].unique()
                
                try:
                    region_order = sorted(unique_regions, 
                                        key=lambda x: int(x) if x.isdigit() else x)
                except:
                    region_order = sorted(unique_regions, key=str)
                
                # Box plot with explicit order
                sns.boxplot(data=plot_data, x='region', y='feature_value', 
                        order=region_order, ax=ax, palette='Set3')
                
                # Get stats
                feat_stats = regional_comp_df[regional_comp_df['feature'] == feature].iloc[0]
                
                ax.set_title(f"{feature}\nη²={feat_stats['eta_squared']:.3f}, p={feat_stats['anova_pvalue_fdr']:.2e}",
                        fontsize=10)
                ax.set_xlabel('Region', fontweight='bold')
                ax.set_ylabel(feature, fontsize=9)
                ax.tick_params(axis='x', rotation=45)
                ax.grid(True, alpha=0.3, axis='y')
                
                # Reduce number of x-tick labels if too many regions
                if len(region_order) > 15:
                    current_labels = ax.get_xticklabels()
                    for i, label in enumerate(current_labels):
                        if i % 2 != 0:  # Hide every other label
                            label.set_visible(False)
            
            plt.tight_layout()
            plt.savefig(FIGURES_DIR / 'regional_differences_top6.png', dpi=300, bbox_inches='tight')
            plt.show()
            print(f"✓ Regional differences visualization saved to {FIGURES_DIR / 'regional_differences_top6.png'}")
        else:
            print("⚠️  No regional comparison results available for visualization")

In [None]:
if not IS_REGIONAL_DATA or regional_df is None:
    print("\n⚠️  No regional data available. Skipping regional analysis.")
else:
    print("\n" + "="*80)
    print("COMPREHENSIVE REGIONAL ANALYSIS")
    print("="*80)
    
    # Identify region column
    region_col = None
    for col_name in ['region', 'region_id', 'region_label', 'Region', 'REGION']:
        if col_name in regional_df.columns:
            region_col = col_name
            break
    
    if region_col is None:
        print("\n⚠️  No region identifier column found")
        print(f"Available columns: {regional_df.columns.tolist()}")
    else:
        # ====================================================================
        # 10.7 Regional Profiles - Feature Fingerprints
        # ====================================================================
        print("\n" + "-"*80)
        print("10.7 REGIONAL FEATURE PROFILES")
        print("-"*80)
        
        # Create "fingerprint" plots showing normalized feature values per region
        if 'regional_var_df' in locals() and len(regional_var_df) > 0:
            # Select most variable features
            top_variable_features = regional_var_df.head(min(30, len(regional_var_df)))['feature'].tolist()
            
            # Create normalized profiles
            profile_data = []
            for region in regions:
                region_data = regional_df[regional_df[region_col] == region]
                profile = {}
                for feature in top_variable_features:
                    if feature in region_data.columns:
                        profile[feature] = region_data[feature].mean()
                profile['region'] = region
                profile_data.append(profile)
            
            profile_df = pd.DataFrame(profile_data).set_index('region')
            
            # Z-score normalization
            profile_normalized = (profile_df - profile_df.mean()) / profile_df.std()
            
            # Plot
            fig, ax = plt.subplots(figsize=(14, max(8, n_regions * 0.1)))
            sns.heatmap(profile_normalized.T, cmap='RdBu_r', center=0, 
                       cbar_kws={'label': 'Z-score'}, ax=ax,
                       linewidths=0.5, linecolor='gray')
            ax.set_title('Regional Feature Profiles (Top Variable Features)', 
                        fontweight='bold', fontsize=14)
            ax.set_xlabel('Region', fontweight='bold')
            ax.set_ylabel('Feature', fontweight='bold')
            plt.tight_layout()
            plt.savefig(FIGURES_DIR / 'regional_feature_profiles.png', dpi=300, bbox_inches='tight')
            plt.show()
            print(f"✓ Regional feature profiles saved to {FIGURES_DIR / 'regional_feature_profiles.png'}")
        
        print("\n" + "="*80)
        print("✅ COMPREHENSIVE REGIONAL ANALYSIS COMPLETE")
        print("="*80)
        print(f"\nGenerated outputs:")
        print(f"  • Regional variability analysis")
        print(f"  • Regional heatmaps by feature category")
        print(f"  • Regional ANOVA comparisons")
        if 'AGE' in regional_with_demo.columns:
            print(f"  • Age×Region interaction analysis")
            print(f"  • Regional age correlations")
            print(f"  • Regional age correlation consistency")
        print(f"  • Top regional differences visualization")
        print(f"  • Regional feature profiles")

In [None]:
# ====================================================================
# 10.8 Atlas Visualization of Regional Results - All Demographics
# ====================================================================
print("\n" + "="*80)
print("10.8 ATLAS VISUALIZATION OF REGIONAL RESULTS - ALL DEMOGRAPHICS")
print("="*80)

# Import required libraries for atlas visualization
from scipy import ndimage

def plot_atlas_panel(ax, atlas_slice, value_slice, title):
    """Create a single atlas visualization panel."""
    # 1. WHITE background
    ax.imshow(np.ones_like(atlas_slice), cmap='gray', vmin=0, vmax=1)
    # 2. LIGHT GREY fill inside regions
    mask = atlas_slice > 0
    bg = np.ones_like(atlas_slice)
    bg[mask] = 0.85
    ax.imshow(bg, cmap='gray', vmin=0, vmax=1)
    # 3. SMOOTH GREY contours
    smooth = ndimage.gaussian_filter(atlas_slice.astype(float), sigma=0.8)
    regions_arr = np.unique(atlas_slice)[1:]
    if regions_arr.size:
        ax.contour(
            atlas_slice.astype(float), levels=regions_arr + 0.5, colors='#BBBBBB',
            linewidths=0.5, alpha=0.6, antialiased=True
        )
    # 4. Value overlay
    vals = np.ma.masked_where(value_slice == 0, value_slice)
    im = ax.imshow(vals, cmap='OrRd', alpha=0.9)
    ax.set_title(title, fontsize=14, pad=6)
    ax.axis('off')
    return im

def create_regional_value_map(regional_stats_df, feature_name, region_col='region', 
                               aggregation='mean'):
    """
    Create a dictionary mapping region IDs to feature values.
    
    Parameters:
    -----------
    regional_stats_df : pd.DataFrame
        DataFrame with regional statistics
    feature_name : str
        Name of the feature to map
    region_col : str
        Name of the region column
    aggregation : str
        How to aggregate ('mean', 'median', 'std', etc.)
    
    Returns:
    --------
    dict : Region ID -> feature value mapping
    """
    feature_data = regional_stats_df[regional_stats_df['feature'] == feature_name]
    
    if len(feature_data) == 0:
        return {}
    
    if aggregation == 'mean':
        return dict(zip(feature_data['region'], feature_data['mean']))
    elif aggregation == 'median':
        return dict(zip(feature_data['region'], feature_data['median']))
    elif aggregation == 'std':
        return dict(zip(feature_data['region'], feature_data['std']))
    elif aggregation == 'cv':
        return dict(zip(feature_data['region'], feature_data['cv']))
    else:
        raise ValueError(f"Unknown aggregation: {aggregation}")

def create_regional_atlas_visualization(atlas_path, regional_data_dict, 
                                        output_prefix='regional_atlas',
                                        title='Regional Feature Distribution',
                                        vmin=None, vmax=None,
                                        colorbar_label='Feature Value',
                                        use_diverging=False):
    """
    Create atlas visualization for regional data.
    
    Parameters:
    -----------
    atlas_path : str or None
        Path to atlas NIfTI file (if available)
    regional_data_dict : dict
        Dictionary mapping region IDs to values
    output_prefix : str
        Prefix for output files
    title : str
        Figure title
    vmin, vmax : float or None
        Value range for colormap
    colorbar_label : str
        Label for colorbar
    use_diverging : bool
        If True, use diverging colormap centered at 0 (for correlations, effect sizes)
    """
    if len(regional_data_dict) == 0:
        print(f"⚠️  No data to visualize for {output_prefix}")
        return None

    # Check if atlas file exists
    if atlas_path is None or not Path(atlas_path).exists():
        print(f"⚠️  Atlas file not found at {atlas_path}")
        print("   Creating bar plot visualization instead...")
        
        # Create bar plot as alternative
        fig, ax = plt.subplots(figsize=(12, max(6, len(regional_data_dict) * 0.3)))
        
        regions_list = sorted(regional_data_dict.keys())
        values = [regional_data_dict[r] for r in regions_list]
        
        bars = ax.barh(range(len(regions_list)), values, color='steelblue', alpha=0.7, edgecolor='black')
        ax.set_yticks(range(len(regions_list)))
        ax.set_yticklabels([f'Region {r}' for r in regions_list])
        ax.set_xlabel(colorbar_label, fontweight='bold')
        ax.set_title(title, fontweight='bold', fontsize=14)
        ax.grid(True, alpha=0.3, axis='x')
        
        # Color bars by value
        if vmin is None:
            vmin = min(values)
        if vmax is None:
            vmax = max(values)
        
        # COLORSSS
        if use_diverging:
            # Center at 0 for correlations
            max_abs = max(abs(vmin), abs(vmax))
            vmin, vmax = -max_abs, max_abs
            norm = plt.Normalize(vmin=vmin, vmax=vmax)
            cmap = plt.cm.RdBu_r
        else:
            norm = plt.Normalize(vmin=vmin, vmax=vmax)
            cmap = plt.cm.OrRd
        
        for bar, val in zip(bars, values):
            bar.set_color(cmap(norm(val)))
        
        # Add colorbar
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax, pad=0.02)
        cbar.set_label(colorbar_label)
        
        plt.tight_layout()
        plt.savefig(FIGURES_DIR / f'{output_prefix}.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        return None

    try:
        import nibabel as nib
        
        print(f"Loading atlas from {atlas_path}...")
        atlas_img = nib.load(atlas_path)
        atlas_data = atlas_img.get_fdata()
        
        # Create value map
        value_map = np.zeros_like(atlas_data)
        
        print("Mapping values to regions...")
        for region_id, value in regional_data_dict.items():
            mask = atlas_data == region_id
            if np.any(mask):
                value_map[mask] = value
                print(f"  Region {region_id}: value={value:.3f}, voxels={np.sum(mask)}")
        
        # Crop to brain region
        atlas_cropped = atlas_data[50:-50, 40:-40, 50:-50]
        value_cropped = value_map[50:-50, 40:-40, 50:-50]
        
        # Get slice positions
        x, y, z = atlas_cropped.shape
        x_slice, y_slice, z_slice = x//2, y//2, z//2
        
        slices = {
            'Sagittal': (np.rot90(atlas_cropped[x_slice, :, :]), np.rot90(value_cropped[x_slice, :, :])),
            'Coronal': (np.rot90(atlas_cropped[:, y_slice, :]), np.rot90(value_cropped[:, y_slice, :])),
            'Axial': (np.rot90(atlas_cropped[:, :, z_slice]), np.rot90(value_cropped[:, :, z_slice]))
        }
        
        # Create figure
        fig, axes = plt.subplots(1, 3, figsize=(15, 6))
        fig.suptitle(title, fontsize=16, fontweight='bold')
        
        # Determine value range if not provided
        non_zero_values = value_map[value_map != 0]
        if vmin is None and len(non_zero_values) > 0:
            vmin = np.percentile(non_zero_values, 5)
        if vmax is None and len(non_zero_values) > 0:
            vmax = np.percentile(non_zero_values, 95)
        
        # For diverging colormaps, ensure symmetric range around 0
        if use_diverging and vmin is not None and vmax is not None:
            max_abs = max(abs(vmin), abs(vmax))
            vmin, vmax = -max_abs, max_abs
        
        # Plot each view
        for ax, (view_name, (atlas_sl, val_sl)) in zip(axes, slices.items()):
            im = plot_atlas_panel(ax, atlas_sl, val_sl, view_name)
            im.set_clim(vmin, vmax)
        
        # Use appropriate colormap
        if use_diverging:
            im.set_cmap('OrRd')
        
        # Colorbar
        cbar = fig.colorbar(im, ax=axes, orientation='horizontal', 
                            fraction=0.05, pad=0.05)
        cbar.set_label(colorbar_label, fontsize=12)
        
        plt.tight_layout()
        plt.savefig(FIGURES_DIR / f'{output_prefix}.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        print(f"✓ Atlas visualization saved to {FIGURES_DIR / f'{output_prefix}.png'}")
        
        return value_map
        
    except ImportError:
        print("⚠️  nibabel not installed. Install with: pip install nibabel")
        return None
    except Exception as e:
        print(f"⚠️  Error creating atlas visualization: {e}")
        return None

In [None]:
ATLAS_PATH = '/home/falcetta/ISBI2025/LIANE/ArterialAtlas.nii.gz'  # Set to your atlas file path

In [None]:
skip_it=False
# ====================================================================
# Create atlas visualizations for all demographic analyses
# ====================================================================
if skip_it:
    print("Skipping")
else:
    if len(regional_stats_df) > 0:
        print("\nCreating comprehensive atlas visualizations...")
        
        # ================================================================
        # 1. Regional Feature Mean Values
        # ================================================================
        print("\n1. REGIONAL FEATURE MEAN VALUES")
        print("-"*80)
        
        # Select all features sorted by variability
        top_variable = regional_var_df['feature'].tolist()
        top_n_features = len(top_variable)
        
        for feature in top_variable:
            print(f"\n  Creating visualization for {feature}...")
            
            # Get regional means for this feature
            regional_values = create_regional_value_map(
                regional_stats_df, 
                feature, 
                region_col=region_col, 
                aggregation='mean'
            )
            
            # Create visualization
            create_regional_atlas_visualization(
                atlas_path=ATLAS_PATH,
                regional_data_dict=regional_values,
                output_prefix=f'regional_atlas_{feature}',
                title=f'{feature} - Mean by Region',
                colorbar_label=f'{feature} (mean)'
            )
        
        # ================================================================
        # 2. Demographic Correlations by Region
        # ================================================================
        if 'regional_demo_correlations' in locals() and len(regional_demo_correlations) > 0:
            print("\n2. DEMOGRAPHIC CORRELATIONS BY REGION")
            print("-"*80)
            
            for demo_var, regional_corr_df in regional_demo_correlations.items():
                print(f"\n  Processing {demo_var} correlations...")
                
                # Find features with strongest average correlation
                avg_corr = regional_corr_df.groupby('feature')['correlation'].agg(['mean', 'std'])
                avg_corr['abs_mean'] = avg_corr['mean'].abs()
                top_corr_features = avg_corr.nlargest(min(10, len(avg_corr)), 'abs_mean').index.tolist()
                
                for feature in top_corr_features:
                    print(f"    Creating {demo_var} correlation map for {feature}...")
                    
                    # Get correlations by region
                    feature_corr = regional_corr_df[regional_corr_df['feature'] == feature]
                    corr_dict = dict(zip(feature_corr['region'], feature_corr['correlation']))
                    
                    create_regional_atlas_visualization(
                        atlas_path=ATLAS_PATH,
                        regional_data_dict=corr_dict,
                        output_prefix=f'regional_{demo_var.lower()}_corr_{feature}',
                        title=f'{feature} - {demo_var} Correlation by Region',
                        vmin=-1, vmax=1,
                        colorbar_label=f'Pearson r with {demo_var}',
                        use_diverging=True
                    )
        
        # ================================================================
        # 3. Visualize Sex Differences (Cohen's d) by Region
        # ================================================================
        if 'regional_sex_df' in locals() and len(regional_sex_df) > 0:
            print("\n3. SEX DIFFERENCES (COHEN'S D) BY REGION")
            print("-"*80)
            
            # Find features with largest average effect size
            avg_effect = regional_sex_df.groupby('feature')['cohens_d'].agg(['mean', 'std'])
            avg_effect['abs_mean'] = avg_effect['mean'].abs()
            top_sex_features = avg_effect.nlargest(min(10, len(avg_effect)), 'abs_mean').index.tolist()
            
            for feature in top_sex_features:
                print(f"\n  Creating sex difference map for {feature}...")
                
                # Get Cohen's d by region
                feature_sex = regional_sex_df[regional_sex_df['feature'] == feature]
                sex_dict = dict(zip(feature_sex['region'], feature_sex['cohens_d']))
                
                create_regional_atlas_visualization(
                    atlas_path=ATLAS_PATH,
                    regional_data_dict=sex_dict,
                    output_prefix=f'regional_sex_cohens_d_{feature}',
                    title=f'{feature} - Sex Difference (Cohen\'s d) by Region',
                    vmin=-2, vmax=2,
                    colorbar_label='Cohen\'s d (Male - Female)',
                    use_diverging=True
                )
        
        # ================================================================
        # 4. Export Comprehensive Regional Data
        # ================================================================
        print("\n4. EXPORTING COMPREHENSIVE REGIONAL DATA")
        print("-"*80)
        
        # Create comprehensive export with all demographics
        comprehensive_export = []
        
        for feature in ALL_FEATURES:
            if feature not in regional_df.columns:
                continue
                
            for region in regions:
                region_data = regional_df[regional_df[region_col] == region][feature].dropna()
                
                if len(region_data) > 0:
                    export_row = {
                        'region': region,
                        'feature': feature,
                        'category': feature_categories[feature],
                        'mean': region_data.mean(),
                        'std': region_data.std(),
                        'median': region_data.median(),
                        'n_subjects': len(region_data)
                    }
                    
                    # Add correlations for all demographic variables
                    if 'regional_demo_correlations' in locals():
                        for demo_var, regional_corr_df in regional_demo_correlations.items():
                            corr_match = regional_corr_df[
                                (regional_corr_df['region'] == region) & 
                                (regional_corr_df['feature'] == feature)
                            ]
                            if len(corr_match) > 0:
                                export_row[f'{demo_var}_correlation'] = corr_match.iloc[0]['correlation']
                                export_row[f'{demo_var}_pvalue'] = corr_match.iloc[0]['pvalue']
                                export_row[f'{demo_var}_significant'] = corr_match.iloc[0]['significant']
                    
                    # Add sex differences
                    if 'regional_sex_df' in locals():
                        sex_match = regional_sex_df[
                            (regional_sex_df['region'] == region) & 
                            (regional_sex_df['feature'] == feature)
                        ]
                        if len(sex_match) > 0:
                            export_row['sex_cohens_d'] = sex_match.iloc[0]['cohens_d']
                            export_row['sex_pvalue'] = sex_match.iloc[0]['pvalue']
                            export_row['sex_significant'] = sex_match.iloc[0]['significant']
                    
                    comprehensive_export.append(export_row)
        
        comprehensive_export_df = pd.DataFrame(comprehensive_export)
        comprehensive_export_df.to_csv(TABLES_DIR / 'regional_comprehensive_demographics.csv', index=False)
        print(f"✓ Comprehensive regional data exported to {TABLES_DIR / 'regional_comprehensive_demographics.csv'}")
        
        # ================================================================
        # 5. Create JSON Export for Custom Atlas Tools
        # ================================================================
        print("\n5. CREATING JSON EXPORTS")
        print("-"*80)
        
        json_exports = {}
        
        # Export mean values for top features
        json_exports['feature_means'] = {}
        for feature in top_variable[:10]:
            regional_values = create_regional_value_map(
                regional_stats_df, feature, region_col='region', aggregation='mean'
            )
            json_exports['feature_means'][feature] = {str(k): float(v) for k, v in regional_values.items()}
        
        # Export demographic correlations
        if 'regional_demo_correlations' in locals():
            json_exports['demographic_correlations'] = {}
            for demo_var, regional_corr_df in regional_demo_correlations.items():
                json_exports['demographic_correlations'][demo_var] = {}
                
                # Top 5 features for each demographic variable
                avg_corr = regional_corr_df.groupby('feature')['correlation'].agg('mean')
                top_features = avg_corr.abs().nlargest(5).index.tolist()
                
                for feature in top_features:
                    feature_corr = regional_corr_df[regional_corr_df['feature'] == feature]
                    corr_dict = dict(zip(feature_corr['region'], feature_corr['correlation']))
                    json_exports['demographic_correlations'][demo_var][feature] = {
                        str(k): float(v) for k, v in corr_dict.items()
                    }
        
        # Export sex differences
        if 'regional_sex_df' in locals():
            json_exports['sex_differences'] = {}
            avg_effect = regional_sex_df.groupby('feature')['cohens_d'].agg('mean')
            top_sex_features = avg_effect.abs().nlargest(5).index.tolist()
            
            for feature in top_sex_features:
                feature_sex = regional_sex_df[regional_sex_df['feature'] == feature]
                sex_dict = dict(zip(feature_sex['region'], feature_sex['cohens_d']))
                json_exports['sex_differences'][feature] = {
                    str(k): float(v) for k, v in sex_dict.items()
                }
        
        # Save JSON
        import json
        with open(TABLES_DIR / 'regional_atlas_data_all_demographics.json', 'w') as f:
            json.dump(json_exports, f, indent=2)
        
        print(f"✓ JSON export saved to {TABLES_DIR / 'regional_atlas_data_all_demographics.json'}")
        
        # ================================================================
        # Summary
        # ================================================================
        print("\n" + "="*80)
        print("✅ COMPREHENSIVE ATLAS VISUALIZATION COMPLETE")
        print("="*80)
        print("\nGenerated outputs:")
        print(f"  • Regional mean value visualizations for top {top_n_features} features")
        if 'regional_demo_correlations' in locals():
            print(f"  • Demographic correlation maps for {len(regional_demo_correlations)} variables:")
            for demo_var in regional_demo_correlations.keys():
                print(f"    - {demo_var}")
        if 'regional_sex_df' in locals():
            print(f"  • Sex difference (Cohen's d) maps")
        print(f"  • Comprehensive CSV export with all demographics")
        print(f"  • JSON export for custom atlas visualization")
        print("\n💡 TIPS:")
        print("  • Set ATLAS_PATH to your atlas NIfTI file for proper brain overlays")
        print("  • All correlation maps use diverging colormaps centered at 0")
        print("  • Effect sizes (Cohen's d) also use diverging colormaps")
        print("  • CSV file contains correlations with ALL demographic variables")
        print("  • JSON file can be used with custom atlas visualization scripts")
        print("  • Use the JSON/CSV exports with 3D Slicer, FSLeyes, or other tools")
        
    else:
        print("\n⚠️  No regional statistics available for atlas visualization")

In [None]:
# Import required libraries for atlas visualization
from scipy import ndimage

def plot_atlas_panel(ax, atlas_slice, value_slice, title, cmap='OrRd', vmin=None, vmax=None):
    """Create a single atlas visualization panel."""
    # 1. WHITE background
    ax.imshow(np.ones_like(atlas_slice), cmap='gray', vmin=0, vmax=1)
    # 2. LIGHT GREY fill inside regions
    mask = atlas_slice > 0
    bg = np.ones_like(atlas_slice)
    bg[mask] = 0.85
    ax.imshow(bg, cmap='gray', vmin=0, vmax=1)
    # 3. SMOOTH GREY contours
    smooth = ndimage.gaussian_filter(atlas_slice.astype(float), sigma=0.8)
    regions_arr = np.unique(atlas_slice)[1:]
    if regions_arr.size:
        ax.contour(
            atlas_slice.astype(float), levels=regions_arr + 0.5, colors='#BBBBBB',
            linewidths=0.5, alpha=0.6, antialiased=True
        )
    # 4. Value overlay
    vals = np.ma.masked_where(value_slice == 0, value_slice)
    im = ax.imshow(vals, cmap=cmap, vmin=vmin, vmax=vmax, alpha=0.9)
    ax.set_title(title, fontsize=12, pad=4)
    ax.axis('off')
    return im

def calculate_regional_means_by_group(regional_df, feature, region_col, group_col, group_value):
    """Calculate mean feature values by region for a specific group."""
    group_data = regional_df[regional_df[group_col] == group_value]
    regional_means = {}
    
    for region in group_data[region_col].unique():
        region_data = group_data[group_data[region_col] == region][feature].dropna()
        if len(region_data) > 0:
            regional_means[region] = region_data.mean()
    
    return regional_means

def create_group_comparison_atlas(atlas_path, regional_df, feature, region_col, 
                                  group_col, group_labels, output_prefix,
                                  main_title, colorbar_label):
    """
    Create side-by-side atlas comparison for different demographic groups.
    
    Parameters:
    -----------
    atlas_path : str
        Path to atlas NIfTI file
    regional_df : pd.DataFrame
        Regional data with demographics
    feature : str
        Feature to visualize
    region_col : str
        Column name for regions
    group_col : str
        Column name for grouping variable
    group_labels : dict
        Dictionary mapping group values to display labels
    output_prefix : str
        Prefix for output filename
    main_title : str
        Main figure title
    colorbar_label : str
        Label for colorbar
    """
    
    if atlas_path is None or not Path(atlas_path).exists():
        print(f"⚠️  Atlas file not found. Creating bar plots instead...")
        
        # Create bar plot comparison
        n_groups = len(group_labels)
        fig, axes = plt.subplots(1, n_groups, figsize=(6*n_groups, 8))
        if n_groups == 1:
            axes = [axes]
        
        fig.suptitle(main_title, fontsize=14, fontweight='bold')
        
        # Get all regions
        all_regions = sorted(regional_df[region_col].unique())
        
        # Determine global value range
        all_values = []
        for group_val in group_labels.keys():
            regional_means = calculate_regional_means_by_group(
                regional_df, feature, region_col, group_col, group_val
            )
            all_values.extend(regional_means.values())
        
        vmin, vmax = min(all_values), max(all_values)
        norm = plt.Normalize(vmin=vmin, vmax=vmax)
        cmap = plt.cm.OrRd
        
        for idx, (group_val, group_label) in enumerate(group_labels.items()):
            ax = axes[idx]
            
            regional_means = calculate_regional_means_by_group(
                regional_df, feature, region_col, group_col, group_val
            )
            
            regions = sorted(regional_means.keys())
            values = [regional_means[r] for r in regions]
            
            bars = ax.barh(range(len(regions)), values, edgecolor='black')
            ax.set_yticks(range(len(regions)))
            ax.set_yticklabels([f'R{r}' for r in regions], fontsize=8)
            ax.set_xlabel(colorbar_label, fontweight='bold')
            ax.set_title(group_label, fontweight='bold')
            ax.grid(True, alpha=0.3, axis='x')
            
            # Color bars
            for bar, val in zip(bars, values):
                bar.set_color(cmap(norm(val)))
        
        # Add shared colorbar
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        sm.set_array([])
        cbar = fig.colorbar(sm, ax=axes, orientation='horizontal', 
                           fraction=0.05, pad=0.08)
        cbar.set_label(colorbar_label, fontsize=10)
        
        plt.tight_layout()
        plt.savefig(FIGURES_DIR / f'{output_prefix}.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        return None
    
    try:
        import nibabel as nib
        
        print(f"Loading atlas from {atlas_path}...")
        atlas_img = nib.load(atlas_path)
        atlas_data = atlas_img.get_fdata()
        
        # Calculate regional means for each group
        group_value_maps = {}
        all_values = []
        
        for group_val, group_label in group_labels.items():
            regional_means = calculate_regional_means_by_group(
                regional_df, feature, region_col, group_col, group_val
            )
            
            # Create value map
            value_map = np.zeros_like(atlas_data)
            for region_id, value in regional_means.items():
                mask = atlas_data == region_id
                if np.any(mask):
                    value_map[mask] = value
            
            group_value_maps[group_label] = value_map
            non_zero = value_map[value_map != 0]
            if len(non_zero) > 0:
                all_values.extend(non_zero)
        
        # Determine global value range
        vmin = np.percentile(all_values, 5)
        vmax = np.percentile(all_values, 95)
        
        # Create figure with subplots for each group
        n_groups = len(group_labels)
        fig, axes = plt.subplots(n_groups, 3, figsize=(15, 5*n_groups))
        if n_groups == 1:
            axes = axes.reshape(1, -1)
        
        fig.suptitle(main_title, fontsize=16, fontweight='bold', y=0.98)
        
        # Sort groups by label (if numbers, sort numerically)
        sorted_groups = sorted(group_value_maps.items(), key=lambda x: x[0])
        
        for group_idx, (group_label, value_map) in enumerate(sorted_groups):
            print(f"  Plotting group: {group_label}...!!!")
            # Crop to brain region
            atlas_cropped = atlas_data[50:-50, 40:-40, 50:-50]
            value_cropped = value_map[50:-50, 40:-40, 50:-50]
            
            # Get slice positions
            x, y, z = atlas_cropped.shape
            x_slice, y_slice, z_slice = x//2, y//2, z//2
            
            slices = {
                'Sagittal': (np.rot90(atlas_cropped[x_slice, :, :]), 
                           np.rot90(value_cropped[x_slice, :, :])),
                'Coronal': (np.rot90(atlas_cropped[:, y_slice, :]), 
                          np.rot90(value_cropped[:, y_slice, :])),
                'Axial': (np.rot90(atlas_cropped[:, :, z_slice]), 
                        np.rot90(value_cropped[:, :, z_slice]))
            }
            
            # Plot each view for this group
            for view_idx, (view_name, (atlas_sl, val_sl)) in enumerate(slices.items()):
                ax = axes[group_idx, view_idx]
                title = f"{group_label} - {view_name}"
                im = plot_atlas_panel(ax, atlas_sl, val_sl, title, vmin=vmin, vmax=vmax)
        
        # Add shared colorbar
        cbar = fig.colorbar(im, ax=axes, orientation='horizontal', 
                           fraction=0.03, pad=0.04)
        cbar.set_label(colorbar_label, fontsize=12)
        
        plt.tight_layout()
        plt.savefig(FIGURES_DIR / f'{output_prefix}.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        print(f"✓ Group comparison atlas saved to {FIGURES_DIR / f'{output_prefix}.png'}")
        
        return group_value_maps
        
    except ImportError:
        print("⚠️  nibabel not installed. Install with: pip install nibabel")
        return None
    except Exception as e:
        print(f"⚠️  Error creating atlas visualization: {e}")
        import traceback
        traceback.print_exc()
        return None

In [None]:
# ====================================================================
# ENHANCED ATLAS VISUALIZATION - ALL DEMOGRAPHICS WITH GROUP COMPARISONS
# ====================================================================
"""
This enhanced version automatically:
1. Detects all available demographic variables
2. Creates visualizations for each demographic
3. Shows group differences with statistical tests
4. Handles both categorical and continuous variables
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from scipy import ndimage, stats
from scipy.stats import f_oneway, ttest_ind, mannwhitneyu
import seaborn as sns
from itertools import combinations

print("\n" + "="*80)
print("ENHANCED ATLAS VISUALIZATION - ALL DEMOGRAPHICS")
print("="*80)

# ====================================================================
# HELPER FUNCTIONS
# ====================================================================

def plot_atlas_panel(ax, atlas_slice, value_slice, title, cmap='OrRd', vmin=None, vmax=None):
    """Create a single atlas visualization panel."""
    # 1. WHITE background
    ax.imshow(np.ones_like(atlas_slice), cmap='gray', vmin=0, vmax=1)
    # 2. LIGHT GREY fill inside regions
    mask = atlas_slice > 0
    bg = np.ones_like(atlas_slice)
    bg[mask] = 0.85
    ax.imshow(bg, cmap='gray', vmin=0, vmax=1)
    # 3. SMOOTH GREY contours
    smooth = ndimage.gaussian_filter(atlas_slice.astype(float), sigma=0.8)
    regions_arr = np.unique(atlas_slice)[1:]
    if regions_arr.size:
        ax.contour(
            atlas_slice.astype(float), levels=regions_arr + 0.5, colors='#BBBBBB',
            linewidths=0.5, alpha=0.6, antialiased=True
        )
    # 4. Value overlay
    vals = np.ma.masked_where(value_slice == 0, value_slice)
    im = ax.imshow(vals, cmap=cmap, vmin=vmin, vmax=vmax, alpha=0.9)
    ax.set_title(title, fontsize=12, pad=4)
    ax.axis('off')
    return im


def calculate_regional_means_by_group(regional_df, feature, region_col, group_col, group_value):
    """Calculate mean feature values by region for a specific group."""
    group_data = regional_df[regional_df[group_col] == group_value]
    regional_means = {}
    
    for region in group_data[region_col].unique():
        region_data = group_data[group_data[region_col] == region][feature].dropna()
        if len(region_data) > 0:
            regional_means[region] = region_data.mean()
    
    return regional_means

def detect_demographic_variables(df, interactive=True):
    """
    Automatically detect all available demographic variables in the dataframe.
    
    Args:
        df: Input dataframe
        interactive: If True, prompts user to select variables. If False, includes all detected.
    
    Returns dict with:
        - 'categorical': list of categorical demographic variables with their unique values
        - 'continuous': list of continuous demographic variables
    """
    demographics = {
        'categorical': {},
        'continuous': []
    }
    
    # Known demographic columns
    demo_candidates = {
        'categorical': ['SEX_ID', 'sex', 'ETHNIC_ID', 'ethnicity', 'MARITAL_ID', 
                       'marital_status', 'OCCUPATION_ID', 'occupation', 'site',
                       'QUALIFICATION_ID', 'qualification', 'age_group', 'bmi_category'],
        'continuous': ['AGE', 'age', 'HEIGHT', 'height', 'WEIGHT', 'weight', 'BMI', 'bmi']
    }
    # Detect available variables
    available_categorical = {}
    available_continuous = []
    
    print("Detecting available demographic variables...")
    print("\nCategorical variables:")
    for col in demo_candidates['categorical']:
        if col in df.columns:
            unique_vals = df[col].dropna().unique()
            if 2 <= len(unique_vals) <= 10:
                value_counts = df[col].value_counts()
                if all(value_counts >= 10):
                    available_categorical[col] = sorted(unique_vals)
                    print(f"  ✓ {col} ({len(unique_vals)} groups)")
    
    print("\nContinuous variables:")
    for col in demo_candidates['continuous']:
        if col in df.columns:
            if df[col].dropna().nunique() > 10 and pd.api.types.is_numeric_dtype(df[col]):
                available_continuous.append(col)
                print(f"  ✓ {col}")
            else:
                print(f"  ✗ {col} (not enough unique values or not numeric)")
    
    # Interactive selection
    if interactive and (available_categorical or available_continuous):
        print("\n" + "="*60)
        print("SELECT VARIABLES TO INCLUDE")
        print("="*60)
        
        # Select categorical variables
        if available_categorical:
            print("\nCategorical variables (enter numbers separated by commas, or 'all'):")
            cat_list = list(available_categorical.keys())
            for i, col in enumerate(cat_list, 1):
                print(f"  {i}. {col}")
            
            cat_input = input("Select categorical variables: ").strip()
            if cat_input.lower() == 'all':
                demographics['categorical'] = available_categorical
            elif cat_input:
                selected_indices = [int(x.strip()) - 1 for x in cat_input.split(',') if x.strip()]
                demographics['categorical'] = {
                    cat_list[i]: available_categorical[cat_list[i]] 
                    for i in selected_indices if 0 <= i < len(cat_list)
                }
        
        # Select continuous variables
        if available_continuous:
            print("\nContinuous variables (enter numbers separated by commas, or 'all'):")
            for i, col in enumerate(available_continuous, 1):
                print(f"  {i}. {col}")
            
            cont_input = input("Select continuous variables: ").strip()
            if cont_input.lower() == 'all':
                demographics['continuous'] = available_continuous
            elif cont_input:
                selected_indices = [int(x.strip()) - 1 for x in cont_input.split(',') if x.strip()]
                demographics['continuous'] = [
                    available_continuous[i] 
                    for i in selected_indices if 0 <= i < len(available_continuous)
                ]
        
        print("\n✓ Selection complete!")
    else:
        # Non-interactive mode: include all
        demographics['categorical'] = available_categorical
        demographics['continuous'] = available_continuous
    
    return demographics


def create_group_labels(df, group_col):
    """Create human-readable labels for group values."""
    labels = {}
    
    # Special handling for common variables
    if group_col in ['SEX_ID', 'sex']:
        labels = {1: 'Male', 2: 'Female'}
    elif 'age_group' in group_col.lower():
        # Use the actual category labels
        for val in df[group_col].dropna().unique():
            labels[val] = str(val)
    elif 'bmi' in group_col.lower():
        # Use the actual category labels
        for val in df[group_col].dropna().unique():
            labels[val] = str(val)
    elif group_col == 'site':
        # Site labels as-is
        for val in df[group_col].dropna().unique():
            labels[val] = str(val)
    else:
        # Generic labeling
        for val in sorted(df[group_col].dropna().unique()):
            labels[val] = f"{group_col}={val}"
    
    return labels


def calculate_group_differences(regional_df, feature, region_col, group_col):
    """
    Calculate statistical differences between groups for each region.
    
    Returns DataFrame with regional p-values and effect sizes.
    """
    results = []
    groups = regional_df[group_col].dropna().unique()
    
    for region in regional_df[region_col].unique():
        region_data = regional_df[regional_df[region_col] == region]
        
        # Get data for each group
        group_data = []
        for grp in groups:
            grp_vals = region_data[region_data[group_col] == grp][feature].dropna()
            if len(grp_vals) >= 3:  # Minimum sample size
                group_data.append(grp_vals.values)
        
        # Skip if insufficient groups
        if len(group_data) < 2:
            continue
        
        # Statistical test
        if len(groups) == 2:
            # Two groups: t-test or Mann-Whitney
            stat, p_val = ttest_ind(group_data[0], group_data[1], equal_var=False)
            test_name = 't-test'
            
            # Calculate Cohen's d (effect size)
            pooled_std = np.sqrt((np.std(group_data[0])**2 + np.std(group_data[1])**2) / 2)
            cohens_d = (np.mean(group_data[0]) - np.mean(group_data[1])) / pooled_std if pooled_std > 0 else 0
            effect_size = abs(cohens_d)
        else:
            # Multiple groups: ANOVA
            stat, p_val = f_oneway(*group_data)
            test_name = 'ANOVA'
            
            # Calculate eta-squared (effect size)
            grand_mean = np.mean(np.concatenate(group_data))
            ss_between = sum(len(g) * (np.mean(g) - grand_mean)**2 for g in group_data)
            ss_total = sum((x - grand_mean)**2 for g in group_data for x in g)
            effect_size = ss_between / ss_total if ss_total > 0 else 0
        
        results.append({
            'region': region,
            'p_value': p_val,
            'test': test_name,
            'effect_size': effect_size,
            'significant': p_val < 0.05,
            'n_groups': len(group_data),
            'total_n': sum(len(g) for g in group_data)
        })
    
    return pd.DataFrame(results)


def create_difference_visualization(atlas_path, regional_df, feature, region_col, 
                                   group_col, output_prefix, main_title):
    """
    Create visualization showing WHERE groups differ (p-value heatmap on atlas).
    """
    # Calculate differences
    diff_results = calculate_group_differences(regional_df, feature, region_col, group_col)
    
    if len(diff_results) == 0:
        print(f"    ⚠️  No statistical results for {feature}")
        return
    
    # Create figure
    fig, axes = plt.subplots(2, 2, figsize=(16, 14))
    fig.suptitle(f"{main_title}\nRegional Group Differences", 
                 fontsize=14, fontweight='bold')
    
    # 1. P-value map on atlas (if atlas available)
    if atlas_path and Path(atlas_path).exists():
        try:
            import nibabel as nib
            atlas_img = nib.load(atlas_path)
            atlas_data = atlas_img.get_fdata()
            
            # Create p-value map
            pval_map = np.zeros_like(atlas_data)
            for _, row in diff_results.iterrows():
                mask = atlas_data == row['region']
                if np.any(mask):
                    # Use -log10(p) for better visualization
                    pval_map[mask] = -np.log10(row['p_value'] + 1e-10)

            
            # Plot
            atlas_cropped = atlas_data[50:-50, 40:-40, 50:-50]
            pval_cropped = pval_map[50:-50, 40:-40, 50:-50]
            x, y, z = atlas_cropped.shape
            
            ax = axes[0, 0]
            atlas_slice = np.rot90(atlas_cropped[:, :, z//2])
            pval_slice = np.rot90(pval_cropped[:, :, z//2])
            
            im = plot_atlas_panel(ax, atlas_slice, pval_slice, 
                                 "Significance Map (-log10 p-value)",
                                 cmap='YlOrRd', vmin=0, vmax=3)
            plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
            
            # Add significance threshold line
            ax.text(0.02, 0.98, "Red = more significant\np < 0.05 ≈ 1.3", 
                   transform=ax.transAxes, fontsize=9, va='top',
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
            
        except Exception as e:
            print(f"    ⚠️  Could not create atlas visualization: {e}")
            axes[0, 0].text(0.5, 0.5, "Atlas not available", 
                          ha='center', va='center', transform=axes[0, 0].transAxes)
            axes[0, 0].axis('off')
    else:
        axes[0, 0].text(0.5, 0.5, "Atlas not available", 
                       ha='center', va='center', transform=axes[0, 0].transAxes)
        axes[0, 0].axis('off')
    
    # 2. Effect size by region
    ax = axes[0, 1]
    sorted_diff = diff_results.sort_values('effect_size', ascending=False).head(20)
    colors = ['red' if sig else 'gray' for sig in sorted_diff['significant']]
    ax.barh(range(len(sorted_diff)), sorted_diff['effect_size'], color=colors, alpha=0.7)
    ax.set_yticks(range(len(sorted_diff)))
    ax.set_yticklabels([f"R{int(r)}" for r in sorted_diff['region']], fontsize=8)
    ax.set_xlabel('Effect Size', fontweight='bold')
    ax.set_title('Top 20 Regions by Effect Size\n(Red = significant p<0.05)', fontsize=10)
    ax.grid(True, alpha=0.3, axis='x')
    ax.invert_yaxis()
    
    # 3. P-value distribution
    ax = axes[1, 0]
    ax.hist(diff_results['p_value'], bins=50, edgecolor='black', alpha=0.7)
    ax.axvline(0.05, color='red', linestyle='--', linewidth=2, label='p = 0.05')
    ax.set_xlabel('P-value', fontweight='bold')
    ax.set_ylabel('Number of Regions')
    ax.set_title(f'Distribution of P-values\n{sum(diff_results["significant"])} significant regions', 
                fontsize=10)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 4. Summary statistics table
    ax = axes[1, 1]
    ax.axis('off')
    
    summary_text = f"""
STATISTICAL SUMMARY

Test: {diff_results['test'].iloc[0]}
Total Regions Tested: {len(diff_results)}
Significant Regions (p<0.05): {sum(diff_results['significant'])} ({100*sum(diff_results['significant'])/len(diff_results):.1f}%)
Bonferroni Threshold: {0.05/len(diff_results):.4f}
Significant (Bonferroni): {sum(diff_results['p_value'] < 0.05/len(diff_results))}

EFFECT SIZES
Mean: {diff_results['effect_size'].mean():.3f}
Median: {diff_results['effect_size'].median():.3f}
Max: {diff_results['effect_size'].max():.3f}

TOP 5 MOST DIFFERENT REGIONS:
"""
    top_5 = diff_results.nlargest(5, 'effect_size')
    for idx, row in top_5.iterrows():
        sig_marker = "***" if row['p_value'] < 0.001 else "**" if row['p_value'] < 0.01 else "*" if row['p_value'] < 0.05 else ""
        summary_text += f"\nR{int(row['region'])}: ES={row['effect_size']:.3f}, p={row['p_value']:.4f} {sig_marker}"
    
    ax.text(0.1, 0.9, summary_text, transform=ax.transAxes, 
           fontsize=10, verticalalignment='top', family='monospace',
           bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / f'{output_prefix}_differences.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    return diff_results


def create_pairwise_comparison_plots(regional_df, feature, region_col, group_col, 
                                    output_prefix, main_title):
    """
    For categorical variables with 2+ groups, create pairwise comparison plots.
    Shows violin/box plots for each region comparing groups.
    """
    groups = sorted(regional_df[group_col].dropna().unique())
    n_groups = len(groups)
    
    if n_groups < 2:
        return
    
    # Get top 12 most variable regions
    regional_variance = []
    for region in regional_df[region_col].unique():
        region_data = regional_df[regional_df[region_col] == region][feature].dropna()
        if len(region_data) > 0:
            regional_variance.append({
                'region': region,
                'variance': region_data.var(),
                'mean': region_data.mean()
            })
    
    top_regions = sorted(regional_variance, key=lambda x: x['variance'], reverse=True)[:12]
    
    # Create figure
    fig, axes = plt.subplots(3, 4, figsize=(20, 12))
    axes = axes.flatten()
    fig.suptitle(f"{main_title}\nGroup Comparisons by Region", 
                 fontsize=14, fontweight='bold')
    
    group_labels = create_group_labels(regional_df, group_col)
    
    for idx, reg_info in enumerate(top_regions):
        ax = axes[idx]
        region = reg_info['region']
        
        # Prepare data for this region
        plot_data = []
        for grp in groups:
            grp_data = regional_df[
                (regional_df[region_col] == region) & 
                (regional_df[group_col] == grp)
            ][feature].dropna()
            
            for val in grp_data:
                plot_data.append({
                    'Region': f"R{int(region)}",
                    'Group': group_labels.get(grp, str(grp)),
                    'Value': val
                })
        
        if len(plot_data) == 0:
            continue
        
        plot_df = pd.DataFrame(plot_data)
        
        # Violin + box plot
        sns.violinplot(data=plot_df, x='Group', y='Value', ax=ax, 
                      inner=None, alpha=0.3)
        sns.boxplot(data=plot_df, x='Group', y='Value', ax=ax,
                   width=0.3, showcaps=True, boxprops=dict(alpha=0.7))
        
        # Add mean line
        group_means = plot_df.groupby('Group')['Value'].mean()
        for i, (grp, mean_val) in enumerate(group_means.items()):
            ax.plot([i-0.4, i+0.4], [mean_val, mean_val], 
                   'r-', linewidth=2, alpha=0.7)
        
        ax.set_title(f"Region {int(region)}", fontweight='bold')
        ax.set_xlabel('')
        ax.set_ylabel(feature if idx % 4 == 0 else '')
        ax.grid(True, alpha=0.3, axis='y')
        
        # Rotate x labels if needed
        if n_groups > 2:
            ax.tick_params(axis='x', rotation=45)
    
    # Remove empty subplots
    for idx in range(len(top_regions), len(axes)):
        fig.delaxes(axes[idx])
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / f'{output_prefix}_pairwise.png', dpi=300, bbox_inches='tight')
    plt.show()

def create_continuous_correlation_atlas(atlas_path, regional_df, feature, region_col,
                                       continuous_var, output_prefix, main_title,
                                       method='pearson'):
    """
    For continuous demographic variables, show correlation strength on atlas.
    
    Parameters:
    -----------
    atlas_path : str
        Path to the atlas NIfTI file
    regional_df : pd.DataFrame
        DataFrame with regional data
    feature : str
        Name of the feature column to correlate
    region_col : str
        Name of the region identifier column
    continuous_var : str
        Name of the continuous variable to correlate with
    output_prefix : str
        Prefix for output filename
    main_title : str
        Main title for the figure
    method : str, optional (default='pearson')
        Correlation method to use. Options:
        - 'pearson': Pearson correlation (linear relationships, parametric)
        - 'spearman': Spearman rank correlation (monotonic relationships, non-parametric)
        - 'kendall': Kendall tau correlation (ordinal data, non-parametric, robust to outliers)
    
    Returns:
    --------
    corr_df : pd.DataFrame
        DataFrame with correlation results by region
    """
    # Validate method
    valid_methods = ['pearson', 'spearman', 'kendall']
    if method not in valid_methods:
        raise ValueError(f"method must be one of {valid_methods}, got '{method}'")
    
    # Map method to scipy function and coefficient name
    method_info = {
        'pearson': {'func': stats.pearsonr, 'coef_name': 'r', 'full_name': 'Pearson'},
        'spearman': {'func': stats.spearmanr, 'coef_name': 'ρ', 'full_name': 'Spearman'},
        'kendall': {'func': stats.kendalltau, 'coef_name': 'τ', 'full_name': 'Kendall'}
    }
    
    corr_func = method_info[method]['func']
    coef_name = method_info[method]['coef_name']
    full_name = method_info[method]['full_name']
    
    # Calculate correlations by region
    corr_results = []
    for region in regional_df[region_col].unique():
        region_data = regional_df[regional_df[region_col] == region][[feature, continuous_var]].dropna()
        
        if len(region_data) >= 10:  # Minimum sample size
            corr, p_val = corr_func(region_data[continuous_var], region_data[feature])
            corr_results.append({
                'region': region,
                'correlation': corr,
                'p_value': p_val,
                'significant': p_val < 0.05,
                'n': len(region_data)
            })
    
    if len(corr_results) == 0:
        print(f"    ⚠️  No correlation results for {feature} vs {continuous_var}")
        return
    
    corr_df = pd.DataFrame(corr_results)
    
    # Create figure
    fig, axes = plt.subplots(2, 2, figsize=(16, 14))
    fig.suptitle(f"{main_title}\n{full_name} Correlation with {continuous_var}", 
                 fontsize=14, fontweight='bold')
    
    # 1. Correlation map on atlas
    if atlas_path and Path(atlas_path).exists():
        try:
            import nibabel as nib
            atlas_img = nib.load(atlas_path)
            atlas_data = atlas_img.get_fdata()
            
            # Create correlation map
            corr_map = np.zeros_like(atlas_data)
            for _, row in corr_df.iterrows():
                mask = atlas_data == row['region']
                if np.any(mask):
                    corr_map[mask] = row['correlation']
            
            # Plot
            atlas_cropped = atlas_data[50:-50, 40:-40, 50:-50]
            corr_cropped = corr_map[50:-50, 40:-40, 50:-50]
            x, y, z = atlas_cropped.shape
            
            ax = axes[0, 0]
            atlas_slice = np.rot90(atlas_cropped[:, :, z//2])
            corr_slice = np.rot90(corr_cropped[:, :, z//2])
            
            vmax = max(abs(corr_df['correlation'].min()), abs(corr_df['correlation'].max()))
            im = plot_atlas_panel(ax, atlas_slice, corr_slice,
                                "Correlation Map",
                                cmap='RdBu_r', vmin=-vmax, vmax=vmax)
            cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
            cbar.set_label(f'{full_name} {coef_name}', fontsize=10)
            
            ax.text(0.02, 0.98, "Blue = negative\nRed = positive", 
                   transform=ax.transAxes, fontsize=9, va='top',
                   bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
            
        except Exception as e:
            print(f"    ⚠️  Could not create atlas visualization: {e}")
            axes[0, 0].text(0.5, 0.5, "Atlas not available", 
                          ha='center', va='center', transform=axes[0, 0].transAxes)
            axes[0, 0].axis('off')
    else:
        axes[0, 0].text(0.5, 0.5, "Atlas not available", 
                       ha='center', va='center', transform=axes[0, 0].transAxes)
        axes[0, 0].axis('off')
    
    # 2. Correlation strength by region
    ax = axes[0, 1]
    sorted_corr = corr_df.reindex(corr_df['correlation'].abs().sort_values(ascending=False).index).head(20)
    colors = ['red' if sig else 'gray' for sig in sorted_corr['significant']]
    ax.barh(range(len(sorted_corr)), sorted_corr['correlation'], color=colors, alpha=0.7)
    ax.set_yticks(range(len(sorted_corr)))
    ax.set_yticklabels([f"R{int(r)}" for r in sorted_corr['region']], fontsize=8)
    ax.set_xlabel(f'Correlation ({coef_name})', fontweight='bold')
    ax.axvline(0, color='black', linestyle='-', linewidth=0.5)
    ax.set_title(f'Top 20 Regions by |Correlation| ({full_name})\n(Red = significant p<0.05)', fontsize=10)
    ax.grid(True, alpha=0.3, axis='x')
    ax.invert_yaxis()
    
    # 3. Scatter plots for top correlated regions
    ax = axes[1, 0]
    top_positive = corr_df.nlargest(1, 'correlation').iloc[0]
    top_negative = corr_df.nsmallest(1, 'correlation').iloc[0]
    
    # Plot both on same axis with different colors
    for region_info, color, label in [(top_positive, 'red', f"R{int(top_positive['region'])} ({coef_name}={top_positive['correlation']:.3f})"),
                                       (top_negative, 'blue', f"R{int(top_negative['region'])} ({coef_name}={top_negative['correlation']:.3f})")]:
        region_data = regional_df[regional_df[region_col] == region_info['region']][[continuous_var, feature]].dropna()
        ax.scatter(region_data[continuous_var], region_data[feature], 
                  alpha=0.5, s=30, color=color, label=label)
        
        # Add regression line (only for visualization, even for non-parametric tests)
        z = np.polyfit(region_data[continuous_var], region_data[feature], 1)
        p = np.poly1d(z)
        x_line = np.linspace(region_data[continuous_var].min(), region_data[continuous_var].max(), 100)
        ax.plot(x_line, p(x_line), color=color, linestyle='--', linewidth=2, alpha=0.7)
    
    ax.set_xlabel(continuous_var, fontweight='bold')
    ax.set_ylabel(feature, fontweight='bold')
    ax.set_title(f'Most Positive & Negative Correlations\n(Line shown for reference)', fontsize=10)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # 4. Summary statistics
    ax = axes[1, 1]
    ax.axis('off')
    
    summary_text = f"""
CORRELATION SUMMARY ({full_name.upper()})

Method: {full_name} ({coef_name})
Variable: {continuous_var}
Total Regions Tested: {len(corr_df)}
Significant Correlations (p<0.05): {sum(corr_df['significant'])} ({100*sum(corr_df['significant'])/len(corr_df):.1f}%)

CORRELATION STATISTICS
Mean |{coef_name}|: {corr_df['correlation'].abs().mean():.3f}
Median |{coef_name}|: {corr_df['correlation'].abs().median():.3f}
Max |{coef_name}|: {corr_df['correlation'].abs().max():.3f}

POSITIVE CORRELATIONS
Count: {sum(corr_df['correlation'] > 0)}
Mean {coef_name}: {corr_df[corr_df['correlation'] > 0]['correlation'].mean():.3f}
Significant: {sum((corr_df['correlation'] > 0) & corr_df['significant'])}

NEGATIVE CORRELATIONS
Count: {sum(corr_df['correlation'] < 0)}
Mean {coef_name}: {corr_df[corr_df['correlation'] < 0]['correlation'].mean():.3f}
Significant: {sum((corr_df['correlation'] < 0) & corr_df['significant'])}

STRONGEST CORRELATIONS:
"""
    top_5_abs = corr_df.iloc[corr_df['correlation'].abs().argsort()[-5:][::-1]]
    for _, row in top_5_abs.iterrows():
        sig_marker = "***" if row['p_value'] < 0.001 else "**" if row['p_value'] < 0.01 else "*" if row['p_value'] < 0.05 else ""
        summary_text += f"\nR{int(row['region'])}: {coef_name}={row['correlation']:.3f}, p={row['p_value']:.4f} {sig_marker}"
    
    ax.text(0.1, 0.9, summary_text, transform=ax.transAxes, 
           fontsize=10, verticalalignment='top', family='monospace',
           bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / f'{output_prefix}_correlation_{method}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    return corr_df

# ====================================================================
# MAIN ANALYSIS WORKFLOW
# ====================================================================

def run_comprehensive_demographic_analysis(regional_with_demo, atlas_path, region_col, 
                                          top_features, FIGURES_DIR, TABLES_DIR,correlation_method):
    """
    Main function to run analysis for ALL available demographics automatically.
    
    Parameters:
    -----------
    regional_with_demo : pd.DataFrame
        Regional data merged with demographics
    atlas_path : str or Path
        Path to atlas NIfTI file
    region_col : str
        Column name for brain regions
    top_features : list
        List of features to analyze (e.g., top 10 by variance)
    FIGURES_DIR : Path
        Directory to save figures
    TABLES_DIR : Path
        Directory to save tables
    """
    print("\n" + "="*80)
    print("STEP 1: AUTO-DETECTING DEMOGRAPHIC VARIABLES")
    print("="*80)
    
    demographics = detect_demographic_variables(regional_with_demo)
    print(f"\nFound {len(demographics['categorical'])} categorical and "
          f"{len(demographics['continuous'])} continuous demographic variables")
    
    # Store all results
    all_statistical_results = {}
    
    # ====================================================================
    # PART 1: CATEGORICAL DEMOGRAPHIC ANALYSES
    # ====================================================================
    
    if len(demographics['categorical']) > 0:
        print("\n" + "="*80)
        print("STEP 2: CATEGORICAL DEMOGRAPHIC COMPARISONS")
        print("="*80)
        
        for demo_var, unique_vals in demographics['categorical'].items():
            print(f"\n{'-'*80}")
            
            print(f"Analyzing: {demo_var.upper()}")
            print(f"{'-'*80}")
            
            # Get group labels
            group_labels = create_group_labels(regional_with_demo, demo_var)
            print(f"Groups: {', '.join([f'{k}={v}' for k,v in group_labels.items()])}")
            
            # Count subjects per group
            group_counts = regional_with_demo[demo_var].value_counts()
            print(f"Sample sizes: {dict(group_counts)}")
            
            for feature_idx, feature in enumerate(top_features, 1):
                if feature not in regional_with_demo.columns:
                    continue
                
                print(f"\n  [{feature_idx}/{len(top_features)}] Processing: {feature}")
                
                # 1. Create group comparison atlas visualization
                print(f"    → Creating group comparison atlas...")
                create_group_comparison_atlas(
                    atlas_path=atlas_path,
                    regional_df=regional_with_demo,
                    feature=feature,
                    region_col=region_col,
                    group_col=demo_var,
                    group_labels=group_labels,
                    output_prefix=f'{demo_var}_{feature}',
                    main_title=f'{feature} - {demo_var.replace("_", " ").title()} Comparison',
                    colorbar_label=f'{feature}'
                )
                
                # 2. Create difference visualization (statistical testing)
                print(f"    → Creating statistical difference maps...")
                diff_results = create_difference_visualization(
                    atlas_path=atlas_path,
                    regional_df=regional_with_demo,
                    feature=feature,
                    region_col=region_col,
                    group_col=demo_var,
                    output_prefix=f'{demo_var}_{feature}',
                    main_title=f'{feature} - {demo_var.replace("_", " ").title()}'
                )
                
                # 3. Create pairwise comparison plots
                print(f"    → Creating pairwise comparison plots...")
                create_pairwise_comparison_plots(
                    regional_df=regional_with_demo,
                    feature=feature,
                    region_col=region_col,
                    group_col=demo_var,
                    output_prefix=f'{demo_var}_{feature}',
                    main_title=f'{feature} - {demo_var.replace("_", " ").title()}'
                )
                
                # Store results
                if diff_results is not None:
                    key = f"{demo_var}_{feature}"
                    all_statistical_results[key] = diff_results
                    
                    # Save to CSV
                    diff_results.to_csv(
                        TABLES_DIR / f'stats_{demo_var}_{feature}.csv', 
                        index=False
                    )
                    print(f"    ✓ Saved statistical results")
    
    # ====================================================================
    # PART 2: CONTINUOUS DEMOGRAPHIC ANALYSES
    # ====================================================================
    
    if len(demographics['continuous']) > 0:
        print("\n" + "="*80)
        print("STEP 3: CONTINUOUS DEMOGRAPHIC CORRELATIONS")
        print("="*80)
        
        for demo_var in demographics['continuous']:
            print(f"\n{'-'*80}")
            print(f"Analyzing: {demo_var.upper()}")
            print(f"{'-'*80}")
            
            # Get variable statistics
            var_data = regional_with_demo[demo_var].dropna()
            print(f"Range: {var_data.min():.1f} - {var_data.max():.1f}")
            print(f"Mean ± SD: {var_data.mean():.1f} ± {var_data.std():.1f}")
            print(f"N = {len(var_data)}")
            
            for feature_idx, feature in enumerate(top_features, 1):
                if feature not in regional_with_demo.columns:
                    continue
                
                print(f"\n  [{feature_idx}/{len(top_features)}] Processing: {feature}")
                
                # Create correlation analysis
                print(f"    → Creating correlation atlas...")
                corr_results = create_continuous_correlation_atlas(
                    atlas_path=atlas_path,
                    regional_df=regional_with_demo,
                    feature=feature,
                    region_col=region_col,
                    continuous_var=demo_var,
                    output_prefix=f'{demo_var}_{feature}',
                    main_title=f'{feature}',
                    method=correlation_method
                )
                
                # Store results
                if corr_results is not None:
                    key = f"{demo_var}_{feature}_corr"
                    all_statistical_results[key] = corr_results
                    
                    # Save to CSV
                    corr_results.to_csv(
                        TABLES_DIR / f'corr_{demo_var}_{feature}.csv',
                        index=False
                    )
                    print(f"    ✓ Saved correlation results")
    
    # ====================================================================
    # PART 3: SUMMARY REPORT
    # ====================================================================
    
    print("\n" + "="*80)
    print("ANALYSIS COMPLETE - GENERATING SUMMARY")
    print("="*80)
    
    summary_data = {
        'demographics_analyzed': {
            'categorical': list(demographics['categorical'].keys()),
            'continuous': demographics['continuous']
        },
        'features_analyzed': top_features,
        'total_comparisons': len(all_statistical_results),
        'figures_generated': len(list(FIGURES_DIR.glob('*.png'))),
        'tables_generated': len(list(TABLES_DIR.glob('stats_*.csv'))) + 
                           len(list(TABLES_DIR.glob('corr_*.csv')))
    }
    
    # Save summary
    import json
    with open(TABLES_DIR / 'demographic_analysis_summary.json', 'w') as f:
        json.dump(summary_data, f, indent=2)
    
    print(f"\n✅ ANALYSIS SUMMARY:")
    print(f"   • Categorical variables: {len(demographics['categorical'])}")
    print(f"   • Continuous variables: {len(demographics['continuous'])}")
    print(f"   • Features analyzed: {len(top_features)}")
    print(f"   • Total statistical tests: {len(all_statistical_results)}")
    print(f"   • Figures generated: {summary_data['figures_generated']}")
    print(f"   • Tables generated: {summary_data['tables_generated']}")
    
    return all_statistical_results, demographics


In [None]:
# Add an AGE_BIN based on quartile 
regional_with_demo['age_group'] = pd.qcut(
    regional_with_demo['AGE'], 
    q=4, 
    labels=[25,40,55,75]
)
# remove any NA age_group
regional_with_demo = regional_with_demo.dropna(subset=['age_group'])
# save is float
regional_with_demo['age_group'] = regional_with_demo['age_group'].astype(float)

print(f"Number of patient per AGE BIN QUARTILE:")
print(f"  {regional_with_demo['age_group'].value_counts().to_dict()}")


# Add an 'bmi_category' based on 10,18,5,25,25+
regional_with_demo['bmi_category'] = pd.cut(
    regional_with_demo['BMI'],
    bins=[18.5, 25, 30,50],
    labels=[22, 27, 35]
)

regional_with_demo = regional_with_demo.dropna(subset=['bmi_category'])
regional_with_demo['bmi_category'] = regional_with_demo['bmi_category'].astype(float)

regional_with_demo

In [None]:
# Get top features (or use all features)
top_features = regional_var_df.head(10)['feature'].tolist()

# SELECT OCCUPATION (4,6,7) + AGE+BMI (1,2)

# BMI ==> Solo len
# OCCUPATION, AGE ==> ....

In [None]:
# Run comprehensive analysis
results, detected_demographics = run_comprehensive_demographic_analysis(
    regional_with_demo=regional_with_demo,
    atlas_path=ATLAS_PATH,
    region_col=region_col,
    top_features=top_features,
    FIGURES_DIR=FIGURES_DIR,
    TABLES_DIR=TABLES_DIR,
    correlation_method='pearson' # 'pearson', 'spearman', or 'kendall'
)

In [None]:
# ====================================================================
# 10.9 Statistical Testing - Regional Features vs Demographics
# ====================================================================
print("\n" + "="*80)
print("10.9 STATISTICAL TESTING - REGIONAL FEATURES VS DEMOGRAPHICS")
print("="*80)

if 'regional_with_demo' not in locals() or len(regional_with_demo) == 0:
    print("\n⚠️  Regional data with demographics not available")
else:
    
    # ================================================================
    # Helper Functions for Statistical Testing
    # ================================================================
    
    def test_continuous_demographic(regional_df, feature, region_col, demo_var):
        """Test correlation between feature and continuous demographic variable by region."""
        results = []
        
        for region in regional_df[region_col].unique():
            region_data = regional_df[regional_df[region_col] == region][[demo_var, feature]].dropna()
            
            if len(region_data) < 10:
                continue
            
            # Pearson correlation
            r, p = pearsonr(region_data[demo_var], region_data[feature])
            
            # Spearman correlation (non-parametric)
            rho, p_spearman = spearmanr(region_data[demo_var], region_data[feature])
            
            results.append({
                'region': region,
                'feature': feature,
                'demographic': demo_var,
                'n': len(region_data),
                'pearson_r': r,
                'pearson_p': p,
                'spearman_rho': rho,
                'spearman_p': p_spearman
            })
        
        return results
    
    def test_categorical_demographic(regional_df, feature, region_col, demo_var):
        """Test differences between demographic groups by region."""
        results = []
        
        # Get unique categories
        categories = regional_df[demo_var].dropna().unique()
        
        if len(categories) < 2:
            return results
        
        for region in regional_df[region_col].unique():
            region_data = regional_df[regional_df[region_col] == region]
            
            # Get data for each category
            group_data = []
            group_labels = []
            group_sizes = []
            group_means = []
            
            for cat in categories:
                cat_data = region_data[region_data[demo_var] == cat][feature].dropna()
                if len(cat_data) >= 5:  # Minimum group size
                    group_data.append(cat_data)
                    group_labels.append(cat)
                    group_sizes.append(len(cat_data))
                    group_means.append(cat_data.mean())
            
            if len(group_data) < 2:
                continue
            
            # Perform statistical tests
            try:
                # ANOVA (parametric)
                f_stat, f_pval = f_oneway(*group_data)
                
                # Kruskal-Wallis (non-parametric)
                h_stat, h_pval = kruskal(*group_data)
                
                # Effect size (eta-squared for ANOVA)
                grand_mean = np.mean(np.concatenate(group_data))
                ss_between = sum(len(d) * (np.mean(d) - grand_mean)**2 for d in group_data)
                ss_total = sum(np.sum((d - grand_mean)**2) for d in group_data)
                eta_squared = ss_between / ss_total if ss_total > 0 else 0
                
                # For binary comparisons, also calculate Cohen's d
                cohens_d = np.nan
                if len(group_data) == 2:
                    pooled_std = np.sqrt(
                        ((len(group_data[0])-1)*group_data[0].std()**2 + 
                         (len(group_data[1])-1)*group_data[1].std()**2) / 
                        (len(group_data[0])+len(group_data[1])-2)
                    )
                    cohens_d = (group_data[0].mean() - group_data[1].mean()) / pooled_std if pooled_std > 0 else 0
                
                result = {
                    'region': region,
                    'feature': feature,
                    'demographic': demo_var,
                    'n_groups': len(group_data),
                    'total_n': sum(group_sizes),
                    'f_statistic': f_stat,
                    'anova_p': f_pval,
                    'kruskal_h': h_stat,
                    'kruskal_p': h_pval,
                    'eta_squared': eta_squared,
                    'cohens_d': cohens_d
                }
                
                # Add group-specific info
                for idx, label in enumerate(group_labels):
                    result[f'group_{idx+1}_label'] = str(label)
                    result[f'group_{idx+1}_n'] = group_sizes[idx]
                    result[f'group_{idx+1}_mean'] = group_means[idx]
                
                results.append(result)
                
            except Exception as e:
                print(f"  Warning: Could not test region {region} for {demo_var}: {e}")
                continue
        
        return results
    

In [None]:
# ================================================================
# 1. IDENTIFY DEMOGRAPHIC VARIABLES
# ================================================================
print("\n" + "-"*80)
print("1. IDENTIFYING DEMOGRAPHIC VARIABLES")
print("-"*80)

# Continuous variables (for correlation testing)
continuous_vars = []
for var in DEMOGRAPHIC_VARS:
    if var in regional_with_demo.columns:
        if pd.api.types.is_numeric_dtype(regional_with_demo[var]):
            var_std = regional_with_demo[var].std()
            var_nunique = regional_with_demo[var].nunique()
            if var_std > 0 and var_nunique > 10:
                continuous_vars.append(var)
            else:
                print(f"Excluding {var}: std={var_std}, unique={var_nunique} (not continuous)")
        else:
            print(f"Excluding {var}: not numeric dtype")
    else:
        print(f"Excluding {var}: not in data columns")
            
# Categorical variables (for group comparison testing)
categorical_vars = []
for var in DEMOGRAPHIC_VARS:
    if var in regional_with_demo.columns:
        var_nunique = regional_with_demo[var].nunique()
        if 2 <= var_nunique <= 10:  # Between 2 and 10 categories
            categorical_vars.append(var)
        else:
            print(f"Excluding {var}: {var_nunique} unique values (not categorical)")
    else:
        print(f"Excluding {var}: not in data columns")
        
print(f"\nContinuous variables (correlation tests): {continuous_vars}")
print(f"Categorical variables (group tests): {categorical_vars}")



In [None]:
# ================================================================
# 2. TEST CONTINUOUS DEMOGRAPHICS (CORRELATIONS)
# ================================================================
print("\n" + "-"*80)
print("2. TESTING CONTINUOUS DEMOGRAPHIC CORRELATIONS BY REGION")
print("-"*80)

all_continuous_results = []

for demo_var in continuous_vars:
    print(f"\n  Testing {demo_var}...")
    
    for feature in ALL_FEATURES:
        if feature not in regional_with_demo.columns or 'region' in feature:
            continue
        
        results = test_continuous_demographic(
            regional_with_demo, feature, region_col, demo_var
        )
        all_continuous_results.extend(results)
    
    print(f"    Completed {len([r for r in all_continuous_results if r['demographic'] == demo_var])} region-feature tests")

if len(all_continuous_results) > 0:
    continuous_results_df = pd.DataFrame(all_continuous_results)
    
    # FDR correction for each demographic variable
    for demo_var in continuous_vars:
        mask = continuous_results_df['demographic'] == demo_var
        if mask.sum() > 0:
            continuous_results_df.loc[mask, 'pearson_p_fdr'] = multipletests(
                continuous_results_df.loc[mask, 'pearson_p'], method='fdr_bh')[1]
            continuous_results_df.loc[mask, 'spearman_p_fdr'] = multipletests(
                continuous_results_df.loc[mask, 'spearman_p'], method='fdr_bh')[1]
    
    continuous_results_df['pearson_significant'] = continuous_results_df['pearson_p_fdr'] < 0.05
    continuous_results_df['spearman_significant'] = continuous_results_df['spearman_p_fdr'] < 0.05
    
    # Save results
    continuous_results_df.to_csv(
        TABLES_DIR / 'regional_continuous_demographics_tests.csv', index=False
    )
    
    # Summary
    print(f"\n  CONTINUOUS DEMOGRAPHICS SUMMARY:")
    print(f"  Total tests performed: {len(continuous_results_df)}")
    for demo_var in continuous_vars:
        demo_results = continuous_results_df[continuous_results_df['demographic'] == demo_var]
        n_sig_pearson = demo_results['pearson_significant'].sum()
        n_sig_spearman = demo_results['spearman_significant'].sum()
        print(f"\n  {demo_var}:")
        print(f"    Significant Pearson correlations: {n_sig_pearson} / {len(demo_results)} ({100*n_sig_pearson/len(demo_results):.1f}%)")
        print(f"    Significant Spearman correlations: {n_sig_spearman} / {len(demo_results)} ({100*n_sig_spearman/len(demo_results):.1f}%)")
        
        # Top correlations
        if n_sig_pearson > 0:
            top_corr = demo_results[demo_results['pearson_significant']].nlargest(5, 'pearson_r', keep='all')
            print(f"    Top 5 positive correlations:")
            for idx, row in top_corr.head(5).iterrows():
                print(f"      Region {row['region']}, {row['feature']}: r={row['pearson_r']:.3f}, p={row['pearson_p_fdr']:.2e}")


In [None]:

# ================================================================
# 3. TEST CATEGORICAL DEMOGRAPHICS (GROUP DIFFERENCES)
# ================================================================
print("\n" + "-"*80)
print("3. TESTING CATEGORICAL DEMOGRAPHIC GROUP DIFFERENCES BY REGION")
print("-"*80)

all_categorical_results = []

for demo_var in categorical_vars:
    print(f"\n  Testing {demo_var}...")
    
    for feature in ALL_FEATURES:
        if feature not in regional_with_demo.columns or 'region' in feature:
            continue
        
        results = test_categorical_demographic(
            regional_with_demo, feature, region_col, demo_var
        )
        all_categorical_results.extend(results)
    
    print(f"    Completed {len([r for r in all_categorical_results if r['demographic'] == demo_var])} region-feature tests")

if len(all_categorical_results) > 0:
    categorical_results_df = pd.DataFrame(all_categorical_results)
    
    # FDR correction for each demographic variable
    for demo_var in categorical_vars:
        mask = categorical_results_df['demographic'] == demo_var
        if mask.sum() > 0:
            categorical_results_df.loc[mask, 'anova_p_fdr'] = multipletests(
                categorical_results_df.loc[mask, 'anova_p'], method='fdr_bh')[1]
            categorical_results_df.loc[mask, 'kruskal_p_fdr'] = multipletests(
                categorical_results_df.loc[mask, 'kruskal_p'], method='fdr_bh')[1]
    
    categorical_results_df['anova_significant'] = categorical_results_df['anova_p_fdr'] < 0.05
    categorical_results_df['kruskal_significant'] = categorical_results_df['kruskal_p_fdr'] < 0.05
    
    # Save results
    categorical_results_df.to_csv(
        TABLES_DIR / 'regional_categorical_demographics_tests.csv', index=False
    )
    
    # Summary
    print(f"\n  CATEGORICAL DEMOGRAPHICS SUMMARY:")
    print(f"  Total tests performed: {len(categorical_results_df)}")
    for demo_var in categorical_vars:
        demo_results = categorical_results_df[categorical_results_df['demographic'] == demo_var]
        n_sig_anova = demo_results['anova_significant'].sum()
        n_sig_kruskal = demo_results['kruskal_significant'].sum()
        print(f"\n  {demo_var}:")
        print(f"    Significant ANOVA differences: {n_sig_anova} / {len(demo_results)} ({100*n_sig_anova/len(demo_results):.1f}%)")
        print(f"    Significant Kruskal-Wallis differences: {n_sig_kruskal} / {len(demo_results)} ({100*n_sig_kruskal/len(demo_results):.1f}%)")
        
        # Top effect sizes
        if n_sig_anova > 0:
            top_effects = demo_results[demo_results['anova_significant']].nlargest(5, 'eta_squared', keep='all')
            print(f"    Top 5 largest effect sizes:")
            for idx, row in top_effects.head(5).iterrows():
                print(f"      Region {row['region']}, {row['feature']}: η²={row['eta_squared']:.3f}, p={row['anova_p_fdr']:.2e}")



In [None]:
# ================================================================
# 4. CREATE SUMMARY HEATMAPS
# ================================================================
print("\n" + "-"*80)
print("4. CREATING SUMMARY HEATMAPS")
print("-"*80)

# Heatmap 1: Significant correlations (continuous variables)
if len(all_continuous_results) > 0 and len(continuous_vars) > 0:
    print("\n  Creating continuous demographics heatmap...")
    
    # Create significance matrix
    sig_matrix = []
    features_in_order = []
    
    for feature in ALL_FEATURES:
        if 'region' in feature:
            continue
        feature_results = continuous_results_df[continuous_results_df['feature'] == feature]
        if len(feature_results) == 0:
            continue
        
        row = []
        for demo_var in continuous_vars:
            demo_feature = feature_results[feature_results['demographic'] == demo_var]
            n_sig = demo_feature['pearson_significant'].sum()
            pct_sig = 100 * n_sig / len(demo_feature) if len(demo_feature) > 0 else 0
            row.append(pct_sig)
        
        if sum(row) > 0:  # Only include features with some significance
            sig_matrix.append(row)
            features_in_order.append(feature)
    
    if len(sig_matrix) > 0:
        sig_df = pd.DataFrame(sig_matrix, index=features_in_order, columns=continuous_vars)
        
        # Plot
        fig, ax = plt.subplots(figsize=(max(8, len(continuous_vars)*1.5), 
                                        max(10, len(features_in_order)*0.3)))
        sns.heatmap(sig_df, annot=True, fmt='.0f', cmap='YlOrRd', 
                    cbar_kws={'label': '% Regions Significant'}, ax=ax,
                    linewidths=0.5, linecolor='gray')
        ax.set_title('Percentage of Regions with Significant Correlations\n(Continuous Demographics)', 
                    fontweight='bold', fontsize=14)
        ax.set_xlabel('Demographic Variable', fontweight='bold')
        ax.set_ylabel('Vessel Feature', fontweight='bold')
        plt.tight_layout()
        plt.savefig(FIGURES_DIR / 'regional_continuous_demographics_heatmap.png', 
                    dpi=300, bbox_inches='tight')
        plt.show()
        print("    ✓ Continuous demographics heatmap saved")

# Heatmap 2: Significant group differences (categorical variables)
if len(all_categorical_results) > 0 and len(categorical_vars) > 0:
    print("\n  Creating categorical demographics heatmap...")
    
    # Create significance matrix
    sig_matrix = []
    features_in_order = []
    
    for feature in ALL_FEATURES:
        if 'region' in feature:
            continue
        feature_results = categorical_results_df[categorical_results_df['feature'] == feature]
        if len(feature_results) == 0:
            continue
        
        row = []
        for demo_var in categorical_vars:
            demo_feature = feature_results[feature_results['demographic'] == demo_var]
            n_sig = demo_feature['anova_significant'].sum()
            pct_sig = 100 * n_sig / len(demo_feature) if len(demo_feature) > 0 else 0
            row.append(pct_sig)
        
        if sum(row) > 0:
            sig_matrix.append(row)
            features_in_order.append(feature)
    
    if len(sig_matrix) > 0:
        sig_df = pd.DataFrame(sig_matrix, index=features_in_order, columns=categorical_vars)
        
        # Plot
        fig, ax = plt.subplots(figsize=(max(8, len(categorical_vars)*1.5), 
                                        max(10, len(features_in_order)*0.3)))
        sns.heatmap(sig_df, annot=True, fmt='.0f', cmap='YlOrRd', 
                    cbar_kws={'label': '% Regions Significant'}, ax=ax,
                    linewidths=0.5, linecolor='gray')
        ax.set_title('Percentage of Regions with Significant Group Differences\n(Categorical Demographics)', 
                    fontweight='bold', fontsize=14)
        ax.set_xlabel('Demographic Variable', fontweight='bold')
        ax.set_ylabel('Vessel Feature', fontweight='bold')
        plt.tight_layout()
        plt.savefig(FIGURES_DIR / 'regional_categorical_demographics_heatmap.png', 
                    dpi=300, bbox_inches='tight')
        plt.show()
        print("    ✓ Categorical demographics heatmap saved")


In [None]:
# ================================================================
# ATLAS-BASED HEATMAP VISUALIZATION
# Replace the heatmap creation code from your document with this version
# ================================================================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from scipy import ndimage

# ================================================================
# HELPER FUNCTIONS FOR ATLAS VISUALIZATION
# ================================================================

def plot_atlas_panel(ax, atlas_slice, value_slice, title, cmap='YlOrRd', vmin=None, vmax=None):
    """Create a single atlas visualization panel with smooth rendering."""
    # 1. WHITE background
    ax.imshow(np.ones_like(atlas_slice), cmap='gray', vmin=0, vmax=1)
    # 2. LIGHT GREY fill inside regions
    mask = atlas_slice > 0
    bg = np.ones_like(atlas_slice)
    bg[mask] = 0.85
    ax.imshow(bg, cmap='gray', vmin=0, vmax=1)
    # 3. SMOOTH GREY contours
    regions_arr = np.unique(atlas_slice)[1:]
    if regions_arr.size:
        ax.contour(
            atlas_slice.astype(float), levels=regions_arr + 0.5, colors='#BBBBBB',
            linewidths=0.5, alpha=0.6, antialiased=True
        )
    # 4. Value overlay
    vals = np.ma.masked_where(value_slice == 0, value_slice)
    im = ax.imshow(vals, cmap=cmap, vmin=vmin, vmax=vmax, alpha=0.9)
    ax.set_title(title, fontsize=12, pad=4)
    ax.axis('off')
    return im


def create_atlas_heatmap_from_dataframe(atlas_path, results_df, demo_var, 
                                        is_continuous=True, output_prefix='',
                                        figures_dir=Path('./figures')):
    """
    Create atlas-based heatmap directly from your results dataframe.
    
    Parameters:
    -----------
    atlas_path : str
        Path to atlas NIfTI file
    results_df : pd.DataFrame
        Your continuous_results_df or categorical_results_df
    demo_var : str
        Demographic variable name (e.g., 'AGE', 'SEX_ID')
    is_continuous : bool
        True for continuous (correlation), False for categorical (ANOVA)
    output_prefix : str
        Prefix for output filename
    figures_dir : Path
        Directory to save figures
    """
    
    import nibabel as nib
    
    # Filter for this demographic variable
    demo_results = results_df[results_df['demographic'] == demo_var]
    
    # Calculate percentage significant per region
    sig_col = 'pearson_significant' if is_continuous else 'anova_significant'
    regional_sig = {}
    
    for region in demo_results['region'].unique():
        region_data = demo_results[demo_results['region'] == region]
        n_sig = region_data[sig_col].sum()
        total = len(region_data)
        pct_sig = 100 * n_sig / total if total > 0 else 0
        regional_sig[region] = pct_sig
    
    # Load atlas
    try:
        print(f"  Loading atlas from {atlas_path}...")
        atlas_img = nib.load(atlas_path)
        atlas_data = atlas_img.get_fdata()
        
        # Create value map
        value_map = np.zeros_like(atlas_data)
        for region_id, pct_sig in regional_sig.items():
            mask = atlas_data == region_id
            if np.any(mask):
                value_map[mask] = pct_sig
        
        # Crop to brain region
        atlas_cropped = atlas_data[50:-50, 40:-40, 50:-50]
        value_cropped = value_map[50:-50, 40:-40, 50:-50]
        
        # Get slice positions
        x, y, z = atlas_cropped.shape
        x_slice, y_slice, z_slice = x//2, y//2, z//2
        
        # Create figure with three views
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        
        slices = {
            'Sagittal': (np.rot90(atlas_cropped[x_slice, :, :]), 
                        np.rot90(value_cropped[x_slice, :, :])),
            'Coronal': (np.rot90(atlas_cropped[:, y_slice, :]), 
                       np.rot90(value_cropped[:, y_slice, :])),
            'Axial': (np.rot90(atlas_cropped[:, :, z_slice]), 
                     np.rot90(value_cropped[:, :, z_slice]))
        }
        
        # Plot each view
        vmin, vmax = 0, 100
        for ax, (view_name, (atlas_sl, val_sl)) in zip(axes, slices.items()):
            im = plot_atlas_panel(ax, atlas_sl, val_sl, view_name, 
                                 cmap='YlOrRd', vmin=vmin, vmax=vmax)
        
        # Title and colorbar
        analysis_type = "Correlations" if is_continuous else "Group Differences"
        fig.suptitle(f'{demo_var} - Regional {analysis_type}', 
                    fontsize=16, fontweight='bold', y=0.98)
        
        cbar = fig.colorbar(im, ax=axes, orientation='horizontal', 
                           fraction=0.046, pad=0.04, aspect=40)
        cbar.set_label('% Features with Significant Association', 
                      fontsize=12, fontweight='bold')
        
        plt.tight_layout()
        output_file = figures_dir / f'{output_prefix}atlas_{demo_var}.png'
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
        plt.show()
        
        print(f"    ✓ Atlas heatmap saved: {output_file.name}")
        
    except Exception as e:
        print(f"    ⚠️  Atlas visualization failed: {e}")
        print(f"    Creating fallback bar plot...")
        create_fallback_barplot(regional_sig, demo_var, output_prefix, 
                               is_continuous, figures_dir)


def create_fallback_barplot(regional_sig, demo_var, output_prefix, 
                            is_continuous, figures_dir):
    """Fallback bar plot if atlas fails."""
    regions = sorted(regional_sig.keys())
    values = [regional_sig[r] for r in regions]
    
    fig, ax = plt.subplots(figsize=(14, 5))
    
    cmap = plt.cm.get_cmap('YlOrRd')
    norm = plt.Normalize(vmin=0, vmax=100)
    colors = [cmap(norm(v)) for v in values]
    
    bars = ax.bar(range(len(regions)), values, color=colors, 
                  edgecolor='black', linewidth=0.5)
    
    analysis_type = "Correlations" if is_continuous else "Group Differences"
    ax.set_xlabel('Region ID', fontsize=12, fontweight='bold')
    ax.set_ylabel('% Features Significant', fontsize=12, fontweight='bold')
    ax.set_title(f'{demo_var} - Regional {analysis_type}', 
                fontsize=14, fontweight='bold')
    ax.set_xticks(range(len(regions)))
    ax.set_xticklabels(regions, rotation=45, ha='right')
    ax.grid(True, alpha=0.3, axis='y')
    ax.set_ylim(0, 100)
    
    plt.tight_layout()
    output_file = figures_dir / f'{output_prefix}atlas_{demo_var}.png'
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"    ✓ Bar plot saved: {output_file.name}")


# ================================================================
# 4. CREATE ATLAS-BASED HEATMAPS
# Replace your existing heatmap code with this
# ================================================================
print("\n" + "-"*80)
print("4. CREATING ATLAS-BASED HEATMAPS")
print("-"*80)

# SET YOUR ATLAS PATH HERE
ATLAS_PATH = '/home/falcetta/ISBI2025/LIANE/ArterialAtlas.nii.gz'  # UPDATE THIS

# Check if atlas exists
atlas_available = Path(ATLAS_PATH).exists() if ATLAS_PATH else False
if atlas_available:
    print(f"✓ Atlas found at: {ATLAS_PATH}")
else:
    print(f"⚠️  Atlas not found. Will create bar plots instead.")
    print(f"   Set ATLAS_PATH to enable atlas visualization.")

# ----------------------------------------------------------------
# Heatmap 1: Continuous demographics (one atlas per variable)
# ----------------------------------------------------------------
if len(all_continuous_results) > 0 and len(continuous_vars) > 0:
    print("\n  Creating continuous demographics atlas heatmaps...")
    
    for demo_var in continuous_vars:
        print(f"\n  Processing {demo_var}...")
        create_atlas_heatmap_from_dataframe(
            atlas_path=ATLAS_PATH,
            results_df=continuous_results_df,
            demo_var=demo_var,
            is_continuous=True,
            output_prefix='regional_continuous_',
            figures_dir=FIGURES_DIR
        )
    
    print("\n  ✓ All continuous demographics atlas heatmaps complete")

# ----------------------------------------------------------------
# Heatmap 2: Categorical demographics (one atlas per variable)
# ----------------------------------------------------------------
if len(all_categorical_results) > 0 and len(categorical_vars) > 0:
    print("\n  Creating categorical demographics atlas heatmaps...")
    
    for demo_var in categorical_vars:
        print(f"\n  Processing {demo_var}...")
        create_atlas_heatmap_from_dataframe(
            atlas_path=ATLAS_PATH,
            results_df=categorical_results_df,
            demo_var=demo_var,
            is_continuous=False,
            output_prefix='regional_categorical_',
            figures_dir=FIGURES_DIR
        )
    
    print("\n  ✓ All categorical demographics atlas heatmaps complete")

# ----------------------------------------------------------------
# Optional: Combined heatmap averaging all demographics
# ----------------------------------------------------------------
print("\n  Creating combined atlas heatmap...")

# Combine all regional significance values
all_regional_sig = {}

# Process continuous variables
if len(all_continuous_results) > 0:
    for demo_var in continuous_vars:
        demo_results = continuous_results_df[continuous_results_df['demographic'] == demo_var]
        for region in demo_results['region'].unique():
            region_data = demo_results[demo_results['region'] == region]
            n_sig = region_data['pearson_significant'].sum()
            total = len(region_data)
            pct_sig = 100 * n_sig / total if total > 0 else 0
            
            if region not in all_regional_sig:
                all_regional_sig[region] = []
            all_regional_sig[region].append(pct_sig)

# Process categorical variables
if len(all_categorical_results) > 0:
    for demo_var in categorical_vars:
        demo_results = categorical_results_df[categorical_results_df['demographic'] == demo_var]
        for region in demo_results['region'].unique():
            region_data = demo_results[demo_results['region'] == region]
            n_sig = region_data['anova_significant'].sum()
            total = len(region_data)
            pct_sig = 100 * n_sig / total if total > 0 else 0
            
            if region not in all_regional_sig:
                all_regional_sig[region] = []
            all_regional_sig[region].append(pct_sig)

# Average across all demographics
avg_regional_sig = {region: np.mean(values) for region, values in all_regional_sig.items()}

# Create combined atlas
if atlas_available:
    try:
        import nibabel as nib
        
        print(f"  Loading atlas from {ATLAS_PATH}...")
        atlas_img = nib.load(ATLAS_PATH)
        atlas_data = atlas_img.get_fdata()
        
        value_map = np.zeros_like(atlas_data)
        for region_id, pct_sig in avg_regional_sig.items():
            mask = atlas_data == region_id
            if np.any(mask):
                value_map[mask] = pct_sig
        
        # Crop and prepare slices
        atlas_cropped = atlas_data[50:-50, 40:-40, 50:-50]
        value_cropped = value_map[50:-50, 40:-40, 50:-50]
        
        x, y, z = atlas_cropped.shape
        x_slice, y_slice, z_slice = x//2, y//2, z//2
        
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        
        slices = {
            'Sagittal': (np.rot90(atlas_cropped[x_slice, :, :]), 
                        np.rot90(value_cropped[x_slice, :, :])),
            'Coronal': (np.rot90(atlas_cropped[:, y_slice, :]), 
                       np.rot90(value_cropped[:, y_slice, :])),
            'Axial': (np.rot90(atlas_cropped[:, :, z_slice]), 
                     np.rot90(value_cropped[:, :, z_slice]))
        }
        
        vmin, vmax = 0, 100
        for ax, (view_name, (atlas_sl, val_sl)) in zip(axes, slices.items()):
            im = plot_atlas_panel(ax, atlas_sl, val_sl, view_name, 
                                 cmap='YlOrRd', vmin=vmin, vmax=vmax)
        
        n_demos = len(continuous_vars) + len(categorical_vars)
        fig.suptitle(f'Combined Atlas - Average Across {n_demos} Demographics', 
                    fontsize=16, fontweight='bold', y=0.98)
        
        cbar = fig.colorbar(im, ax=axes, orientation='horizontal', 
                           fraction=0.046, pad=0.04, aspect=40)
        cbar.set_label('Average % Features with Significant Association', 
                      fontsize=12, fontweight='bold')
        
        plt.tight_layout()
        plt.savefig(FIGURES_DIR / 'regional_combined_atlas_heatmap.png', 
                   dpi=300, bbox_inches='tight')
        plt.show()
        
        print("    ✓ Combined atlas heatmap saved")
        
    except Exception as e:
        print(f"    ⚠️  Combined atlas failed: {e}")

print("\n" + "="*80)
print("✅ ATLAS-BASED HEATMAP GENERATION COMPLETE")
print("="*80)
print(f"\nGenerated outputs:")
if len(continuous_vars) > 0:
    print(f"  • {len(continuous_vars)} continuous demographics atlas heatmaps")
if len(categorical_vars) > 0:
    print(f"  • {len(categorical_vars)} categorical demographics atlas heatmaps")
print(f"  • 1 combined atlas heatmap (average across all demographics)")
print(f"\nAll files saved to: {FIGURES_DIR}")

In [None]:
# ================================================================
# ENHANCED ATLAS-BASED HEATMAP WITH FEATURE IMPORTANCE ANALYSIS
# Add this after your atlas heatmap creation
# ================================================================
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# ================================================================
# FEATURE IMPORTANCE ANALYSIS FUNCTIONS
# ================================================================

def analyze_regional_feature_importance_continuous(results_df, demo_var, top_n=3):
    """
    Identify which features drive significance in each region for continuous demographics.
    
    Parameters:
    -----------
    results_df : pd.DataFrame
        continuous_results_df
    demo_var : str
        Demographic variable (e.g., 'AGE')
    top_n : int
        Number of top features to report per region
        
    Returns:
    --------
    pd.DataFrame with columns: region, feature, r_value, p_value, rank
    """
    # Filter for this demographic
    demo_results = results_df[results_df['demographic'] == demo_var].copy()
    
    # Only keep significant results
    demo_results = demo_results[demo_results['pearson_significant'] == True]
    
    if len(demo_results) == 0:
        return pd.DataFrame()
    
    # Rank features within each region by effect size (absolute correlation)
    demo_results['abs_r'] = demo_results['pearson_r'].abs()
    
    regional_importance = []
    
    for region in sorted(demo_results['region'].unique()):
        region_data = demo_results[demo_results['region'] == region].copy()
        
        # Sort by effect size (absolute correlation)
        region_data = region_data.sort_values('abs_r', ascending=False)
        
        # Get top N features
        for idx, row in enumerate(region_data.head(top_n).itertuples(), 1):
            regional_importance.append({
                'region': region,
                'rank': idx,
                'feature': row.feature,
                'r_value': row.pearson_r,
                'abs_r': row.abs_r,
                'p_value': row.pearson_p,  # ← FIXED: was pearson_pvalue
                'n_sig_in_region': len(region_data)
            })
    
    return pd.DataFrame(regional_importance)


def analyze_regional_feature_importance_categorical(results_df, demo_var, top_n=3):
    """
    Identify which features drive significance in each region for categorical demographics.
    
    Parameters:
    -----------
    results_df : pd.DataFrame
        categorical_results_df
    demo_var : str
        Demographic variable (e.g., 'SEX_ID')
    top_n : int
        Number of top features to report per region
        
    Returns:
    --------
    pd.DataFrame with columns: region, feature, eta_squared, f_statistic, p_value, rank
    """
    # Filter for this demographic
    demo_results = results_df[results_df['demographic'] == demo_var].copy()
    
    # Only keep significant results
    demo_results = demo_results[demo_results['anova_significant'] == True]
    
    if len(demo_results) == 0:
        return pd.DataFrame()
    
    regional_importance = []
    
    for region in sorted(demo_results['region'].unique()):
        region_data = demo_results[demo_results['region'] == region].copy()
        
        # Sort by effect size (eta squared)
        region_data = region_data.sort_values('eta_squared', ascending=False)
        
        # Get top N features
        for idx, row in enumerate(region_data.head(top_n).itertuples(), 1):
            regional_importance.append({
                'region': region,
                'rank': idx,
                'feature': row.feature,
                'eta_squared': row.eta_squared,
                'f_statistic': row.f_statistic,
                'p_value': row.anova_p,
                'n_sig_in_region': len(region_data)
            })
    
    return pd.DataFrame(regional_importance)

def create_feature_importance_summary_continuous(results_df, demo_var, figures_dir):
    """Create visual summary of top features per region for continuous demographics."""
    
    importance_df = analyze_regional_feature_importance_continuous(results_df, demo_var, top_n=3)
    
    if len(importance_df) == 0:
        print(f"    No significant results for {demo_var}")
        return
    
    # Create text summary
    summary_lines = [
        f"\n{'='*80}",
        f"FEATURE IMPORTANCE ANALYSIS: {demo_var}",
        f"{'='*80}\n",
        f"Top features driving regional significance patterns:\n"
    ]
    
    # Group by region and create formatted summary
    regions_with_sig = importance_df['region'].unique()
    
    for region in sorted(regions_with_sig):
        region_data = importance_df[importance_df['region'] == region]
        n_sig = region_data.iloc[0]['n_sig_in_region']
        
        summary_lines.append(f"\n📍 REGION {region} ({n_sig} significant features total):")
        summary_lines.append("-" * 60)
        
        for _, row in region_data.iterrows():
            direction = "↑ Positive" if row['r_value'] > 0 else "↓ Negative"
            summary_lines.append(
                f"  {row['rank']}. {row['feature']:<35} "
                f"r={row['r_value']:>6.3f} {direction:>12} "
                f"(p={row['p_value']:.2e})"
            )
    
    # Print summary
    summary_text = '\n'.join(summary_lines)
    print(summary_text)
    
    # Save to file
    output_file = figures_dir / f'feature_importance_{demo_var}_detailed.txt'
    with open(output_file, 'w') as f:
        f.write(summary_text)
    
    print(f"\n✓ Detailed feature importance saved to: {output_file.name}")
    
    # Create visualization: Heatmap of top features by region
    create_feature_importance_heatmap_continuous(importance_df, demo_var, figures_dir)
    
    return importance_df


def create_feature_importance_summary_categorical(results_df, demo_var, figures_dir):
    """Create visual summary of top features per region for categorical demographics."""
    
    importance_df = analyze_regional_feature_importance_categorical(results_df, demo_var, top_n=3)
    
    if len(importance_df) == 0:
        print(f"    No significant results for {demo_var}")
        return
    
    # Create text summary
    summary_lines = [
        f"\n{'='*80}",
        f"FEATURE IMPORTANCE ANALYSIS: {demo_var}",
        f"{'='*80}\n",
        f"Top features driving regional group differences:\n"
    ]
    
    # Group by region
    regions_with_sig = importance_df['region'].unique()
    
    for region in sorted(regions_with_sig):
        region_data = importance_df[importance_df['region'] == region]
        n_sig = region_data.iloc[0]['n_sig_in_region']
        
        summary_lines.append(f"\n📍 REGION {region} ({n_sig} significant features total):")
        summary_lines.append("-" * 60)
        
        for _, row in region_data.iterrows():
            effect_size = "Large" if row['eta_squared'] > 0.14 else "Medium" if row['eta_squared'] > 0.06 else "Small"
            summary_lines.append(
                f"  {row['rank']}. {row['feature']:<35} "
                f"η²={row['eta_squared']:>6.3f} ({effect_size:>6}) "
                f"F={row['f_statistic']:>6.2f} (p={row['p_value']:.2e})"
            )
    
    # Print summary
    summary_text = '\n'.join(summary_lines)
    print(summary_text)
    
    # Save to file
    output_file = figures_dir / f'feature_importance_{demo_var}_detailed.txt'
    with open(output_file, 'w') as f:
        f.write(summary_text)
    
    print(f"\n✓ Detailed feature importance saved to: {output_file.name}")
    
    # Create visualization
    create_feature_importance_heatmap_categorical(importance_df, demo_var, figures_dir)
    
    return importance_df


def create_feature_importance_heatmap_continuous(importance_df, demo_var, figures_dir):
    """Create heatmap showing which features are most important in each region."""
    
    # Pivot to get feature × region matrix with effect sizes
    pivot_data = importance_df.pivot_table(
        index='feature',
        columns='region',
        values='abs_r',
        aggfunc='first'  # Take the first (highest ranked) value if duplicates
    )
    
    # Fill NaN with 0 (feature not in top 3 for that region)
    pivot_data = pivot_data.fillna(0)
    
    # Create figure
    fig, ax = plt.subplots(figsize=(max(12, len(pivot_data.columns)*0.5), 
                                     max(8, len(pivot_data)*0.4)))
    
    # Create heatmap
    sns.heatmap(pivot_data, annot=True, fmt='.2f', cmap='YlOrRd', 
                cbar_kws={'label': '|Correlation Coefficient|'},
                linewidths=0.5, linecolor='gray', ax=ax,
                vmin=0, vmax=pivot_data.max().max())
    
    ax.set_title(f'{demo_var}: Top Feature Importance by Region\n(Absolute Correlation Coefficients)', 
                fontweight='bold', fontsize=14)
    ax.set_xlabel('Region ID', fontweight='bold', fontsize=12)
    ax.set_ylabel('Feature', fontweight='bold', fontsize=12)
    
    plt.tight_layout()
    
    output_file = figures_dir / f'feature_importance_heatmap_{demo_var}.png'
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Feature importance heatmap saved: {output_file.name}")


def create_feature_importance_heatmap_categorical(importance_df, demo_var, figures_dir):
    """Create heatmap showing which features are most important in each region."""
    
    # Pivot to get feature × region matrix with effect sizes
    pivot_data = importance_df.pivot_table(
        index='feature',
        columns='region',
        values='eta_squared',
        aggfunc='first'
    )
    
    # Fill NaN with 0
    pivot_data = pivot_data.fillna(0)
    
    # Create figure
    fig, ax = plt.subplots(figsize=(max(12, len(pivot_data.columns)*0.5), 
                                     max(8, len(pivot_data)*0.4)))
    
    # Create heatmap
    sns.heatmap(pivot_data, annot=True, fmt='.3f', cmap='YlOrRd', 
                cbar_kws={'label': 'Eta Squared (η²)'},
                linewidths=0.5, linecolor='gray', ax=ax,
                vmin=0, vmax=pivot_data.max().max())
    
    ax.set_title(f'{demo_var}: Top Feature Importance by Region\n(Effect Sizes - η²)', 
                fontweight='bold', fontsize=14)
    ax.set_xlabel('Region ID', fontweight='bold', fontsize=12)
    ax.set_ylabel('Feature', fontweight='bold', fontsize=12)
    
    plt.tight_layout()
    
    output_file = figures_dir / f'feature_importance_heatmap_{demo_var}.png'
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Feature importance heatmap saved: {output_file.name}")


def create_regional_feature_profile_plot(importance_df, demo_var, figures_dir, is_continuous=True):
    """
    Create bar plot showing distribution of significant features across regions.
    Identifies which regions have most/least feature diversity.
    """
    
    # Count significant features per region
    region_counts = importance_df.groupby('region')['n_sig_in_region'].first().sort_values(ascending=False)
    
    if len(region_counts) == 0:
        return
    
    fig, ax = plt.subplots(figsize=(14, 6))
    
    # Create colormap based on counts
    cmap = plt.cm.get_cmap('YlOrRd')
    norm = plt.Normalize(vmin=region_counts.min(), vmax=region_counts.max())
    colors = [cmap(norm(v)) for v in region_counts.values]
    
    bars = ax.bar(range(len(region_counts)), region_counts.values, color=colors,
                  edgecolor='black', linewidth=0.5)
    
    ax.set_xlabel('Region ID', fontsize=12, fontweight='bold')
    ax.set_ylabel('Number of Significant Features', fontsize=12, fontweight='bold')
    
    analysis_type = "Correlations" if is_continuous else "Group Differences"
    ax.set_title(f'{demo_var}: Regional Feature Diversity\nNumber of Significant {analysis_type} per Region',
                fontsize=14, fontweight='bold')
    
    ax.set_xticks(range(len(region_counts)))
    ax.set_xticklabels(region_counts.index, rotation=45, ha='right')
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add value labels on bars
    for i, (bar, count) in enumerate(zip(bars, region_counts.values)):
        ax.text(bar.get_x() + bar.get_width()/2., bar.get_height(),
               f'{int(count)}', ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    
    output_file = figures_dir / f'regional_feature_diversity_{demo_var}.png'
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Regional feature diversity plot saved: {output_file.name}")


def create_overall_feature_ranking(all_importance_dfs, demo_var, figures_dir, is_continuous=True):
    """
    Create overall ranking of features across all regions.
    Shows which features are consistently important.
    """
    
    if len(all_importance_dfs) == 0:
        return
    
    # Count how many regions each feature appears in (top 3)
    feature_frequency = all_importance_dfs['feature'].value_counts()
    
    # Calculate average effect size for each feature
    if is_continuous:
        feature_avg_effect = all_importance_dfs.groupby('feature')['abs_r'].mean()
    else:
        feature_avg_effect = all_importance_dfs.groupby('feature')['eta_squared'].mean()
    
    # Combine into summary
    feature_summary = pd.DataFrame({
        'frequency': feature_frequency,
        'avg_effect': feature_avg_effect
    }).sort_values('frequency', ascending=False)
    
    # Create figure with two subplots
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Plot 1: Frequency (in how many regions is this feature important?)
    top_features = feature_summary.head(10)
    
    ax1.barh(range(len(top_features)), top_features['frequency'].values, 
            color='steelblue', edgecolor='black', linewidth=0.5)
    ax1.set_yticks(range(len(top_features)))
    ax1.set_yticklabels(top_features.index)
    ax1.set_xlabel('Number of Regions (Top 3)', fontsize=11, fontweight='bold')
    ax1.set_title('Most Frequently Important Features', fontsize=12, fontweight='bold')
    ax1.grid(True, alpha=0.3, axis='x')
    ax1.invert_yaxis()
    
    # Add value labels
    for i, v in enumerate(top_features['frequency'].values):
        ax1.text(v, i, f' {int(v)}', va='center', fontsize=9)
    
    # Plot 2: Average effect size
    effect_label = 'Average |r|' if is_continuous else 'Average η²'
    
    ax2.barh(range(len(top_features)), top_features['avg_effect'].values,
            color='coral', edgecolor='black', linewidth=0.5)
    ax2.set_yticks(range(len(top_features)))
    ax2.set_yticklabels(top_features.index)
    ax2.set_xlabel(effect_label, fontsize=11, fontweight='bold')
    ax2.set_title('Average Effect Size', fontsize=12, fontweight='bold')
    ax2.grid(True, alpha=0.3, axis='x')
    ax2.invert_yaxis()
    
    # Add value labels
    for i, v in enumerate(top_features['avg_effect'].values):
        ax2.text(v, i, f' {v:.3f}', va='center', fontsize=9)
    
    fig.suptitle(f'{demo_var}: Overall Feature Importance Ranking', 
                fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    
    output_file = figures_dir / f'overall_feature_ranking_{demo_var}.png'
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"✓ Overall feature ranking saved: {output_file.name}")
    
    # Save summary table
    summary_file = figures_dir / f'overall_feature_ranking_{demo_var}.csv'
    feature_summary.to_csv(summary_file)
    print(f"✓ Feature ranking table saved: {summary_file.name}")


# ================================================================
# 5. FEATURE IMPORTANCE ANALYSIS FOR ALL DEMOGRAPHICS
# Add this as a new section after your atlas heatmaps
# ================================================================
print("\n" + "="*80)
print("5. FEATURE IMPORTANCE ANALYSIS")
print("="*80)
print("\nAnalyzing which features drive regional significance patterns...\n")

# ----------------------------------------------------------------
# Continuous Demographics
# ----------------------------------------------------------------
if len(all_continuous_results) > 0 and len(continuous_vars) > 0:
    print("\n" + "-"*80)
    print("CONTINUOUS DEMOGRAPHICS - FEATURE IMPORTANCE")
    print("-"*80)
    
    all_continuous_importance = {}
    
    for demo_var in continuous_vars:
        print(f"\n{'='*80}")
        print(f"Analyzing: {demo_var}")
        print('='*80)
        
        # Get detailed feature importance
        importance_df = create_feature_importance_summary_continuous(
            continuous_results_df, demo_var, FIGURES_DIR
        )
        
        if importance_df is not None and len(importance_df) > 0:
            all_continuous_importance[demo_var] = importance_df
            
            # Create regional feature profile
            create_regional_feature_profile_plot(
                importance_df, demo_var, FIGURES_DIR, is_continuous=True
            )
            
            # Create overall feature ranking
            create_overall_feature_ranking(
                importance_df, demo_var, FIGURES_DIR, is_continuous=True
            )

# ----------------------------------------------------------------
# Categorical Demographics
# ----------------------------------------------------------------
if len(all_categorical_results) > 0 and len(categorical_vars) > 0:
    print("\n" + "-"*80)
    print("CATEGORICAL DEMOGRAPHICS - FEATURE IMPORTANCE")
    print("-"*80)
    
    all_categorical_importance = {}
    
    for demo_var in categorical_vars:
        print(f"\n{'='*80}")
        print(f"Analyzing: {demo_var}")
        print('='*80)
        
        # Get detailed feature importance
        importance_df = create_feature_importance_summary_categorical(
            categorical_results_df, demo_var, FIGURES_DIR
        )
        
        if importance_df is not None and len(importance_df) > 0:
            all_categorical_importance[demo_var] = importance_df
            
            # Create regional feature profile
            create_regional_feature_profile_plot(
                importance_df, demo_var, FIGURES_DIR, is_continuous=False
            )
            
            # Create overall feature ranking
            create_overall_feature_ranking(
                importance_df, demo_var, FIGURES_DIR, is_continuous=False
            )

print("\n" + "="*80)
print("✅ FEATURE IMPORTANCE ANALYSIS COMPLETE")
print("="*80)
print("\nGenerated outputs for each demographic variable:")
print("  • Detailed text summary (feature_importance_[VAR]_detailed.txt)")
print("  • Feature importance heatmap (feature_importance_heatmap_[VAR].png)")
print("  • Regional feature diversity plot (regional_feature_diversity_[VAR].png)")
print("  • Overall feature ranking plot (overall_feature_ranking_[VAR].png)")
print("  • Feature ranking table (overall_feature_ranking_[VAR].csv)")
print(f"\nAll files saved to: {FIGURES_DIR}")

In [None]:
# ================================================================
# 5. CREATE FEATURE-SPECIFIC REGIONAL MAPS
# ================================================================
print("\n" + "-"*80)
print("5. CREATING FEATURE-SPECIFIC REGIONAL SIGNIFICANCE MAPS")
print("-"*80)

# For top 5 features, show which regions have significant demographic effects
if len(all_continuous_results) > 0 or len(all_categorical_results) > 0:
    
    # Get top features by total number of significant findings
    feature_sig_counts = {}
    
    if len(all_continuous_results) > 0:
        for feature in ALL_FEATURES:
            if 'region' in feature:
                continue
            cont_count = continuous_results_df[
                (continuous_results_df['feature'] == feature) & 
                (continuous_results_df['pearson_significant'])
            ].shape[0]
            feature_sig_counts[feature] = feature_sig_counts.get(feature, 0) + cont_count
    
    if len(all_categorical_results) > 0:
        for feature in ALL_FEATURES:
            if 'region' in feature:
                continue
            cat_count = categorical_results_df[
                (categorical_results_df['feature'] == feature) & 
                (categorical_results_df['anova_significant'])
            ].shape[0]
            feature_sig_counts[feature] = feature_sig_counts.get(feature, 0) + cat_count
    
    top_features = sorted(feature_sig_counts.items(), key=lambda x: x[1], reverse=True)[:5]
    
    print(f"\n  Top 5 features by demographic significance:")
    for feat, count in top_features:
        print(f"    {feat}: {count} significant region-demographic combinations")
    
    for feature, _ in top_features:
        if 'region' in feature:
            continue
        print(f"\n  Creating regional significance map for {feature}...")
        
        # Collect significance data by region
        region_sig_data = {}
        
        for region in regions:
            sig_demos = []
            
            # Check continuous variables
            if len(all_continuous_results) > 0:
                cont_region = continuous_results_df[
                    (continuous_results_df['feature'] == feature) & 
                    (continuous_results_df['region'] == region) & 
                    (continuous_results_df['pearson_significant'])
                ]
                sig_demos.extend(cont_region['demographic'].tolist())
            
            # Check categorical variables
            if len(all_categorical_results) > 0:
                cat_region = categorical_results_df[
                    (categorical_results_df['feature'] == feature) & 
                    (categorical_results_df['region'] == region) & 
                    (categorical_results_df['anova_significant'])
                ]
                sig_demos.extend(cat_region['demographic'].tolist())
            
            region_sig_data[region] = len(set(sig_demos))  # Count unique demographics
        
        # Create bar plot
        fig, ax = plt.subplots(figsize=(14, max(6, len(regions)*0.25)))
        
        sorted_regions = sorted(region_sig_data.keys())
        counts = [region_sig_data[r] for r in sorted_regions]
        
        bars = ax.barh(range(len(sorted_regions)), counts, 
                        color='steelblue', edgecolor='black', alpha=0.7)
        ax.set_yticks(range(len(sorted_regions)))
        ax.set_yticklabels([f'Region {r}' for r in sorted_regions], fontsize=8)
        ax.set_xlabel('Number of Demographics with Significant Effect', fontweight='bold')
        ax.set_title(f'{feature} - Regional Demographic Sensitivity', 
                    fontweight='bold', fontsize=14)
        ax.grid(True, alpha=0.3, axis='x')
        
        # Color bars by count
        max_count = max(counts) if counts else 1
        for bar, count in zip(bars, counts):
            bar.set_color(plt.cm.YlOrRd(count / max_count))
        
        plt.tight_layout()
        plt.savefig(FIGURES_DIR / f'regional_demographic_sensitivity_{feature}.png', 
                    dpi=300, bbox_inches='tight')
        plt.show()
        print(f"    ✓ Saved regional sensitivity map for {feature}")


In [None]:

# ================================================================
# 6. COMPREHENSIVE SUMMARY TABLE
# ================================================================
print("\n" + "-"*80)
print("6. CREATING COMPREHENSIVE SUMMARY")
print("-"*80)

summary_rows = []

# Summarize each feature
for feature in ALL_FEATURES:
    if 'region' in feature:
        continue
    row = {
        'feature': feature,
        'category': feature_categories[feature]
    }
    
    # Continuous demographics
    if len(all_continuous_results) > 0:
        feat_cont = continuous_results_df[continuous_results_df['feature'] == feature]
        for demo_var in continuous_vars:
            demo_feat = feat_cont[feat_cont['demographic'] == demo_var]
            if len(demo_feat) > 0:
                n_sig = demo_feat['pearson_significant'].sum()
                pct_sig = 100 * n_sig / len(demo_feat)
                max_r = demo_feat['pearson_r'].abs().max()
                row[f'{demo_var}_n_regions_sig'] = n_sig
                row[f'{demo_var}_pct_regions_sig'] = pct_sig
                row[f'{demo_var}_max_abs_r'] = max_r
    
    # Categorical demographics
    if len(all_categorical_results) > 0:
        feat_cat = categorical_results_df[categorical_results_df['feature'] == feature]
        for demo_var in categorical_vars:
            demo_feat = feat_cat[feat_cat['demographic'] == demo_var]
            if len(demo_feat) > 0:
                n_sig = demo_feat['anova_significant'].sum()
                pct_sig = 100 * n_sig / len(demo_feat)
                max_eta = demo_feat['eta_squared'].max()
                row[f'{demo_var}_n_regions_sig'] = n_sig
                row[f'{demo_var}_pct_regions_sig'] = pct_sig
                row[f'{demo_var}_max_eta_squared'] = max_eta
    
    summary_rows.append(row)

summary_df = pd.DataFrame(summary_rows)
summary_df.to_csv(TABLES_DIR / 'regional_demographics_summary.csv', index=False)
print(f"✓ Comprehensive summary saved to {TABLES_DIR / 'regional_demographics_summary.csv'}")

# Display top features
print(f"\nTop features by average demographic sensitivity:")
sig_cols = [c for c in summary_df.columns if '_pct_regions_sig' in c]
if len(sig_cols) > 0:
    summary_df['avg_pct_sig'] = summary_df[sig_cols].mean(axis=1)
    top_summary = summary_df.nlargest(10, 'avg_pct_sig')[['feature', 'category', 'avg_pct_sig'] + sig_cols]
    display(top_summary)



In [None]:
# ================================================================
# FINAL SUMMARY
# ================================================================
print("\n" + "="*80)
print("✅ STATISTICAL TESTING COMPLETE")
print("="*80)
print("\nGenerated outputs:")
print(f"  • Continuous demographics tests: {TABLES_DIR / 'regional_continuous_demographics_tests.csv'}")
print(f"  • Categorical demographics tests: {TABLES_DIR / 'regional_categorical_demographics_tests.csv'}")
print(f"  • Comprehensive summary: {TABLES_DIR / 'regional_demographics_summary.csv'}")
print(f"  • Significance heatmaps (by feature and demographic)")
print(f"  • Regional sensitivity maps (top 5 features)")

if len(all_continuous_results) > 0:
    total_sig_cont = continuous_results_df['pearson_significant'].sum()
    total_tests_cont = len(continuous_results_df)
    print(f"\n📊 CONTINUOUS DEMOGRAPHICS:")
    print(f"  Total significant correlations: {total_sig_cont} / {total_tests_cont} ({100*total_sig_cont/total_tests_cont:.1f}%)")

if len(all_categorical_results) > 0:
    total_sig_cat = categorical_results_df['anova_significant'].sum()
    total_tests_cat = len(categorical_results_df)
    print(f"\n📊 CATEGORICAL DEMOGRAPHICS:")
    print(f"  Total significant group differences: {total_sig_cat} / {total_tests_cat} ({100*total_sig_cat/total_tests_cat:.1f}%)")

print("\n💡 KEY INSIGHTS:")
print("  • Each test controls for FDR at 5% within demographic variable")
print("  • Both parametric (Pearson/ANOVA) and non-parametric (Spearman/Kruskal-Wallis) tests")
print("  • Effect sizes included (Pearson r, η², Cohen's d)")
print("  • Results show feature-specific regional sensitivity to demographics")

---
## 11. Hemispheric Asymmetry Analysis (if applicable)

In [None]:
if not HAS_HEMISPHERE:
    print("\n⚠️  No hemisphere information available. Skipping hemispheric asymmetry analysis.")
else:
    print("\n" + "="*80)
    print("HEMISPHERIC ASYMMETRY ANALYSIS")
    print("="*80)
    print("\n⚠️  Hemisphere analysis requires manual implementation based on your specific data structure.")
    print("     Please refer to the IXI_Vessel_Advanced_Analysis.ipynb for detailed hemisphere analysis code.")

---
## 12. Advanced Analyses

### 12.1 Machine Learning: Age Prediction

In [None]:
if 'AGE' not in df.columns:
    print("\n⚠️  AGE variable not available. Skipping ML age prediction.")
else:
    print("\n" + "="*80)
    print("MACHINE LEARNING: AGE PREDICTION")
    print("="*80)
    
    # Prepare data - remove rows with missing age
    valid_data = df[['AGE'] + ALL_FEATURES].dropna(subset=['AGE'])
    
    # Check if we have enough data
    if len(valid_data) < 20:
        print(f"\n⚠️  Insufficient data for ML ({len(valid_data)} subjects with valid age). Skipping.")
    else:
        print(f"\nUsing {len(valid_data)} subjects with complete data")
        
        # Prepare features and target
        X = valid_data[ALL_FEATURES].fillna(0)  # Impute missing feature values with 0
        y = valid_data['AGE']
        
        # Check for features with no variance
        feature_variance = X.var()
        zero_variance_features = feature_variance[feature_variance == 0].index.tolist()
        if len(zero_variance_features) > 0:
            print(f"\n⚠️  Removing {len(zero_variance_features)} zero-variance features")
            X = X.drop(columns=zero_variance_features)
            features_for_ml = [f for f in ALL_FEATURES if f not in zero_variance_features]
        else:
            features_for_ml = ALL_FEATURES
        
        print(f"Using {len(features_for_ml)} features for prediction")
        
        # Standardize features
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        # Random Forest age prediction with cross-validation
        print("\nTraining Random Forest regressor...")
        rf = RandomForestRegressor(n_estimators=100, max_depth=10, random_state=42, n_jobs=-1)
        
        # 5-fold cross-validation
        cv_scores = cross_val_score(rf, X_scaled, y, cv=5, scoring='r2')
        cv_mae = -cross_val_score(rf, X_scaled, y, cv=5, scoring='neg_mean_absolute_error')
        
        print(f"\nCross-validation results (5-fold):")
        print(f"  R² score: {cv_scores.mean():.3f} ± {cv_scores.std():.3f}")
        print(f"  MAE: {cv_mae.mean():.2f} ± {cv_mae.std():.2f} years")
        
        # Train final model and get predictions
        rf.fit(X_scaled, y)
        y_pred = cross_val_predict(rf, X_scaled, y, cv=5)
        
        # Feature importance
        feature_importance = pd.DataFrame({
            'feature': features_for_ml,
            'importance': rf.feature_importances_
        }).sort_values('importance', ascending=False)
        
        print(f"\nTop 10 most important features for age prediction:")
        display(feature_importance.head(10))
        
        feature_importance.to_csv(TABLES_DIR / 'age_prediction_feature_importance.csv', index=False)
        
        # Plot predictions vs actual
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.scatter(y, y_pred, alpha=0.5, s=30)
        ax.plot([y.min(), y.max()], [y.min(), y.max()], 'r--', lw=2)
        ax.set_xlabel('Actual Age (years)', fontweight='bold')
        ax.set_ylabel('Predicted Age (years)', fontweight='bold')
        ax.set_title(f'Age Prediction from Vessel Features\nR²={cv_scores.mean():.3f}, MAE={cv_mae.mean():.1f} years', 
                     fontweight='bold')
        ax.grid(True, alpha=0.3)
        
        # Add correlation coefficient
        correlation = np.corrcoef(y, y_pred)[0, 1]
        ax.text(0.05, 0.95, f'Pearson r = {correlation:.3f}', 
                transform=ax.transAxes, fontsize=10, 
                verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        
        plt.tight_layout()
        plt.savefig(FIGURES_DIR / 'age_prediction.png', dpi=300, bbox_inches='tight')
        plt.show()
        
        print(f"\n✓ Age prediction analysis complete")
        print(f"  Subjects used: {len(valid_data)}")
        print(f"  Features used: {len(features_for_ml)}")

### 12.2 Age × Sex Interaction Effects

In [None]:
if 'AGE' not in df.columns or 'SEX_ID' not in df.columns:
    print("\n⚠️  AGE or SEX_ID not available. Skipping interaction analysis.")
else:
    print("\n" + "="*80)
    print("AGE × SEX INTERACTION EFFECTS")
    print("="*80)
    
    interaction_results = []
    
    for feature in ALL_FEATURES:
        # Prepare data
        data_for_model = df[['AGE', 'SEX_ID', feature]].dropna()
        
        if len(data_for_model) < 20:
            continue
        
        try:
            # Fit model with interaction
            model = smf.ols(f'{feature} ~ AGE + C(SEX_ID) + AGE:C(SEX_ID)', data=data_for_model).fit()
            
            # Get interaction term p-value
            interaction_pval = model.pvalues['AGE:C(SEX_ID)[T.2]'] if 'AGE:C(SEX_ID)[T.2]' in model.pvalues else np.nan
            
            interaction_results.append({
                'feature': feature,
                'category': feature_categories[feature],
                'n': len(data_for_model),
                'interaction_pvalue': interaction_pval,
                'model_r_squared': model.rsquared
            })
        except:
            continue
    
    if len(interaction_results) > 0:
        interaction_df = pd.DataFrame(interaction_results)
        interaction_df = interaction_df.dropna(subset=['interaction_pvalue'])
        
        if len(interaction_df) > 0:
            interaction_df['interaction_pvalue_fdr'] = multipletests(interaction_df['interaction_pvalue'], method='fdr_bh')[1]
            interaction_df['significant'] = interaction_df['interaction_pvalue_fdr'] < 0.05
            interaction_df = interaction_df.sort_values('interaction_pvalue')
            
            n_sig = interaction_df['significant'].sum()
            print(f"\nInteraction analysis summary:")
            print(f"  Features tested: {len(interaction_df)}")
            print(f"  Significant Age×Sex interactions (FDR<0.05): {n_sig}")
            
            if n_sig > 0:
                print(f"\nFeatures with significant Age×Sex interaction:")
                display(interaction_df[interaction_df['significant']][['feature', 'category', 'interaction_pvalue_fdr', 'model_r_squared']])
            
            interaction_df.to_csv(TABLES_DIR / 'age_sex_interactions.csv', index=False)
            print(f"\n✓ Interaction analysis saved to {TABLES_DIR / 'age_sex_interactions.csv'}")

---
## 13. Summary and Export for Paper

In [None]:
print("\n" + "="*80)
print("COMPREHENSIVE ANALYSIS SUMMARY FOR PAPER")
print("="*80)

# Helper function to convert numpy/pandas types to Python native types
def convert_to_native(obj):
    """Convert numpy/pandas types to native Python types for JSON serialization."""
    if isinstance(obj, (np.integer, np.int64, np.int32)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float64, np.float32)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif pd.isna(obj):
        return None
    return obj

summary = {
    'dataset': {
        'total_subjects': int(len(df)),
        'n_regions': int(regional_df[region_col].nunique()) if 'regional_df' in locals() and 'region_col' in locals() else 'N/A',
        'age_range': f"{df['AGE'].min():.1f}-{df['AGE'].max():.1f}" if 'AGE' in df.columns and df['AGE'].notna().any() else 'N/A',
        'age_mean': float(df['AGE'].mean()) if 'AGE' in df.columns and df['AGE'].notna().any() else 'N/A',
        'age_std': float(df['AGE'].std()) if 'AGE' in df.columns and df['AGE'].notna().any() else 'N/A',
        'male_subjects': int(len(df[df['SEX_ID']==1])) if 'SEX_ID' in df.columns else 'N/A',
        'female_subjects': int(len(df[df['SEX_ID']==2])) if 'SEX_ID' in df.columns else 'N/A',
        'n_sites': int(df['site'].nunique()) if 'site' in df.columns else 'N/A'
    },
    'features': {
        'total_features': int(len(ALL_FEATURES)),
        'morphometric': int(len(MORPHOMETRIC_FEATURES)),
        'topological': int(len(TOPOLOGICAL_FEATURES)),
        'curvature': int(len(CURVATURE_FEATURES)),
        'other': int(len(OTHER_FEATURES))
    },
    'regional_analysis': {},
    'key_findings': {}
}

# Regional demographic associations
if 'continuous_results_df' in locals() and len(continuous_results_df) > 0:
    summary['regional_analysis']['continuous_demographics'] = {}
    for demo_var in continuous_vars:
        demo_data = continuous_results_df[continuous_results_df['demographic'] == demo_var]
        n_sig = demo_data['pearson_significant'].sum()
        total_tests = len(demo_data)
        
        summary['regional_analysis']['continuous_demographics'][demo_var] = {
            'total_tests': int(total_tests),
            'significant_associations': int(n_sig),
            'percent_significant': float(100 * n_sig / total_tests) if total_tests > 0 else 0.0,
            'n_regions_with_effects': int(demo_data[demo_data['pearson_significant']]['region'].nunique()),
            'strongest_correlation': {
                'feature': str(demo_data.loc[demo_data['pearson_r'].abs().idxmax(), 'feature']) if n_sig > 0 else 'N/A',
                'region': int(demo_data.loc[demo_data['pearson_r'].abs().idxmax(), 'region']) if n_sig > 0 else 'N/A',
                'r': float(demo_data.loc[demo_data['pearson_r'].abs().idxmax(), 'pearson_r']) if n_sig > 0 else 'N/A',
                'p_fdr': float(demo_data.loc[demo_data['pearson_r'].abs().idxmax(), 'pearson_p_fdr']) if n_sig > 0 else 'N/A'
            }
        }

if 'categorical_results_df' in locals() and len(categorical_results_df) > 0:
    summary['regional_analysis']['categorical_demographics'] = {}
    for demo_var in categorical_vars:
        demo_data = categorical_results_df[categorical_results_df['demographic'] == demo_var]
        n_sig = demo_data['anova_significant'].sum()
        total_tests = len(demo_data)
        
        summary['regional_analysis']['categorical_demographics'][demo_var] = {
            'total_tests': int(total_tests),
            'significant_differences': int(n_sig),
            'percent_significant': float(100 * n_sig / total_tests) if total_tests > 0 else 0.0,
            'n_regions_with_effects': int(demo_data[demo_data['anova_significant']]['region'].nunique()),
            'largest_effect': {
                'feature': str(demo_data.loc[demo_data['eta_squared'].idxmax(), 'feature']) if n_sig > 0 else 'N/A',
                'region': int(demo_data.loc[demo_data['eta_squared'].idxmax(), 'region']) if n_sig > 0 else 'N/A',
                'eta_squared': float(demo_data.loc[demo_data['eta_squared'].idxmax(), 'eta_squared']) if n_sig > 0 else 'N/A',
                'p_fdr': float(demo_data.loc[demo_data['eta_squared'].idxmax(), 'anova_p_fdr']) if n_sig > 0 else 'N/A'
            }
        }

# Feature importance summary
if 'all_continuous_importance' in locals() and len(all_continuous_importance) > 0:
    summary['feature_importance'] = {}
    
    for demo_var, importance_df in all_continuous_importance.items():
        if len(importance_df) > 0:
            # Get most frequently important features
            feature_freq = importance_df['feature'].value_counts()
            top_feature = feature_freq.index[0] if len(feature_freq) > 0 else 'N/A'
            
            # Get average effect size for top feature
            if top_feature != 'N/A':
                top_feature_data = importance_df[importance_df['feature'] == top_feature]
                avg_effect = top_feature_data['abs_r'].mean()
            else:
                avg_effect = 'N/A'
            
            summary['feature_importance'][demo_var] = {
                'n_regions_analyzed': int(importance_df['region'].nunique()),
                'most_important_feature': str(top_feature),
                'n_regions_where_top': int(feature_freq.iloc[0]) if len(feature_freq) > 0 else 0,
                'avg_effect_size': float(avg_effect) if avg_effect != 'N/A' else 'N/A'
            }

# Traditional whole-brain findings (if available)
if 'AGE' in df.columns and 'age_corr_df' in locals() and len(age_corr_df) > 0:
    summary['key_findings']['whole_brain_age'] = {
        'n_correlated_features': int(age_corr_df['significant_pearson'].sum()),
        'strongest_correlation': {
            'feature': str(age_corr_df.iloc[0]['feature']),
            'r': float(age_corr_df.iloc[0]['pearson_r']),
            'p_fdr': float(age_corr_df.iloc[0]['pearson_p_fdr'])
        }
    }

if 'SEX_ID' in df.columns and 'sex_comp_df' in locals() and len(sex_comp_df) > 0:
    summary['key_findings']['whole_brain_sex'] = {
        'n_differences': int(sex_comp_df['significant_ttest'].sum()),
        'largest_difference': {
            'feature': str(sex_comp_df.iloc[0]['feature']),
            'cohens_d': float(sex_comp_df.iloc[0]['cohens_d']),
            'p_fdr': float(sex_comp_df.iloc[0]['t_pvalue_fdr'])
        }
    }

# ML results
if 'cv_scores' in locals():
    summary['key_findings']['age_prediction'] = {
        'r2_mean': float(cv_scores.mean()),
        'r2_std': float(cv_scores.std()),
        'mae_mean': float(cv_mae.mean()),
        'mae_std': float(cv_mae.std()),
        'n_subjects': int(len(valid_data)) if 'valid_data' in locals() else 'N/A',
        'n_features': int(len(features_for_ml)) if 'features_for_ml' in locals() else int(len(ALL_FEATURES))
    }

# Save summary
with open(TABLES_DIR / 'analysis_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

# ============================================================================
# PRINT FORMATTED SUMMARY
# ============================================================================

print("\n📊 DATASET OVERVIEW:")
print(f"   Total subjects: {summary['dataset']['total_subjects']}")
print(f"   Brain regions analyzed: {summary['dataset']['n_regions']}")
print(f"   Age range: {summary['dataset']['age_range']} years")
if summary['dataset']['age_mean'] != 'N/A':
    print(f"   Age: {summary['dataset']['age_mean']:.1f} ± {summary['dataset']['age_std']:.1f} years")
if summary['dataset']['male_subjects'] != 'N/A':
    male_pct = 100 * summary['dataset']['male_subjects'] / summary['dataset']['total_subjects']
    female_pct = 100 * summary['dataset']['female_subjects'] / summary['dataset']['total_subjects']
    print(f"   Sex: {summary['dataset']['male_subjects']} male ({male_pct:.1f}%), {summary['dataset']['female_subjects']} female ({female_pct:.1f}%)")
if summary['dataset']['n_sites'] != 'N/A':
    print(f"   Multi-center sites: {summary['dataset']['n_sites']}")

print("\n🔬 VESSEL FEATURES ANALYZED:")
print(f"   Total features: {summary['features']['total_features']}")
print(f"   • Morphometric: {summary['features']['morphometric']}")
print(f"   • Topological: {summary['features']['topological']}")
print(f"   • Curvature: {summary['features']['curvature']}")
if summary['features']['other'] > 0:
    print(f"   • Other: {summary['features']['other']}")

# Regional analysis summary
if 'regional_analysis' in summary and summary['regional_analysis']:
    print("\n🗺️  REGIONAL DEMOGRAPHIC ASSOCIATIONS:")
    print("="*80)
    
    if 'continuous_demographics' in summary['regional_analysis']:
        print("\nContinuous Demographics (Correlations):")
        for demo_var, stats in summary['regional_analysis']['continuous_demographics'].items():
            print(f"\n  {demo_var}:")
            print(f"    • Total region-feature tests: {stats['total_tests']}")
            print(f"    • Significant associations: {stats['significant_associations']} ({stats['percent_significant']:.1f}%)")
            print(f"    • Regions with effects: {stats['n_regions_with_effects']}")
            if stats['strongest_correlation']['feature'] != 'N/A':
                print(f"    • Strongest correlation:")
                print(f"      - Feature: {stats['strongest_correlation']['feature']}")
                print(f"      - Region: {stats['strongest_correlation']['region']}")
                print(f"      - r = {stats['strongest_correlation']['r']:.3f}, p(FDR) = {stats['strongest_correlation']['p_fdr']:.2e}")
    
    if 'categorical_demographics' in summary['regional_analysis']:
        print("\nCategorical Demographics (Group Differences):")
        for demo_var, stats in summary['regional_analysis']['categorical_demographics'].items():
            print(f"\n  {demo_var}:")
            print(f"    • Total region-feature tests: {stats['total_tests']}")
            print(f"    • Significant differences: {stats['significant_differences']} ({stats['percent_significant']:.1f}%)")
            print(f"    • Regions with effects: {stats['n_regions_with_effects']}")
            if stats['largest_effect']['feature'] != 'N/A':
                print(f"    • Largest effect size:")
                print(f"      - Feature: {stats['largest_effect']['feature']}")
                print(f"      - Region: {stats['largest_effect']['region']}")
                print(f"      - η² = {stats['largest_effect']['eta_squared']:.3f}, p(FDR) = {stats['largest_effect']['p_fdr']:.2e}")

# Feature importance summary
if 'feature_importance' in summary and summary['feature_importance']:
    print("\n⭐ FEATURE IMPORTANCE (Top Drivers by Region):")
    print("="*80)
    for demo_var, stats in summary['feature_importance'].items():
        if stats['most_important_feature'] != 'N/A':
            print(f"\n  {demo_var}:")
            print(f"    • Most consistently important: {stats['most_important_feature']}")
            print(f"    • Appears in top 3 for: {stats['n_regions_where_top']} / {stats['n_regions_analyzed']} regions")
            if stats['avg_effect_size'] != 'N/A':
                print(f"    • Average effect size: {stats['avg_effect_size']:.3f}")

# Traditional findings
if summary['key_findings']:
    print("\n🎯 ADDITIONAL KEY FINDINGS:")
    print("="*80)
    
    if 'whole_brain_age' in summary['key_findings']:
        age_stats = summary['key_findings']['whole_brain_age']
        print(f"\nWhole-Brain Age Associations:")
        print(f"  • Significant features: {age_stats['n_correlated_features']}")
        print(f"  • Strongest: {age_stats['strongest_correlation']['feature']} (r={age_stats['strongest_correlation']['r']:.3f})")
    
    if 'whole_brain_sex' in summary['key_findings']:
        sex_stats = summary['key_findings']['whole_brain_sex']
        print(f"\nWhole-Brain Sex Differences:")
        print(f"  • Significant features: {sex_stats['n_differences']}")
        print(f"  • Largest: {sex_stats['largest_difference']['feature']} (d={sex_stats['largest_difference']['cohens_d']:.3f})")
    
    if 'age_prediction' in summary['key_findings']:
        ml_stats = summary['key_findings']['age_prediction']
        print(f"\nAge Prediction (Machine Learning):")
        print(f"  • Cross-validated R²: {ml_stats['r2_mean']:.3f} ± {ml_stats['r2_std']:.3f}")
        print(f"  • Mean Absolute Error: {ml_stats['mae_mean']:.2f} ± {ml_stats['mae_std']:.2f} years")
        print(f"  • Subjects: {ml_stats['n_subjects']}, Features: {ml_stats['n_features']}")

print("\n" + "="*80)
print("📁 OUTPUT FILES GENERATED:")
print("="*80)

# Count files by category
csv_files = sorted(TABLES_DIR.glob('*.csv'))
txt_files = sorted(FIGURES_DIR.glob('*.txt'))
png_files = sorted(FIGURES_DIR.glob('*.png'))

print(f"\n📊 Tables & Data ({len(csv_files)} CSV files in {TABLES_DIR}):")
for file in csv_files:
    print(f"  ✓ {file.name}")

print(f"\n📝 Feature Importance Reports ({len(txt_files)} text files in {FIGURES_DIR}):")
for file in txt_files:
    print(f"  ✓ {file.name}")

print(f"\n📈 Visualizations ({len(png_files)} PNG files in {FIGURES_DIR}):")
# Group by type
atlas_files = [f for f in png_files if 'atlas' in f.name]
importance_files = [f for f in png_files if 'importance' in f.name or 'diversity' in f.name or 'ranking' in f.name]
other_files = [f for f in png_files if f not in atlas_files and f not in importance_files]

if atlas_files:
    print(f"\n  Atlas Visualizations ({len(atlas_files)}):")
    for file in atlas_files:
        print(f"    ✓ {file.name}")

if importance_files:
    print(f"\n  Feature Importance Analyses ({len(importance_files)}):")
    for file in importance_files:
        print(f"    ✓ {file.name}")

if other_files:
    print(f"\n  Other Visualizations ({len(other_files)}):")
    for file in other_files:
        print(f"    ✓ {file.name}")

print(f"\n📄 Summary: {TABLES_DIR / 'analysis_summary.json'}")

print("\n" + "="*80)
print("✅ COMPREHENSIVE POPULATION-LEVEL ANALYSIS COMPLETE!")
print("="*80)

print("\n📝 PAPER WRITING GUIDE:")
print("="*80)
print("\nFor your ISBI conference paper, use these outputs:\n")

print("MAIN FIGURES (Choose 3-4):")
print("  1. Combined atlas heatmap (regional_combined_atlas_heatmap.png)")
print("     → Shows overall spatial pattern of demographic associations")
print("  2. AGE atlas + feature importance (create multi-panel figure)")
print("     → Shows WHERE age effects occur and WHICH features drive them")
print("  3. SEX atlas (if significant effects found)")
print("     → Shows sex-specific regional patterns")
print("  4. Age prediction results (if ML analysis performed)")
print("     → Shows predictive power of vascular features")

print("\nSUPPLEMENTARY MATERIALS:")
print("  • All individual demographic atlases")
print("  • All feature importance heatmaps")
print("  • Complete statistical tables (CSV files)")
print("  • Feature ranking tables for all demographics")

print("\nMETHODS SECTION:")
print("  Use: analysis_summary.json for dataset statistics")
print("  Mention: 30 brain regions, 15 morphometric features")
print("  Describe: Regional correlation analysis + feature importance ranking")
print("  State: FDR correction (q<0.05) for multiple comparisons")

print("\nRESULTS SECTION - Key Points to Highlight:")
if 'regional_analysis' in summary and 'continuous_demographics' in summary['regional_analysis']:
    if 'AGE' in summary['regional_analysis']['continuous_demographics']:
        age_stats = summary['regional_analysis']['continuous_demographics']['AGE']
        print(f"  • Age effects: {age_stats['percent_significant']:.1f}% of region-feature tests significant")
        print(f"    - {age_stats['n_regions_with_effects']} regions show age-related changes")
        if 'feature_importance' in summary and 'AGE' in summary['feature_importance']:
            age_imp = summary['feature_importance']['AGE']
            print(f"    - {age_imp['most_important_feature']} most consistently affected")
            print(f"      (top 3 in {age_imp['n_regions_where_top']} regions)")

if 'regional_analysis' in summary and 'categorical_demographics' in summary['regional_analysis']:
    if 'SEX_ID' in summary['regional_analysis']['categorical_demographics']:
        sex_stats = summary['regional_analysis']['categorical_demographics']['SEX_ID']
        print(f"  • Sex differences: {sex_stats['percent_significant']:.1f}% of region-feature tests significant")
        print(f"    - {sex_stats['n_regions_with_effects']} regions show sex-specific patterns")

print("\nDISCUSSION POINTS:")
print("  • Spatial heterogeneity of demographic effects across brain regions")
print("  • Feature-specific mechanisms (e.g., vessel rarefaction vs. tortuosity)")
print("  • Anterior-posterior gradients (if observed)")
print("  • Clinical implications for age-related vascular changes")
print("  • Comparison with existing literature on cerebrovascular aging")

print("\n💡 TIPS:")
print("  1. Atlas visualizations are publication-ready at 300 DPI")
print("  2. Feature importance heatmaps complement atlas maps perfectly")
print("  3. Use CSV files to create custom summary tables in manuscript")
print("  4. Text files contain top 3 features per region for easy reference")
print("  5. Consider creating multi-panel figures combining atlas + importance")

print("\n🚀 NEXT STEPS:")
print("  1. Review all atlas visualizations for biological plausibility")
print("  2. Examine feature importance patterns for mechanistic insights")
print("  3. Cross-reference strongest effects with known vascular anatomy")
print("  4. Prepare figure legends explaining color scales and interpretations")
print("  5. Draft results paragraph for each key demographic variable")

print("\n" + "="*80)
print("="*80)