In [1]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
from nilearn import datasets, image, plotting
from nilearn.plotting.find_cuts import find_xyz_cut_coords
from nilearn.image import iter_img, index_img
from nilearn.datasets import load_fsaverage
from nilearn import surface
import plotly.graph_objects as go
import matplotlib.pyplot as plt  # 👈 ADDED for slice plots
import warnings
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
import logging
import gc
import time

warnings.filterwarnings("ignore")

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# ===== CONFIGURATION =====
PROJECT_BASE = '/home/jaizor/jaizor/xtra'
GROUP_OUTPUT_DIR = Path(PROJECT_BASE) / "derivatives" / "group"
OUTPUT_DIR = GROUP_OUTPUT_DIR / "stat"
HTML_OUTPUT_DIR = OUTPUT_DIR / "html"
SLICE_OUTPUT_DIR = OUTPUT_DIR / "slices"  # 👈 NEW: Directory for slice plots

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
HTML_OUTPUT_DIR.mkdir(exist_ok=True)
SLICE_OUTPUT_DIR.mkdir(exist_ok=True)  # 👈 Create it

BANDS = ["Low_Beta", "High_Beta", "Low_Gamma", "High_Gamma"]
CONDITIONS = ['InPhase', 'OutofPhase']
N_TOP = 10

# Color scheme
CONDITION_COLORS = {
    'InPhase': '#1f77b4',     # Blue
    'OutofPhase': '#d62728'   # Red
}

# ===== CACHED DATA LOADING =====
@lru_cache(maxsize=1)
def get_difumo_atlas():
    """Cache atlas loading to avoid repeated downloads"""
    logger.info("Loading DiFuMo 512 atlas...")
    return datasets.fetch_atlas_difumo(dimension=512, resolution_mm=2)

@lru_cache(maxsize=1)
def get_roi_names():
    """Load and clean ROI names"""
    try:
        atlas = get_difumo_atlas()
        roi_names = atlas.labels['difumo_names'].astype(str).tolist()
        logger.info(f"Loaded {len(roi_names)} ROI names from atlas")
    except Exception as e:
        logger.warning(f"Failed to load ROI names: {e}, using generic names")
        roi_names = [f"Component_{i:03d}" for i in range(512)]
    
    # Clean names for CSV compatibility
    cleaned_names = []
    for name in roi_names:
        cleaned = name.replace(',', ';').replace('\n', ' ').replace('\r', ' ').strip()
        cleaned_names.append(cleaned)
    
    return cleaned_names

@lru_cache(maxsize=1)
def get_fsaverage_mesh():
    """Cache fsaverage mesh loading — use fsaverage4 for better visuals"""
    logger.info("Loading fsaverage4 mesh (optimal detail/performance)...")
    return load_fsaverage(mesh='fsaverage4')  # 👈 CHANGED for better visuals

# ===== ON-DEMAND ROI STATISTICAL MAP LOADER =====
@lru_cache(maxsize=32)
def get_roi_stat_map_by_index(roi_index):
    """Load statistical map and centroid for a single ROI by index — cached for reuse"""
    atlas = get_difumo_atlas()
    atlas_img = atlas["maps"]
    try:
        roi_map = index_img(atlas_img, roi_index)
        coord = find_xyz_cut_coords(roi_map, activation_threshold=0.1)
        return roi_map, coord
    except Exception as e:
        logger.warning(f"Failed to load stat map for ROI index {roi_index}: {e}")
        return None, [0, 0, 0]

def get_needed_roi_stat_maps(roi_names_needed, all_roi_names):
    """
    Load stat maps and coordinates ONLY for ROIs that appear in top connections.
    Returns: 
        roi_stat_maps: dict[roi_name] -> Nifti1Image
        roi_coords_dict: dict[roi_name] -> [x, y, z]
    """
    roi_stat_maps = {}
    roi_coords_dict = {}

    for roi_name in roi_names_needed:
        idx = find_roi_index_by_name(roi_name, all_roi_names)
        if idx is None:
            logger.warning(f"ROI name '{roi_name}' not found in atlas.")
            continue

        stat_map, coord = get_roi_stat_map_by_index(idx)
        if stat_map is not None:
            roi_stat_maps[roi_name] = stat_map
            roi_coords_dict[roi_name] = coord

    logger.info(f"Loaded stat maps for {len(roi_stat_maps)} relevant ROIs")
    return roi_stat_maps, roi_coords_dict

