In [None]:
# Visualize top connections


import os
from pathlib import Path
import numpy as np
import pandas as pd
from nilearn import datasets, image
from nilearn.plotting.find_cuts import find_xyz_cut_coords
from nilearn.image import iter_img
from nilearn.datasets import load_fsaverage
import plotly.graph_objects as go
import warnings
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
import logging

warnings.filterwarnings("ignore")

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# ===== CONFIGURATION =====
PROJECT_BASE = '/home/jaizor/jaizor/xtra'
GROUP_OUTPUT_DIR = Path(PROJECT_BASE) / "derivatives" / "group"
OUTPUT_DIR = GROUP_OUTPUT_DIR / "top_connections"
HTML_OUTPUT_DIR = OUTPUT_DIR / "html"

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
HTML_OUTPUT_DIR.mkdir(exist_ok=True)

BANDS = ["Theta", "Alpha", "Low_Beta", "High_Beta", "Low_Gamma", "High_Gamma"]
CONDITIONS = ['InPhase', 'OutofPhase']
N_TOP = 10

# Color scheme
CONDITION_COLORS = {
    'InPhase': '#1f77b4',     # Blue
    'OutofPhase': '#d62728'   # Red
}

# ===== CACHED DATA LOADING =====
@lru_cache(maxsize=1)
def get_difumo_atlas():
    """Cache atlas loading to avoid repeated downloads"""
    logger.info("Loading DiFuMo 512 atlas...")
    return datasets.fetch_atlas_difumo(dimension=512, resolution_mm=2)

@lru_cache(maxsize=1)
def get_roi_coords():
    """Optimized ROI coordinate computation with progress tracking"""
    logger.info("Computing ROI coordinates...")
    atlas = get_difumo_atlas()
    atlas_img = atlas["maps"]
    
    roi_coords = []
    total_rois = 512
    
    for i, roi_img in enumerate(iter_img(atlas_img)):
        try:
            coord = find_xyz_cut_coords(roi_img, activation_threshold=0.1)
            roi_coords.append(coord)
        except Exception as e:
            logger.warning(f"Failed to compute coordinates for ROI {i}: {e}")
            roi_coords.append([0, 0, 0])  # Fallback coordinates
        
        if (i + 1) % 100 == 0:
            logger.info(f"Progress: {i+1}/{total_rois} ROI coordinates computed")
    
    return np.array(roi_coords)

@lru_cache(maxsize=1)
def get_roi_names():
    """Load and clean ROI names"""
    try:
        atlas = get_difumo_atlas()
        roi_names = atlas.labels['difumo_names'].astype(str).tolist()
        logger.info(f"Loaded {len(roi_names)} ROI names from atlas")
    except Exception as e:
        logger.warning(f"Failed to load ROI names: {e}, using generic names")
        roi_names = [f"Component_{i:03d}" for i in range(512)]
    
    # Clean names for CSV compatibility
    cleaned_names = []
    for name in roi_names:
        cleaned = name.replace(',', ';').replace('\n', ' ').replace('\r', ' ').strip()
        cleaned_names.append(cleaned)
    
    return cleaned_names

@lru_cache(maxsize=1)
def get_fsaverage_mesh():
    """Cache fsaverage mesh loading"""
    logger.info("Loading fsaverage mesh...")
    return load_fsaverage(mesh='fsaverage5')

# ===== OPTIMIZED MATRIX OPERATIONS =====
def load_matrix(condition, band):
    """Load connectivity matrix with error handling"""
    fname = GROUP_OUTPUT_DIR / f"matrix_{condition}_{band}_group_avg.csv"
    if not fname.exists():
        raise FileNotFoundError(f"Missing matrix: {fname}")
    
    try:
        matrix = pd.read_csv(fname, index_col=0).values.astype(np.float32)
        logger.info(f"Loaded matrix {condition}_{band}: shape {matrix.shape}")
        return matrix
    except Exception as e:
        raise ValueError(f"Error loading matrix {fname}: {e}")

