In [1]:
# Simplified Group Statistics Script
"""
- Only paired t-test
- Only top 3 connections per band with p < 0.01 (uncorrected)
- Clean, publication-ready figures and reports
- No multiple comparison corrections
"""

import os
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.stats import ttest_rel
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple
import warnings
warnings.filterwarnings('ignore')

# ===========================
# CONFIGURATION
# ===========================

PROJECT_BASE = '/home/jaizor/jaizor/xtra'
GROUP_OUTPUT_DIR = Path(PROJECT_BASE) / "derivatives" / "group"
SUBJECT_MATRICES_DIR = GROUP_OUTPUT_DIR / "subject_level"

# Create organized output structure
STATS_OUTPUT_DIR = GROUP_OUTPUT_DIR / "statistics"
MATRICES_DIR = STATS_OUTPUT_DIR / "matrices"
REPORTS_DIR = STATS_OUTPUT_DIR / "reports"
FIGURES_DIR = STATS_OUTPUT_DIR / "figures"

# Create directories
for dir_path in [STATS_OUTPUT_DIR, MATRICES_DIR, REPORTS_DIR, FIGURES_DIR]:
    dir_path.mkdir(parents=True, exist_ok=True)

BANDS = ["Theta", "Alpha", "Low_Beta", "High_Beta", "Low_Gamma", "High_Gamma"]
CONDITIONS = ['InPhase', 'OutofPhase']
N_ROIS = 512
MIN_SUBJECT_THRESHOLD = 8  # Minimum subjects needed for analysis

print(f"📁 Output directories created:")
print(f"   Statistics: {STATS_OUTPUT_DIR}")
print(f"   Matrices: {MATRICES_DIR}")
print(f"   Reports: {REPORTS_DIR}")
print(f"   Figures: {FIGURES_DIR}")

# ===========================
# LOAD ROI NAMES WITH FALLBACK
# ===========================

def load_roi_names() -> List[str]:
    """Load ROI names with multiple fallback options."""
    possible_files = [
        GROUP_OUTPUT_DIR / "matrix_InPhase_Alpha_group_avg.csv",
        GROUP_OUTPUT_DIR / "matrix_OutofPhase_Alpha_group_avg.csv",
        GROUP_OUTPUT_DIR / "matrix_InPhase_Theta_group_avg.csv"
    ]
    
    for csv_path in possible_files:
        if csv_path.exists():
            try:
                df = pd.read_csv(csv_path, index_col=0)
                roi_names = df.index.tolist()
                print(f"✅ Loaded {len(roi_names)} ROI names from: {csv_path.name}")
                return roi_names
            except Exception as e:
                print(f"⚠️ Error reading {csv_path.name}: {e}")
                continue
    
    print("⚠️ No CSV with ROI names found, creating generic names")
    return [f"ROI_{i:03d}" for i in range(N_ROIS)]

# ===========================
# DATA LOADING
# ===========================

def load_subject_matrices() -> Dict:
    """Load and validate subject-level matrices."""
    if not SUBJECT_MATRICES_DIR.exists():
        raise FileNotFoundError(f"Subject matrices directory not found: {SUBJECT_MATRICES_DIR}")

    subjects = sorted([d.name for d in SUBJECT_MATRICES_DIR.iterdir() 
                      if d.is_dir() and d.name.startswith('sub-')])
    
    if len(subjects) < MIN_SUBJECT_THRESHOLD:
        raise ValueError(f"Only {len(subjects)} subjects found, need at least {MIN_SUBJECT_THRESHOLD}")
    
    print(f"🧠 Found {len(subjects)} subjects: {subjects}")

    data = {band: {'InPhase': [], 'OutofPhase': [], 'subjects': []} for band in BANDS}
    missing_files = []

    for subject in subjects:
        subject_dir = SUBJECT_MATRICES_DIR / subject
        subject_complete = True
        
        for condition in CONDITIONS:
            for band in BANDS:
                file_path = subject_dir / f"{condition}_{band}.npy"
                if file_path.exists():
                    try:
                        matrix = np.load(file_path)
                        if matrix.shape != (N_ROIS, N_ROIS):
                            subject_complete = False
                            continue
                        if np.any(np.isnan(matrix)) or np.any(np.isinf(matrix)):
                            matrix = np.nan_to_num(matrix, nan=0.0, posinf=1.0, neginf=-1.0)
                        data[band][condition].append(matrix)
                    except Exception as e:
                        subject_complete = False
                        missing_files.append(str(file_path))
                else:
                    missing_files.append(str(file_path))
                    subject_complete = False
        
        if subject_complete:
            for band in BANDS:
                data[band]['subjects'].append(subject)

    # Final validation
    cleaned_data = {band: {'InPhase': [], 'OutofPhase': [], 'subjects': []} for band in BANDS}
    for band in BANDS:
        n_in = len(data[band]['InPhase'])
        n_out = len(data[band]['OutofPhase'])
        n_subjects = len(data[band]['subjects'])
        if n_in == n_out == len(subjects):
            print(f"✅ {band}: {n_in} complete subjects")
            cleaned_data[band] = data[band]
        else:
            print(f"❌ {band}: incomplete data (In:{n_in}, Out:{n_out}, Expected:{len(subjects)})")
    
    if missing_files:
        print(f"\n⚠️ {len(missing_files)} missing files saved to missing_files.txt")
        with open(REPORTS_DIR / "missing_files.txt", 'w') as f:
            f.write("\n".join(missing_files))

    return cleaned_data