# ===== ROI MATCHING =====
def find_roi_index_by_name(roi_name, roi_names):
    """Find ROI index by name with fuzzy matching"""
    roi_name_lower = roi_name.lower()
    
    # Exact match first
    for i, name in enumerate(roi_names):
        if name.lower() == roi_name_lower:
            return i
    
    # Partial match
    for i, name in enumerate(roi_names):
        if roi_name_lower in name.lower() or name.lower() in roi_name_lower:
            return i
    
    # Word-based matching
    roi_words = set(roi_name_lower.split())
    best_match_idx = None
    best_score = 0
    
    for i, name in enumerate(roi_names):
        name_words = set(name.lower().split())
        if roi_words and name_words:
            intersection = roi_words & name_words
            union = roi_words | name_words
            score = len(intersection) / len(union) if union else 0
            
            if score > best_score and score > 0.3:
                best_score = score
                best_match_idx = i
    
    return best_match_idx

# ===== SURFACE EXTRACTION =====
def extract_roi_surface_data(roi_map, fsaverage_mesh):
    """Extract surface data from ROI statistical map"""
    try:
        # Project volume to surface
        left_texture = surface.vol_to_surf(roi_map, fsaverage_mesh.pial.parts['left'])
        right_texture = surface.vol_to_surf(roi_map, fsaverage_mesh.pial.parts['right'])
        
        # Combine hemispheres
        full_texture = np.concatenate([left_texture, right_texture])
        
        # Threshold to get only significant voxels — SOFTER THRESHOLD
        threshold = np.percentile(full_texture[full_texture > 0], 70) if np.any(full_texture > 0) else 0  # 👈 WAS 95 → too strict
        
        return full_texture, threshold
    except Exception as e:
        logger.warning(f"Failed to extract surface data: {e}")
        return None, 0

# ===== MATRIX OPERATIONS =====
def load_matrix(condition, band):
    """Load connectivity matrix with error handling"""
    fname = GROUP_OUTPUT_DIR / f"matrix_{condition}_{band}_group_avg.csv"
    if not fname.exists():
        raise FileNotFoundError(f"Missing matrix: {fname}")
    
    try:
        matrix = pd.read_csv(fname, index_col=0).values.astype(np.float32)
        logger.info(f"Loaded matrix {condition}_{band}: shape {matrix.shape}")
        return matrix
    except Exception as e:
        raise ValueError(f"Error loading matrix {fname}: {e}")

def extract_top_edges_vectorized(matrix, condition, band, roi_names, top_n=N_TOP):
    """Vectorized extraction of top edges - much faster than loops"""
    n_rois = len(roi_names)
    
    # Get upper triangle indices (excluding diagonal)
    i_indices, j_indices = np.triu_indices(n_rois, k=1)
    
    # Extract upper triangle values
    upper_triangle_values = matrix[i_indices, j_indices]
    
    # Handle NaN/inf values
    valid_mask = np.isfinite(upper_triangle_values)
    if not np.any(valid_mask):
        logger.warning(f"No valid connections found for {condition}_{band}")
        return []
    
    # Apply mask to get valid indices and values
    valid_values = upper_triangle_values[valid_mask]
    valid_i = i_indices[valid_mask]
    valid_j = j_indices[valid_mask]
    
    # Get top N indices based on absolute values for stability
    top_indices = np.argpartition(np.abs(valid_values), -top_n)[-top_n:]
    
    # Sort by actual values (descending)
    sorted_top_indices = top_indices[np.argsort(valid_values[top_indices])[::-1]]
    
    edges = []
    for idx in sorted_top_indices:
        i, j = valid_i[idx], valid_j[idx]
        weight = valid_values[idx]
        edges.append({
            'Marker1': roi_names[i],
            'Marker2': roi_names[j],
            'Stability': float(weight),
            'Band': band,
            'Condition': condition,
            'ROI1_Index': int(i),
            'ROI2_Index': int(j)
        })
    
    return edges