def extract_top_edges_vectorized(matrix, condition, band, roi_names, top_n=N_TOP):
    """Vectorized extraction of top edges - much faster than loops"""
    n_rois = len(roi_names)
    
    # Get upper triangle indices (excluding diagonal)
    i_indices, j_indices = np.triu_indices(n_rois, k=1)
    
    # Extract upper triangle values
    upper_triangle_values = matrix[i_indices, j_indices]
    
    # Handle NaN/inf values
    valid_mask = np.isfinite(upper_triangle_values)
    if not np.any(valid_mask):
        logger.warning(f"No valid connections found for {condition}_{band}")
        return []
    
    # Apply mask to get valid indices and values
    valid_values = upper_triangle_values[valid_mask]
    valid_i = i_indices[valid_mask]
    valid_j = j_indices[valid_mask]
    
    # Get top N indices based on absolute values for stability
    top_indices = np.argpartition(np.abs(valid_values), -top_n)[-top_n:]
    
    # Sort by actual values (descending)
    sorted_top_indices = top_indices[np.argsort(valid_values[top_indices])[::-1]]
    
    edges = []
    for idx in sorted_top_indices:
        i, j = valid_i[idx], valid_j[idx]
        weight = valid_values[idx]
        edges.append({
            'Marker1': roi_names[i],
            'Marker2': roi_names[j],
            'Stability': float(weight),  # Ensure JSON serializable
            'Band': band,
            'Condition': condition,
            'ROI1_Index': int(i),
            'ROI2_Index': int(j)
        })
    
    return edges

# ===== OPTIMIZED COORDINATE MATCHING =====
def create_region_lookup(roi_names):
    """Create efficient lookup for region matching"""
    lookup = {}
    for i, name in enumerate(roi_names):
        # Create multiple lookup keys for better matching
        clean_name = name.lower().replace('-', ' ').replace('_', ' ')
        words = clean_name.split()
        
        # Full name
        lookup[name] = i
        lookup[clean_name] = i
        
        # Individual words for partial matching
        for word in words:
            if len(word) > 2:  # Skip very short words
                if word not in lookup:
                    lookup[word] = []
                if isinstance(lookup[word], int):
                    lookup[word] = [lookup[word]]
                if isinstance(lookup[word], list) and i not in lookup[word]:
                    lookup[word].append(i)
    
    return lookup

def get_coordinates_for_regions_optimized(regions, roi_coords, roi_names):
    """Optimized coordinate matching with fallbacks"""
    lookup = create_region_lookup(roi_names)
    coordinates = []
    matched_labels = []
    
    for region in regions:
        clean_region = region.lower().replace('-', ' ').replace('_', ' ')
        
        # Try exact match first
        if region in lookup:
            idx = lookup[region] if isinstance(lookup[region], int) else lookup[region][0]
            coordinates.append(roi_coords[idx])
            matched_labels.append(roi_names[idx])
            continue
        
        # Try cleaned match
        if clean_region in lookup:
            idx = lookup[clean_region] if isinstance(lookup[clean_region], int) else lookup[clean_region][0]
            coordinates.append(roi_coords[idx])
            matched_labels.append(roi_names[idx])
            continue
        
        # Try word-based matching
        region_words = clean_region.split()
        best_match_idx = None
        best_score = 0
        
        for word in region_words:
            if word in lookup and isinstance(lookup[word], list):
                for candidate_idx in lookup[word]:
                    candidate_words = set(roi_names[candidate_idx].lower().split())
                    region_word_set = set(region_words)
                    
                    intersection = region_word_set & candidate_words
                    union = region_word_set | candidate_words
                    
                    if union:
                        score = len(intersection) / len(union)
                        if score > best_score:
                            best_score = score
                            best_match_idx = candidate_idx
        
        if best_match_idx is not None and best_score > 0.2:
            coordinates.append(roi_coords[best_match_idx])
            matched_labels.append(roi_names[best_match_idx])
        else:
            # Fallback to origin with warning
            logger.warning(f"No match found for region: {region}")
            coordinates.append([0, 0, 0])
            matched_labels.append(region)
    
    return np.array(coordinates), matched_labels

# ===== SCIENTIFIC ABBREVIATIONS =====
SCIENTIFIC_ABBREVIATIONS = {
    'ventromedial prefrontal cortex': 'vmPFC',
    'ventromedial prefrontal': 'vmPFC',
    'middle frontal gyrus': 'DLPFC',
    'middle frontal': 'MFG',
    'precentral gyrus': 'M1',
    'precentral': 'M1',
    'superior frontal gyrus': 'SFG',
    'superior frontal': 'SFG',
    'central sulcus': 'CS',
    'globus pallidus': 'GP',
    'amygdala': 'Amyg',
    'caudate': 'Caud',
    'cerebellum': 'Cereb',
    'midbrain': 'MB',
    'hippocampus': 'Hipp',
    'thalamus': 'Thal',
    'putamen': 'Put',
    'insula': 'Ins'
}