# ===========================
# STATISTICAL ANALYSIS — ONLY T-TEST + TOP 3
# ===========================

def enhanced_statistical_analysis(data: Dict, roi_names: List[str]) -> Dict:
    """Run paired t-test and extract top 3 connections per band (p<0.01)."""
    results = {}
    
    print(f"\n🔬 STATISTICAL ANALYSIS (Top 3 per band, p<0.01)")
    print(f"=" * 60)
    
    for band in BANDS:
        if not data[band]['InPhase']:
            print(f"⏭️ Skipping {band} - no data")
            continue
            
        print(f"\n📊 {band.upper()} BAND")
        
        # Stack matrices
        inphase = np.stack(data[band]['InPhase'])
        outphase = np.stack(data[band]['OutofPhase'])
        n_subjects = inphase.shape[0]
        
        print(f"   📈 Analyzing {n_subjects} subjects")
        
        # Calculate group averages
        avg_inphase = np.mean(inphase, axis=0)
        avg_outphase = np.mean(outphase, axis=0)
        avg_difference = avg_inphase - avg_outphase
        
        # Flatten for statistical testing (upper triangle only)
        n_conn_full = N_ROIS * N_ROIS
        in_flat = inphase.reshape(n_subjects, n_conn_full)
        out_flat = outphase.reshape(n_subjects, n_conn_full)

        # Mask to upper triangle
        upper_tri_mask = np.triu(np.ones((N_ROIS, N_ROIS), dtype=bool)).flatten()
        in_flat = in_flat[:, upper_tri_mask]
        out_flat = out_flat[:, upper_tri_mask]
        n_conn = in_flat.shape[1]

        print(f"   🔼 Testing only upper triangle: {n_conn} connections")
        
        # Remove connections with zero variance
        var_mask = (np.var(in_flat, axis=0) > 1e-10) & (np.var(out_flat, axis=0) > 1e-10)
        valid_connections = np.sum(var_mask)
        
        print(f"   🔗 Valid connections: {valid_connections}/{n_conn}")
        
        if valid_connections == 0:
            print(f"   ❌ No valid connections found for {band}")
            continue
        
        # Initialize results arrays
        t_vals = np.zeros(n_conn)
        p_vals = np.ones(n_conn)
        
        # Paired t-test only on valid connections
        valid_in = in_flat[:, var_mask]
        valid_out = out_flat[:, var_mask]
        t_vals_valid, p_vals_valid = ttest_rel(valid_in, valid_out, axis=0)
        
        # Fill results
        t_vals[var_mask] = t_vals_valid
        p_vals[var_mask] = p_vals_valid
        
        # Initialize full 512x512 matrices
        t_matrix = np.zeros((N_ROIS, N_ROIS))
        p_matrix = np.ones((N_ROIS, N_ROIS))

        # Fill upper triangle
        upper_tri_indices = np.triu_indices(N_ROIS)
        t_matrix[upper_tri_indices] = t_vals
        p_matrix[upper_tri_indices] = p_vals

        # Mirror upper triangle to lower for symmetric matrices
        t_matrix = t_matrix + t_matrix.T - np.diag(np.diag(t_matrix))
        p_matrix = p_matrix + p_matrix.T - np.diag(np.diag(p_matrix))

        # Count connections with p < 0.01
        n_uncorr = np.sum(p_vals < 0.01)
        print(f"   📊 Uncorrected p<0.01: {n_uncorr}")
        print(f"   📊 Max |t-value|: {np.max(np.abs(t_vals)):.3f}")
        
        # Create DataFrames
        df_avg_in = pd.DataFrame(avg_inphase, index=roi_names, columns=roi_names)
        df_avg_out = pd.DataFrame(avg_outphase, index=roi_names, columns=roi_names)
        df_diff = pd.DataFrame(avg_difference, index=roi_names, columns=roi_names)
        df_t = pd.DataFrame(t_matrix, index=roi_names, columns=roi_names)
        df_p = pd.DataFrame(p_matrix, index=roi_names, columns=roi_names)
        
        # Save results
        band_dir = MATRICES_DIR / band
        band_dir.mkdir(exist_ok=True)
        
        df_avg_in.to_csv(band_dir / f"{band}_average_inphase.csv")
        df_avg_out.to_csv(band_dir / f"{band}_average_outphase.csv")
        df_diff.to_csv(band_dir / f"{band}_difference.csv")
        df_t.to_csv(band_dir / f"{band}_t_values.csv")
        df_p.to_csv(band_dir / f"{band}_p_values.csv")
        
        np.save(band_dir / f"{band}_t_values.npy", t_matrix)
        np.save(band_dir / f"{band}_p_values.npy", p_matrix)
        np.save(band_dir / f"{band}_difference.npy", avg_difference)
        
        # Store results
        results[band] = {
            'n_subjects': n_subjects,
            'average_inphase': avg_inphase,
            'average_outphase': avg_outphase,
            'difference_matrix': avg_difference,
            't_matrix': t_matrix,
            'p_matrix': p_matrix,
            'n_significant_uncorr': n_uncorr,
            'max_t': np.max(np.abs(t_vals)),
            'valid_connections': valid_connections,
            't_vals_upper': t_vals.copy(),
            'p_vals_upper': p_vals.copy(),
            'upper_tri_mask': upper_tri_mask
        }
    
    return results