# ===== SLICE PLOT FUNCTION — LIKE YOUR EXAMPLE IMAGE =====
def plot_roi_slices(roi_map, roi_name, output_path):
    """Create beautiful slice views of an ROI map — just like your reference image"""
    try:
        # Get cut coordinates
        coords = find_xyz_cut_coords(roi_map, activation_threshold=0.1)
        x, y, z = int(coords[0]), int(coords[1]), int(coords[2])

        # Use a dark background for contrast
        display = plotting.plot_roi(
            roi_map,
            cut_coords=(x, y, z),
            display_mode='ortho',  # Shows 3 views: axial, sagittal, coronal
            black_bg=True,
            cmap='RdYlBu_r',  # Professional colormap
            title=f"{roi_name}",
            draw_cross=True,
            annotate=True,
            threshold=0.1  # Show more of the ROI
        )

        # Save high-res PNG
        display.savefig(output_path, dpi=300)
        display.close()

        logger.info(f"✅ Saved slice plot: {output_path}")

    except Exception as e:
        logger.warning(f"Failed to create slice plot for {roi_name}: {e}")

# ===== BEAUTIFUL 3D VISUALIZATION =====
def create_stat_map_connectome_optimized(df_edges, roi_stat_maps, roi_coords_dict, roi_names, fsaverage_mesh,
                                        title="Top Connections", brain_opacity=0.2):
    """Create BEAUTIFUL 3D brain visualization — vibrant, visible ROIs"""
    fig = go.Figure()

    # Prepare brain surface
    mesh = fsaverage_mesh.pial
    vertices_lh = mesh.parts['left'].coordinates
    vertices_rh = mesh.parts['right'].coordinates
    vertices = np.vstack([vertices_lh, vertices_rh])

    faces_lh = mesh.parts['left'].faces
    faces_rh = mesh.parts['right'].faces + len(vertices_lh)
    faces = np.vstack([faces_lh, faces_rh])

    # Add base brain surface — smoother, slightly transparent
    fig.add_trace(go.Mesh3d(
        x=vertices[:, 0], y=vertices[:, 1], z=vertices[:, 2],
        i=faces[:, 0], j=faces[:, 1], k=faces[:, 2],
        color=f'rgba(180, 180, 180, {brain_opacity})',
        flatshading=False,  # 👈 Smooth shading
        name='Cortical Surface',
        hoverinfo='skip',
        showlegend=False,
        lighting=dict(
            ambient=0.5,
            diffuse=0.8,
            specular=0.3,
            roughness=0.2,
            fresnel=0.1
        ),
        lightposition=dict(x=500, y=0, z=1000)
    ))

    # Get unique regions
    unique_regions = pd.unique(df_edges[['Marker1', 'Marker2']].values.ravel())
    logger.info(f"Creating overlays for {len(unique_regions)} ROIs")

    # Vibrant colors
    CONDITION_BASE_COLORS = {
        'InPhase': [31, 119, 180],
        'OutofPhase': [214, 39, 40]
    }

    # Add ROI overlays
    for roi_name in unique_regions:
        if roi_name not in roi_stat_maps:
            continue

        roi_map = roi_stat_maps[roi_name]
        try:
            texture_data, threshold = extract_roi_surface_data(roi_map, fsaverage_mesh)

            if texture_data is not None and np.any(texture_data > 0):
                significant_mask = texture_data > threshold
                if np.any(significant_mask):
                    # Determine condition
                    roi_condition = 'InPhase'
                    for _, edge in df_edges.iterrows():
                        if edge['Marker1'] == roi_name or edge['Marker2'] == roi_name:
                            roi_condition = edge['Condition']
                            break

                    base_color = CONDITION_BASE_COLORS[roi_condition]

                    # Normalize within significant voxels
                    overlay_color = np.zeros(len(vertices))
                    overlay_color[significant_mask] = texture_data[significant_mask]
                    sig_values = overlay_color[significant_mask]
                    if len(sig_values) > 0 and np.max(sig_values) > np.min(sig_values):
                        overlay_color[significant_mask] = (sig_values - np.min(sig_values)) / (np.max(sig_values) - np.min(sig_values))

                    # Create vibrant, visible colors
                    rgba_colors = []
                    for val in overlay_color:
                        if val > 0:
                            alpha = 0.4 + val * 0.5  # Min 40% opacity
                            rgba_colors.append(f'rgba({base_color[0]}, {base_color[1]}, {base_color[2]}, {alpha})')
                        else:
                            rgba_colors.append('rgba(0,0,0,0)')

                    fig.add_trace(go.Mesh3d(
                        x=vertices[:, 0], y=vertices[:, 1], z=vertices[:, 2],
                        i=faces[:, 0], j=faces[:, 1], k=faces[:, 2],
                        facecolor=rgba_colors,
                        name=f'{roi_name} ({roi_condition})',
                        hovertext=f'{roi_name}<br>Condition: {roi_condition}',
                        hoverinfo='text',
                        showlegend=True,
                        lighting=dict(ambient=0.6, diffuse=0.9)
                    ))

        except Exception as e:
            logger.warning(f"Failed to create overlay for ROI {roi_name}: {e}")

    # Add connection lines
    for condition in CONDITIONS:
        condition_edges = df_edges[df_edges['Condition'] == condition]
        if condition_edges.empty:
            continue

        edge_x, edge_y, edge_z = [], [], []
        edge_hover_texts = []

        for _, row in condition_edges.iterrows():
            m1, m2 = row['Marker1'], row['Marker2']
            if m1 in roi_coords_dict and m2 in roi_coords_dict:
                c1, c2 = np.array(roi_coords_dict[m1]), np.array(roi_coords_dict[m2])

                # Create curved lines for beauty
                control_point = (c1 + c2) / 2
                control_point[2] += 20

                t = np.linspace(0, 1, 20)
                curve_x = (1-t)**2 * c1[0] + 2*(1-t)*t * control_point[0] + t**2 * c2[0]
                curve_y = (1-t)**2 * c1[1] + 2*(1-t)*t * control_point[1] + t**2 * c2[1]
                curve_z = (1-t)**2 * c1[2] + 2*(1-t)*t * control_point[2] + t**2 * c2[2]

                edge_x.extend(curve_x.tolist() + [None])
                edge_y.extend(curve_y.tolist() + [None])
                edge_z.extend(curve_z.tolist() + [None])

                hover_text = (f"<b>{m1} ↔ {m2}</b><br>"
                             f"Weight: {row['Stability']:.4f}<br>"
                             f"Band: {row['Band']}<br>"
                             f"Condition: {condition}")
                edge_hover_texts.extend([hover_text] * 20 + [""])

        if edge_x:
            fig.add_trace(go.Scatter3d(
                x=edge_x, y=edge_y, z=edge_z,
                mode='lines',
                line=dict(color=CONDITION_COLORS[condition], width=5),
                opacity=0.95,
                name=f'{condition} Connections',
                showlegend=True,
                hovertext=edge_hover_texts,
                hoverinfo='text'
            ))

    # Publication-ready layout
    fig.update_layout(
        title=dict(
            text=f'<b>{title}</b>',
            x=0.5, y=0.95,
            font=dict(size=18, family="Arial", color='black')
        ),
        paper_bgcolor='white',
        plot_bgcolor='white',
        scene=dict(
            bgcolor='white',
            xaxis=dict(visible=False),
            yaxis=dict(visible=False),
            zaxis=dict(visible=False),
            aspectmode='data',
            camera=dict(
                eye=dict(x=1.5, y=-1.5, z=0.8),  # Left lateral view
                center=dict(x=0, y=0, z=0),
                up=dict(x=0, y=0, z=1)
            )
        ),
        legend=dict(
            orientation="v",
            yanchor="top", y=0.98,
            xanchor="left", x=0.01,
            bgcolor="rgba(255, 255, 255, 0.95)",
            bordercolor="lightgray",
            borderwidth=1,
            font=dict(size=11, family="Arial"),
            title=dict(text="<b>Legend</b>", font=dict(size=12, weight='bold'))
        ),
        margin=dict(l=20, r=20, b=20, t=60),
        height=800,
        width=1200,
        hoverlabel=dict(
            bgcolor="white",
            font_size=12,
            font_family="Arial"
        )
    )

    return fig