def abbreviate_region_name(name, max_length=15):
    """Improved region name abbreviation"""
    name_lower = name.lower()
    
    # Check for scientific abbreviations
    for full_term, abbrev in SCIENTIFIC_ABBREVIATIONS.items():
        if full_term in name_lower:
            return abbrev
    
    # If still too long, truncate intelligently
    if len(name) > max_length:
        # Try to keep meaningful parts
        words = name.split()
        if len(words) > 1:
            # Take first letters of each word
            abbrev = ''.join(word[0].upper() for word in words if word)
            if len(abbrev) <= max_length:
                return abbrev
        
        # Simple truncation
        return name[:max_length-2] + '..'
    
    return name

# ===== OPTIMIZED PLOTLY VISUALIZATION =====
def create_optimized_connectome(df_edges, roi_coords, roi_names, fsaverage_mesh, 
                              title="Top Connections by Condition", brain_opacity=0.1):
    """Optimized 3D brain visualization"""
    fig = go.Figure()
    
    # Prepare brain surface
    mesh = fsaverage_mesh.pial
    vertices_lh = mesh.parts['left'].coordinates
    vertices_rh = mesh.parts['right'].coordinates
    vertices = np.vstack([vertices_lh, vertices_rh])
    
    faces_lh = mesh.parts['left'].faces
    faces_rh = mesh.parts['right'].faces + len(vertices_lh)
    faces = np.vstack([faces_lh, faces_rh])
    
    # Add brain surface
    fig.add_trace(go.Mesh3d(
        x=vertices[:, 0], y=vertices[:, 1], z=vertices[:, 2],
        i=faces[:, 0], j=faces[:, 1], k=faces[:, 2],
        color=f'rgba(200, 200, 200, {brain_opacity})',
        flatshading=True,
        name='Cortical Surface',
        hoverinfo='skip',
        showlegend=False,
        lighting=dict(ambient=0.4, diffuse=0.7, fresnel=0.1, specular=0.2, roughness=0.3),
        lightposition=dict(x=100, y=100, z=200)
    ))
    
    # Get unique regions and their coordinates
    unique_regions = pd.unique(df_edges[['Marker1', 'Marker2']].values.ravel())
    marker_coords, matched_labels = get_coordinates_for_regions_optimized(
        unique_regions, roi_coords, roi_names)
    
    region_to_coord = dict(zip(matched_labels, marker_coords))
    
    # Calculate node properties
    node_stats = {}
    for _, row in df_edges.iterrows():
        for marker in [row['Marker1'], row['Marker2']]:
            if marker not in node_stats:
                node_stats[marker] = {'connections': 0, 'max_stability': 0, 'conditions': set()}
            node_stats[marker]['connections'] += 1
            node_stats[marker]['max_stability'] = max(node_stats[marker]['max_stability'], abs(row['Stability']))
            node_stats[marker]['conditions'].add(row['Condition'])
    
    # Prepare node visualization
    node_coords = []
    node_sizes = []
    node_colors = []
    node_labels = []
    node_hover_texts = []
    
    max_connections = max([stats['connections'] for stats in node_stats.values()]) if node_stats else 1
    
    for region in matched_labels:
        if region in region_to_coord:
            coord = region_to_coord[region]
            node_coords.append(coord)
            
            stats = node_stats.get(region, {'connections': 0, 'max_stability': 0, 'conditions': set()})
            
            # Node size based on connections
            size_factor = stats['connections'] / max_connections
            size = 8 + 12 * size_factor
            node_sizes.append(size)
            
            # Node color based on stability
            stability_intensity = min(255, int(100 + 155 * stats['max_stability']))
            node_colors.append(f'rgb({stability_intensity}, {stability_intensity}, {stability_intensity})')
            
            # Labels and hover
            abbrev_label = abbreviate_region_name(region)
            node_labels.append(abbrev_label)
            
            conditions_str = ', '.join(sorted(stats['conditions']))
            hover_text = (f"<b>{region}</b><br>"
                         f"Connections: {stats['connections']}<br>"
                         f"Max Stability: {stats['max_stability']:.3f}<br>"
                         f"Conditions: {conditions_str}")
            node_hover_texts.append(hover_text)
    
    if node_coords:
        node_coords = np.array(node_coords)
        
        # Add nodes
        fig.add_trace(go.Scatter3d(
            x=node_coords[:, 0], y=node_coords[:, 1], z=node_coords[:, 2],
            mode='markers+text',
            marker=dict(
                size=node_sizes, 
                color=node_colors,
                line=dict(width=1, color='black'),
                opacity=0.9
            ),
            text=node_labels,
            textfont=dict(size=14, color='black', family="Arial Bold"),
            textposition='top center',
            hovertext=node_hover_texts,
            hoverinfo='text',
            name='Brain Regions',
            showlegend=False
        ))
    
    # Add edges grouped by condition
    for condition in CONDITIONS:
        condition_edges = df_edges[df_edges['Condition'] == condition]
        if condition_edges.empty:
            continue
        
        edge_x, edge_y, edge_z = [], [], []
        edge_hover_texts = []
        
        for _, row in condition_edges.iterrows():
            marker1, marker2 = row['Marker1'], row['Marker2']
            
            if marker1 in region_to_coord and marker2 in region_to_coord:
                coord1 = region_to_coord[marker1]
                coord2 = region_to_coord[marker2]
                
                edge_x.extend([coord1[0], coord2[0], None])
                edge_y.extend([coord1[1], coord2[1], None])
                edge_z.extend([coord1[2], coord2[2], None])
                
                hover_text = (f"<b>{marker1} ↔ {marker2}</b><br>"
                             f"Weight: {row['Stability']:.4f}<br>"
                             f"Band: {row['Band']}<br>"
                             f"Condition: {condition}")
                edge_hover_texts.extend([hover_text, hover_text, ""])
        
        if edge_x:  # Only add trace if there are valid edges
            fig.add_trace(go.Scatter3d(
                x=edge_x, y=edge_y, z=edge_z,
                mode='lines',
                line=dict(color=CONDITION_COLORS[condition], width=5),
                opacity=0.8,
                name=f'{condition} (Top {N_TOP})',
                showlegend=True,
                hovertext=edge_hover_texts,
                hoverinfo='text'
            ))
    
    # Optimized layout
    fig.update_layout(
        title=dict(
            text=f'<b>{title}</b>',
            x=0.5, y=0.95,
            font=dict(size=16, family="Arial", color='black')
        ),
        paper_bgcolor='white',
        plot_bgcolor='white',
        scene=dict(
            bgcolor='white',
            xaxis=dict(visible=False),
            yaxis=dict(visible=False),
            zaxis=dict(visible=False),
            aspectmode='data',
            camera=dict(eye=dict(x=1.3, y=1.3, z=0.8))
        ),
        legend=dict(
            orientation="v",
            yanchor="top", y=0.98,
            xanchor="left", x=0.02,
            bgcolor="rgba(255, 255, 255, 0.9)",
            bordercolor="black",
            borderwidth=1,
            font=dict(size=11, family="Arial"),
            title=dict(text="<b>Conditions</b>", font=dict(size=12))
        ),
        margin=dict(l=10, r=10, b=10, t=50),
        height=800,
        width=1000
    )
    
    return fig