# ===========================
# REPORTING — TOP 3 PER BAND
# ===========================

def generate_comprehensive_report(results: Dict, roi_names: List[str]):
    """Generate report with top 3 connections per band (p<0.01)."""
    
    summary_data = []
    top3_findings = []  # Top 3 per band with p<0.01
    
    for band, res in results.items():
        summary_data.append({
            'Band': band,
            'N_Subjects': res['n_subjects'],
            'Valid_Connections': res['valid_connections'],
            'Significant_p01': res['n_significant_uncorr'],
            'Max_T_Value': res['max_t'],
            'Mean_Abs_Difference': np.mean(np.abs(res['difference_matrix']))
        })
        
        # Extract top 3 by |t| with p<0.01
        t_vals_upper = res['t_vals_upper']
        p_vals_upper = res['p_vals_upper']
        t_abs_upper = np.abs(t_vals_upper)
        sorted_indices = np.argsort(t_abs_upper)[::-1]  # descending by |t|

        count = 0
        for idx in sorted_indices:
            if count >= 3:  # Only top 3
                break
            if p_vals_upper[idx] < 0.01:  # uncorrected p<0.01
                upper_tri_indices = np.triu_indices(N_ROIS)
                row, col = upper_tri_indices[0][idx], upper_tri_indices[1][idx]
                if row == col:
                    continue
                top3_findings.append({
                    'Band': band,
                    'Rank': count + 1,
                    'ROI_1': roi_names[row],
                    'ROI_2': roi_names[col],
                    'T_Value': t_vals_upper[idx],
                    'P_Value': p_vals_upper[idx],
                    'Mean_Difference': res['difference_matrix'][row, col],
                    'Effect_Size_Cohen_d': t_vals_upper[idx] / np.sqrt(res['n_subjects'])
                })
                count += 1

    # Save summary report
    summary_df = pd.DataFrame(summary_data)
    summary_df.to_csv(REPORTS_DIR / "analysis_summary.csv", index=False)
    
    # Save Top-3 findings
    if top3_findings:
        top3_df = pd.DataFrame(top3_findings)
        top3_df.to_csv(REPORTS_DIR / "top3_connections_p01.csv", index=False)
        print(f"🎯 Found {len(top3_findings)} top-3 connections (p<0.01 uncorrected, ranked by |t|)")
    else:
        print("⚠️ No top-3 connections found with p<0.01")

    # Generate text report
    with open(REPORTS_DIR / "analysis_report.txt", 'w') as f:
        f.write("SIMPLIFIED CONNECTIVITY ANALYSIS REPORT\n")
        f.write("=" * 50 + "\n\n")
        f.write("ANALYSIS METHOD\n")
        f.write("-" * 30 + "\n")
        f.write("Paired t-tests were performed for each connection (InPhase vs OutofPhase).\n")
        f.write("Due to high dimensionality (131K+ tests), no multiple comparison correction was applied.\n")
        f.write("Top 3 connections per band, ranked by absolute t-value and passing p < 0.01, are reported for exploratory analysis.\n\n")
        
        f.write("SUMMARY BY FREQUENCY BAND\n")
        f.write("-" * 30 + "\n")
        for _, row in summary_df.iterrows():
            f.write(f"\n{row['Band'].upper()} BAND:\n")
            f.write(f"  Subjects analyzed: {row['N_Subjects']}\n")
            f.write(f"  Valid connections: {row['Valid_Connections']}\n")
            f.write(f"  Connections with p<0.01: {row['Significant_p01']}\n")
            f.write(f"  Maximum |t-value|: {row['Max_T_Value']:.3f}\n")
            f.write(f"  Mean absolute difference: {row['Mean_Abs_Difference']:.6f}\n")
    
    print(f"📊 Comprehensive reports saved to: {REPORTS_DIR}")
    return summary_df, top3_findings