# ===== MAIN PROCESSING =====
def process_band_with_stat_maps(band, all_roi_names):
    """Process a single frequency band — generate 3D + slice plots"""
    logger.info(f"Processing band: {band}")

    try:
        # Load matrices
        matrix_in = load_matrix("InPhase", band)
        matrix_out = load_matrix("OutofPhase", band)

        # Validate matrix dimensions
        expected_size = len(all_roi_names)
        if matrix_in.shape != (expected_size, expected_size) or matrix_out.shape != (expected_size, expected_size):
            raise ValueError(f"Matrix size mismatch for {band}. Expected {expected_size}x{expected_size}")

        # Extract top edges
        top_in = extract_top_edges_vectorized(matrix_in, 'InPhase', band, all_roi_names, N_TOP)
        top_out = extract_top_edges_vectorized(matrix_out, 'OutofPhase', band, all_roi_names, N_TOP)

        if not top_in and not top_out:
            logger.warning(f"No valid connections found for {band}")
            return False

        # Combine and create DataFrame
        all_edges = top_in + top_out
        df_edges = pd.DataFrame(all_edges)

        # Save CSV
        csv_path = OUTPUT_DIR / f"top_{N_TOP}_per_condition_{band}_statmaps.csv"
        df_edges.to_csv(csv_path, index=False)
        logger.info(f"Saved CSV: {csv_path}")

        # Load stat maps for active ROIs
        unique_rois = pd.unique(df_edges[['Marker1', 'Marker2']].values.ravel())
        logger.info(f"Identified {len(unique_rois)} unique ROIs for {band}")

        roi_stat_maps, roi_coords_dict = get_needed_roi_stat_maps(unique_rois, all_roi_names)

        if not roi_stat_maps:
            logger.error(f"No statistical maps could be loaded for active ROIs in {band}. Skipping plot.")
            return False

        # Generate 3D plot
        try:
            fsaverage = get_fsaverage_mesh()
            fig_3d = create_stat_map_connectome_optimized(
                df_edges=df_edges,
                roi_stat_maps=roi_stat_maps,
                roi_coords_dict=roi_coords_dict,
                roi_names=all_roi_names,
                fsaverage_mesh=fsaverage,
                title=f"Top {N_TOP} Connections: {band} Band"
            )

            html_path = HTML_OUTPUT_DIR / f"top_{N_TOP}_per_condition_{band}_statmaps.html"
            fig_3d.write_html(html_path, include_plotlyjs='cdn')
            logger.info(f"Saved interactive 3D plot: {html_path}")

            # 👇 GENERATE SLICE PLOTS — LIKE YOUR EXAMPLE IMAGE
            for roi_name in unique_rois:
                if roi_name not in roi_stat_maps:
                    continue
                roi_map = roi_stat_maps[roi_name]
                safe_name = roi_name.replace(" ", "_").replace("/", "_").replace("\\", "_")
                slice_path = SLICE_OUTPUT_DIR / f"{safe_name}_{band}.png"
                plot_roi_slices(roi_map, roi_name, slice_path)

            # 🧹 CLEANUP
            del fig_3d, fsaverage, roi_stat_maps, roi_coords_dict, df_edges
            gc.collect()

            return True

        except Exception as e:
            logger.error(f"Failed to generate plots for {band}: {e}")
            return False

    except Exception as e:
        logger.error(f"Failed to process {band}: {e}")
        return False
    finally:
        gc.collect()