# ===== MAIN PROCESSING FUNCTIONS =====
def process_band(band, roi_coords, roi_names):
    """Process a single frequency band"""
    logger.info(f"Processing band: {band}")
    
    try:
        # Load matrices
        matrix_in = load_matrix("InPhase", band)
        matrix_out = load_matrix("OutofPhase", band)
        
        # Validate matrix dimensions
        expected_size = len(roi_names)
        if matrix_in.shape != (expected_size, expected_size) or matrix_out.shape != (expected_size, expected_size):
            raise ValueError(f"Matrix size mismatch for {band}. Expected {expected_size}x{expected_size}")
        
        # Extract top edges
        top_in = extract_top_edges_vectorized(matrix_in, 'InPhase', band, roi_names, N_TOP)
        top_out = extract_top_edges_vectorized(matrix_out, 'OutofPhase', band, roi_names, N_TOP)
        
        if not top_in and not top_out:
            logger.warning(f"No valid connections found for {band}")
            return False
        
        # Combine and create DataFrame
        all_edges = top_in + top_out
        df_edges = pd.DataFrame(all_edges)
        
        # Save CSV
        csv_path = OUTPUT_DIR / f"top_{N_TOP}_per_condition_{band}.csv"
        df_edges.to_csv(csv_path, index=False)
        logger.info(f"Saved CSV: {csv_path}")
        
        # Generate 3D plot
        try:
            fsaverage = get_fsaverage_mesh()
            fig_3d = create_optimized_connectome(
                df_edges=df_edges,
                roi_coords=roi_coords,
                roi_names=roi_names,
                fsaverage_mesh=fsaverage,
                title=f"Top {N_TOP} Connections: {band} Band"
            )
            
            html_path = HTML_OUTPUT_DIR / f"top_{N_TOP}_per_condition_{band}.html"
            fig_3d.write_html(html_path, include_plotlyjs='cdn')
            logger.info(f"Saved HTML: {html_path}")
            
            return True
            
        except Exception as e:
            logger.error(f"Failed to generate plot for {band}: {e}")
            return False
            
    except Exception as e:
        logger.error(f"Failed to process {band}: {e}")
        return False