# ===========================
# VISUALIZATION — TOP 3 PER BAND
# ===========================

def plot_top3_connections(results: Dict, roi_names: List[str], top3_findings: List[Dict]):
    """Plot top 3 connections per band."""
    for band in BANDS:
        if band not in results:
            continue
            
        res = results[band]
        # Create mask for top 3 connections
        top3_mask = np.zeros((N_ROIS, N_ROIS), dtype=bool)
        
        # Find top 3 for this band
        band_findings = [f for f in top3_findings if f['Band'] == band]
        
        for finding in band_findings:
            # Find ROI indices
            row_idx = roi_names.index(finding['ROI_1'])
            col_idx = roi_names.index(finding['ROI_2'])
            top3_mask[row_idx, col_idx] = True
            top3_mask[col_idx, row_idx] = True  # symmetric
        
        # Plot
        fig, ax = plt.subplots(1, 1, figsize=(10, 8))
        im = ax.imshow(
            np.where(top3_mask, res['difference_matrix'], np.nan),
            cmap='RdBu_r',
            aspect='auto',
            interpolation='none',
            vmin=-np.nanmax(np.abs(res['difference_matrix'])),
            vmax=np.nanmax(np.abs(res['difference_matrix']))
        )
        ax.set_title(f"{band}: Top 3 Connections (p<0.01 uncorrected)", fontsize=14, fontweight='bold')
        plt.colorbar(im, ax=ax, shrink=0.8, label='Mean Difference (InPhase - OutofPhase)')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.grid(False)
        
        plt.tight_layout()
        plt.savefig(FIGURES_DIR / f"{band}_top3_connections_p01.png", dpi=150, bbox_inches='tight')
        plt.close()
        print(f"   🖼️  Saved top-3 plot for {band}")

def create_summary_visualization(results: Dict, top3_counts: Dict):
    """Create summary visualization."""
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('Connectivity Analysis Summary\nTop 3 per band (p<0.01)', fontsize=16)
    
    bands = list(results.keys())
    if not bands:
        return
    
    # Plot 1: Number of p<0.01 connections
    p01_counts = [results[band]['n_significant_uncorr'] for band in bands]
    axes[0,0].bar(bands, p01_counts, color='#FF6B6B', alpha=0.7)
    axes[0,0].set_title('Connections with p<0.01', fontsize=12)
    axes[0,0].set_ylabel('Count')
    plt.setp(axes[0,0].xaxis.get_majorticklabels(), rotation=45)
    
    # Plot 2: Maximum t-values
    max_t = [results[band]['max_t'] for band in bands]
    axes[0,1].bar(bands, max_t, color='#DDA0DD', alpha=0.7)
    axes[0,1].set_title('Maximum |t-value|', fontsize=12)
    axes[0,1].set_ylabel('t-value')
    plt.setp(axes[0,1].xaxis.get_majorticklabels(), rotation=45)
    
    # Plot 3: Effect sizes
    mean_diff = [np.mean(np.abs(results[band]['difference_matrix'])) for band in bands]
    axes[1,0].bar(bands, mean_diff, color='#A29BFE', alpha=0.7)
    axes[1,0].set_title('Mean Absolute Difference', fontsize=12)
    axes[1,0].set_ylabel('Difference')
    plt.setp(axes[1,0].xaxis.get_majorticklabels(), rotation=45)
    
    # Plot 4: Top 3 counts (should be 3 per band, or less if not enough significant)
    top3_vals = [top3_counts.get(band, 0) for band in bands]
    axes[1,1].bar(bands, top3_vals, color='#FFEAA7', alpha=0.7)
    axes[1,1].set_title('Top 3 Connections Found', fontsize=12)
    axes[1,1].set_ylabel('Count')
    axes[1,1].set_ylim([0, 3])
    plt.setp(axes[1,1].xaxis.get_majorticklabels(), rotation=45)
    
    plt.tight_layout()
    plt.savefig(FIGURES_DIR / "analysis_summary.png", dpi=300, bbox_inches='tight')
    plt.close()
    print(f"📊 Summary visualization saved to: {FIGURES_DIR}/analysis_summary.png")