def main():
    """Main execution function — beautiful plots + no crashes"""
    logger.info("🚀 Starting brain connectivity visualization...")

    try:
        roi_names = get_roi_names()
        logger.info(f"✅ Loaded {len(roi_names)} ROI names")

        successful_bands = 0
        for band in BANDS:
            success = process_band_with_stat_maps(band, roi_names)
            if success:
                successful_bands += 1
                time.sleep(0.5)
                gc.collect()

        logger.info(f"\n🎉 Processing complete!")
        logger.info(f"   • Successfully processed: {successful_bands}/{len(BANDS)} bands")
        logger.info(f"   • Data saved to: {OUTPUT_DIR}")
        logger.info(f"   • Interactive 3D plots: {HTML_OUTPUT_DIR}")
        logger.info(f"   • Slice plots (like your example): {SLICE_OUTPUT_DIR}")  # 👈 NEW
        logger.info(f"   • Open HTML files in browser or use PNGs in papers")

        if successful_bands == 0:
            logger.error("No bands processed successfully.")
            return False

        # 🧹 Final cleanup
        gc.collect()
        time.sleep(1)

        return True

    except Exception as e:
        logger.error(f"Fatal error: {e}")
        return False
    finally:
        gc.collect()

if __name__ == "__main__":
    success = main()
    # Let Python exit naturally — don't force exit()