def main():
    """Main execution function with parallel processing"""
    logger.info("🚀 Starting optimized brain connectivity visualization...")
    
    try:
        # Load cached data
        roi_names = get_roi_names()
        roi_coords = get_roi_coords()
        
        logger.info(f"✅ Loaded {len(roi_names)} ROIs with coordinates")
        logger.info(f"🎯 Processing top {N_TOP} connections per condition:")
        logger.info("   🔵 BLUE = InPhase")
        logger.info("   🔴 RED = OutofPhase")
        
        # Process bands with optional parallel processing
        successful_bands = 0
        
        # Sequential processing for better error handling
        for band in BANDS:
            success = process_band(band, roi_coords, roi_names)
            if success:
                successful_bands += 1
        
        # Summary
        logger.info(f"\n🎉 Processing complete!")
        logger.info(f"   • Successfully processed: {successful_bands}/{len(BANDS)} bands")
        logger.info(f"   • Data saved to: {OUTPUT_DIR}")
        logger.info(f"   • Interactive plots: {HTML_OUTPUT_DIR}")
        logger.info(f"   • Open HTML files in browser to explore")
        
        if successful_bands == 0:
            logger.error("No bands were processed successfully. Check input data and paths.")
            return False
        
        return True
        
    except Exception as e:
        logger.error(f"Fatal error in main execution: {e}")
        return False

if __name__ == "__main__":
    success = main()
    exit(0 if success else 1)

2025-09-19 21:56:07,355 - INFO - 🚀 Starting optimized brain connectivity visualization...
2025-09-19 21:56:07,356 - INFO - Loading DiFuMo 512 atlas...


2025-09-19 21:56:07,461 - INFO - Loaded 512 ROI names from atlas
2025-09-19 21:56:07,462 - INFO - Computing ROI coordinates...
2025-09-19 21:56:18,750 - INFO - Progress: 100/512 ROI coordinates computed
2025-09-19 21:56:25,600 - INFO - Progress: 200/512 ROI coordinates computed
2025-09-19 21:56:32,942 - INFO - Progress: 300/512 ROI coordinates computed
2025-09-19 21:56:39,720 - INFO - Progress: 400/512 ROI coordinates computed
2025-09-19 21:56:46,607 - INFO - Progress: 500/512 ROI coordinates computed
2025-09-19 21:56:47,578 - INFO - ✅ Loaded 512 ROIs with coordinates
2025-09-19 21:56:47,579 - INFO - 🎯 Processing top 10 connections per condition:
2025-09-19 21:56:47,581 - INFO -    🔵 BLUE = InPhase
2025-09-19 21:56:47,581 - INFO -    🔴 RED = OutofPhase
2025-09-19 21:56:47,582 - INFO - Processing band: Theta
2025-09-19 21:56:47,608 - INFO - Loaded matrix InPhase_Theta: shape (512, 512)
2025-09-19 21:56:47,631 - INFO - Loaded matrix OutofPhase_Theta: shape (512, 512)
2025-09-19 21:56:47,

: 