# ===========================
# MAIN EXECUTION
# ===========================

def main():
    print("🚀 STARTING SIMPLIFIED CONNECTIVITY ANALYSIS")
    print("=" * 60)
    
    try:
        # Load ROI names
        roi_names = load_roi_names()
        
        # Load subject data
        data = load_subject_matrices()
        
        if not any(data[band]['InPhase'] for band in BANDS):
            raise ValueError("No valid data found for any frequency band")
        
        # Run statistical analysis
        results = enhanced_statistical_analysis(data, roi_names)
        
        if not results:
            raise ValueError("No results generated from statistical analysis")
        
        # Generate reports
        summary_df, top3_findings = generate_comprehensive_report(results, roi_names)
        
        # Count top3 per band
        top3_counts = {}
        for finding in top3_findings:
            band = finding['Band']
            top3_counts[band] = top3_counts.get(band, 0) + 1
        
        # Create visualizations
        plot_top3_connections(results, roi_names, top3_findings)
        create_summary_visualization(results, top3_counts)
        
        # Final summary
        print("\n" + "=" * 60)
        print("📊 ANALYSIS COMPLETE - SUMMARY")
        print("=" * 60)
        print(summary_df.to_string(index=False))
        
        if top3_findings:
            print(f"\n🎯 TOP 3 CONNECTIONS (p<0.01):")
            top3_df = pd.DataFrame(top3_findings)
            print(top3_df.to_string(index=False))
        
        print(f"\n📁 All outputs saved to:")
        print(f"   📊 Statistics: {STATS_OUTPUT_DIR}")
        print(f"   📈 Matrices: {MATRICES_DIR}")
        print(f"   📋 Reports: {REPORTS_DIR}")
        print(f"   📊 Figures: {FIGURES_DIR}")
        print(f"   🎯 Top 3: {REPORTS_DIR}/top3_connections_p01.csv")
        
    except Exception as e:
        print(f"❌ ANALYSIS FAILED: {e}")
        import traceback
        traceback.print_exc()
        return False
    
    return True

if __name__ == "__main__":
    success = main()
    if success:
        print("\n✅ 🎯 ANALYSIS COMPLETED SUCCESSFULLY — Top 3 per band (p<0.01)!")
    else:
        print("\n❌ ANALYSIS FAILED - CHECK LOGS")

📁 Output directories created:
   Statistics: /home/jaizor/jaizor/xtra/derivatives/group/statistics
   Matrices: /home/jaizor/jaizor/xtra/derivatives/group/statistics/matrices
   Reports: /home/jaizor/jaizor/xtra/derivatives/group/statistics/reports
   Figures: /home/jaizor/jaizor/xtra/derivatives/group/statistics/figures
🚀 STARTING SIMPLIFIED CONNECTIVITY ANALYSIS
✅ Loaded 512 ROI names from: matrix_InPhase_Alpha_group_avg.csv
🧠 Found 12 subjects: ['sub-01', 'sub-02', 'sub-03', 'sub-05', 'sub-06', 'sub-07', 'sub-08', 'sub-09', 'sub-10', 'sub-11', 'sub-12', 'sub-14']
✅ Theta: 12 complete subjects
✅ Alpha: 12 complete subjects
✅ Low_Beta: 12 complete subjects
✅ High_Beta: 12 complete subjects
✅ Low_Gamma: 12 complete subjects
✅ High_Gamma: 12 complete subjects

🔬 STATISTICAL ANALYSIS (Top 3 per band, p<0.01)

📊 THETA BAND
   📈 Analyzing 12 subjects
   🔼 Testing only upper triangle: 131328 connections
   🔗 Valid connections: 127260/131328
   📊 Uncorrected p<0.01: 1325
   📊 Max |t-value|: 