2025-09-20 00:12:32,101 - INFO - 🚀 Starting brain connectivity visualization...
2025-09-20 00:12:32,103 - INFO - Loading DiFuMo 512 atlas...


2025-09-20 00:12:32,210 - INFO - Loaded 512 ROI names from atlas
2025-09-20 00:12:32,210 - INFO - ✅ Loaded 512 ROI names
2025-09-20 00:12:32,211 - INFO - Processing band: Low_Beta
2025-09-20 00:12:32,235 - INFO - Loaded matrix InPhase_Low_Beta: shape (512, 512)
2025-09-20 00:12:32,265 - INFO - Loaded matrix OutofPhase_Low_Beta: shape (512, 512)
2025-09-20 00:12:32,271 - INFO - Saved CSV: /home/jaizor/jaizor/xtra/derivatives/group/stat/top_10_per_condition_Low_Beta_statmaps.csv
2025-09-20 00:12:32,272 - INFO - Identified 30 unique ROIs for Low_Beta
2025-09-20 00:14:44,538 - INFO - Loaded stat maps for 30 relevant ROIs
2025-09-20 00:14:44,539 - INFO - Loading fsaverage4 mesh (optimal detail/performance)...


2025-09-20 00:14:46,177 - INFO - Creating overlays for 30 ROIs
2025-09-20 00:14:47,081 - INFO - Saved interactive 3D plot: /home/jaizor/jaizor/xtra/derivatives/group/stat/html/top_10_per_condition_Low_Beta_statmaps.html
2025-09-20 00:14:48,513 - INFO - ✅ Saved slice plot: /home/jaizor/jaizor/xtra/derivatives/group/stat/slices/Superior_frontal_sulcus_middle_LH_Low_Beta.png
2025-09-20 00:14:49,440 - INFO - ✅ Saved slice plot: /home/jaizor/jaizor/xtra/derivatives/group/stat/slices/Cerebrospinal_fluid_(between_postcentral_sulcus_and_skull_RH)_Low_Beta.png
2025-09-20 00:14:50,396 - INFO - ✅ Saved slice plot: /home/jaizor/jaizor/xtra/derivatives/group/stat/slices/Cerebellum_Crus_I_anterior_RH_Low_Beta.png
2025-09-20 00:14:51,344 - INFO - ✅ Saved slice plot: /home/jaizor/jaizor/xtra/derivatives/group/stat/slices/Cuneus_anterior_Low_Beta.png
2025-09-20 00:14:52,315 - INFO - ✅ Saved slice plot: /home/jaizor/jaizor/xtra/derivatives/group/stat/slices/Precuneus_mid-posterior_Low_Beta.png
2025-09-2