### Original code from Léa Schmidt ###
 

In [None]:
# Complete pipeline: fMRIPrep preprocessed BOLD to Functional Connectivity Matrix
# This replicates much of what XCP-D does but with full manual control
#%%
import numpy as np
import pandas as pd
import nibabel as nib
import matplotlib.pyplot as plt
from nilearn import datasets, input_data, plotting, image, signal
from nilearn.connectome import ConnectivityMeasure
from nilearn.image import clean_img, smooth_img
from scipy import stats
import os
import warnings
warnings.filterwarnings('ignore')
import glob
import re

def bold_to_fc_complete_pipeline(
    bold_file,
    confounds_file,
    atlas_name='schaefer_400',
    custom_atlas_file=None,
    custom_atlas_labels=None,
    smoothing_fwhm=6,
    low_pass=0.1,
    high_pass=0.01,
    t_r=2.0,
    fd_threshold=0.5,
    dvars_threshold=2.0,
    denoising_strategy='24P+8PhysioCor+SpikeReg',
    standardize=True,
    connectivity_kind='correlation',
    plot_results=True,
    output_dir='./fc_results',
    subject_id = 'XX'
):
    """
    Complete pipeline from fMRIPrep preprocessed BOLD to FC matrix
    Replicates XCP-D functionality with manual control
    
    Parameters:
    -----------
    bold_file : str
        Path to preprocessed BOLD file (*_desc-preproc_bold.nii.gz)
    confounds_file : str  
        Path to confounds file (*_desc-confounds_timeseries.tsv)
    atlas_name : str
        Atlas to use ('schaefer_400', 'schaefer_200', 'aal', 'power_264', 'custom')
    custom_atlas_file : str, optional
        Path to custom atlas NIfTI file (use when atlas_name='custom')
    custom_atlas_labels : list, optional  
        List of region labels for custom atlas
    smoothing_fwhm : float
        FWHM for spatial smoothing in mm
    low_pass : float
        Low-pass filter frequency in Hz
    high_pass : float  
        High-pass filter frequency in Hz
    t_r : float
        Repetition time in seconds
    fd_threshold : float
        Framewise displacement threshold for scrubbing (mm)
    dvars_threshold : float
        DVARS threshold for scrubbing
    denoising_strategy : str
        Denoising strategy ('6P', '24P', '24P+8PhysioCor+SpikeReg', 'AROMA+GSR')
    standardize : bool
        Whether to standardize time series
    connectivity_kind : str
        Type of connectivity ('correlation', 'partial correlation', 'covariance')
    plot_results : bool
        Whether to generate plots
    output_dir : str
        Output directory for results
    """
    
    print("=== BOLD to FC Matrix Pipeline ===")
    print(f"Processing: {os.path.basename(bold_file)}")
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # ==========================================
    # 1. LOAD DATA
    # ==========================================
    print("\n1. Loading data...")
    bold_img = nib.load(bold_file)
    confounds_df = pd.read_csv(confounds_file, sep='\t')
    
    print(f"   BOLD shape: {bold_img.shape}")
    print(f"   Confounds shape: {confounds_df.shape}")
    
    # ==========================================
    # 2. LOAD ATLAS
    # ==========================================
    print(f"\n2. Loading atlas: {atlas_name}...")
    if atlas_name == 'custom':
        if custom_atlas_file is None:
            raise ValueError("custom_atlas_file must be provided when atlas_name='custom'")
        
        print(f"   Loading custom atlas: {custom_atlas_file}")
        atlas_img = nib.load(custom_atlas_file)
        
        # Get unique labels from atlas (excluding background=0)
        atlas_data = atlas_img.get_fdata()
        unique_labels = np.unique(atlas_data)[1:]  # Exclude 0 (background)
        n_regions = len(unique_labels)
        
        # Use provided labels or create default ones
        if custom_atlas_labels is not None:
            if len(custom_atlas_labels) != n_regions:
                print(f"   Warning: Provided {len(custom_atlas_labels)} labels but atlas has {n_regions} regions")
                labels = [f"Region_{i:03d}" for i in unique_labels]
            else:
                labels = custom_atlas_labels
        else:
            labels = [f"Region_{int(i):03d}" for i in unique_labels]
        
        print(f"   Custom atlas loaded: {n_regions} regions")
        print(f"   Atlas shape: {atlas_img.shape}")
        print(f"   Unique labels: {unique_labels[:10]}{'...' if len(unique_labels) > 10 else ''}")
        
    elif atlas_name == 'schaefer_400':
        atlas = datasets.fetch_atlas_schaefer_2018(n_rois=400, yeo_networks=7, resolution_mm=2)
        atlas_img = atlas.maps
        labels = atlas.labels
    elif atlas_name == 'schaefer_200':
        atlas = datasets.fetch_atlas_schaefer_2018(n_rois=200, yeo_networks=7, resolution_mm=2)
        atlas_img = atlas.maps
        labels = atlas.labels
    elif atlas_name == 'aal':
        atlas = datasets.fetch_atlas_aal(version='SPM12')
        atlas_img = atlas.maps
        labels = atlas.labels
    elif atlas_name == 'power_264':
        atlas = datasets.fetch_coords_power_2011()
        # For Power atlas, we'll use spheres
        atlas_img = None
        coords = np.vstack((atlas.rois['x'], atlas.rois['y'], atlas.rois['z'])).T
        labels = [f"Power_{i:03d}" for i in range(len(coords))]
    else:
        raise ValueError(f"Unknown atlas: {atlas_name}")
    
    if atlas_img is not None:
        print(f"   Atlas loaded with {len(labels)} regions")
    else:
        print(f"   Coordinate-based atlas loaded with {len(coords)} regions")
    
    # ==========================================
    # 3. PREPARE CONFOUNDS
    # ==========================================
    print(f"\n3. Preparing confounds with strategy: {denoising_strategy}...")
    
    # Base motion parameters (6P)
    motion_params = ['trans_x', 'trans_y', 'trans_z', 'rot_x', 'rot_y', 'rot_z']
    
    confound_vars = []
    
    if '6P' in denoising_strategy:
        confound_vars.extend(motion_params)

    if '12P' in denoising_strategy:
    # 6 motion + 6 derivatives + 12 squared terms
        motion_derivs = [f"{param}_derivative1" for param in motion_params]
        
        confound_vars.extend(motion_params)
        confound_vars.extend([col for col in motion_derivs if col in confounds_df.columns])

    if '24P' in denoising_strategy:
        # 6 motion + 6 derivatives + 12 squared terms
        motion_derivs = [f"{param}_derivative1" for param in motion_params]
        motion_squared = [f"{param}_power2" for param in motion_params]
        motion_deriv_squared = [f"{param}_derivative1_power2" for param in motion_params]
        
        confound_vars.extend(motion_params)
        confound_vars.extend([col for col in motion_derivs if col in confounds_df.columns])
        confound_vars.extend([col for col in motion_squared if col in confounds_df.columns])
        confound_vars.extend([col for col in motion_deriv_squared if col in confounds_df.columns])
    
    if 'PhysioCor' in denoising_strategy:
        # Add CompCor components
        compcor_cols = [col for col in confounds_df.columns if 'a_comp_cor_' in col][:8]
        confound_vars.extend(compcor_cols)
        
        # Add physiological signals (WITHOUT global signal for no-GSR analysis)
        physio_signals = ['csf', 'white_matter']
        # Only add global signal if explicitly requested
        if 'GSR' in denoising_strategy or 'global_signal' in denoising_strategy.lower():
            physio_signals.append('global_signal')
        
        confound_vars.extend([col for col in physio_signals if col in confounds_df.columns])
    
    if 'AROMA' in denoising_strategy:
        # Add AROMA components if available
        aroma_cols = [col for col in confounds_df.columns if 'aroma_motion_' in col]
        confound_vars.extend(aroma_cols)
    
    # Ensure all confound variables exist
    confound_vars = [col for col in confound_vars if col in confounds_df.columns]
    confounds = confounds_df[confound_vars].fillna(0)
    
    print(f"   Using {len(confound_vars)} confound regressors:")
    print(f"   {confound_vars}")
    
    # ==========================================
    # 4. IDENTIFY OUTLIER TIMEPOINTS
    # ==========================================
    print(f"\n4. Identifying outlier timepoints...")
    
    outlier_mask = np.zeros(len(confounds_df), dtype=bool)
    
    # Framewise displacement
    if 'framewise_displacement' in confounds_df.columns:
        fd_outliers = confounds_df['framewise_displacement'] > fd_threshold
        outlier_mask |= fd_outliers.fillna(False)
        print(f"   FD outliers (>{fd_threshold}mm): {fd_outliers.sum()}")

    if outlier_mask.sum()/len(outlier_mask) > 0.2:
        print(f"Too many FD outliers: {outlier_mask.sum()/len(outlier_mask):.1%} — stopping function.")
        return
# DVARS
    
    # DVARS
    if 'std_dvars' in confounds_df.columns:
        dvars_outliers = confounds_df['std_dvars'] > dvars_threshold
        outlier_mask |= dvars_outliers.fillna(False)
        print(f"   DVARS outliers (>{dvars_threshold}): {dvars_outliers.sum()}")
    
    print(f"   Total outlier timepoints: {outlier_mask.sum()}/{len(outlier_mask)}")
    
    # ==========================================
    # 5. SPATIAL SMOOTHING
    # ==========================================
    if smoothing_fwhm > 0:
        print(f"\n5. Applying spatial smoothing (FWHM={smoothing_fwhm}mm)...")
        bold_img = smooth_img(bold_img, fwhm=smoothing_fwhm)
    else:
        print("\n5. Skipping spatial smoothing...")
    
    # ==========================================
    # 6. EXTRACT TIME SERIES
    # ==========================================
    print(f"\n6. Extracting time series...")
    
    if atlas_img is not None:
        # Use atlas-based masker
        masker = input_data.NiftiLabelsMasker(
            labels_img=atlas_img,
            standardize=False,  # We'll handle standardization later
            memory='nilearn_cache',
            verbose=0
        )
    else:
        # Use coordinates-based masker (e.g., Power atlas)
        masker = input_data.NiftiSpheresMasker(
            seeds=coords,
            radius=5,
            standardize=False,
            memory='nilearn_cache',
            verbose=0
        )
    
    # Extract raw time series
    time_series = masker.fit_transform(bold_img)
    print(f"   Extracted time series shape: {time_series.shape}")
    
    # Check for problematic regions (all zeros, constant values, etc.)
    problematic_regions = []
    for i in range(time_series.shape[1]):
        region_ts = time_series[:, i]
        if np.all(region_ts == 0):
            problematic_regions.append(i)
            print(f"   Warning: Region {i} has all-zero time series")
        elif np.std(region_ts) < 1e-10:
            problematic_regions.append(i)
            print(f"   Warning: Region {i} has constant time series (std={np.std(region_ts):.2e})")
        elif np.any(np.isnan(region_ts)):
            problematic_regions.append(i)
            print(f"   Warning: Region {i} contains NaN values")
    
    if problematic_regions:
        print(f"   Found {len(problematic_regions)} problematic regions out of {time_series.shape[1]}")
        # Remove problematic regions
        good_regions = [i for i in range(time_series.shape[1]) if i not in problematic_regions]
        if len(good_regions) < 10:
            raise ValueError(f"Too few valid regions ({len(good_regions)}). Check your atlas alignment.")
        
        time_series = time_series[:, good_regions]
        labels = [labels[i] for i in good_regions]
        print(f"   Kept {len(good_regions)} valid regions for analysis")
    
    # ==========================================
    # 7. TEMPORAL CLEANING
    # ==========================================
    print(f"\n7. Temporal filtering and denoising...")
    
    # Prepare spike regressors for outlier timepoints
    spike_regressors = None
    if 'SpikeReg' in denoising_strategy and outlier_mask.sum() > 0:
        spike_regressors = np.zeros((len(outlier_mask), outlier_mask.sum()))
        spike_regressors[outlier_mask, :] = np.eye(outlier_mask.sum())
        print(f"   Adding {spike_regressors.shape[1]} spike regressors")
    
    # Combine confounds with spike regressors
    all_confounds = confounds.values
    if spike_regressors is not None:
        all_confounds = np.hstack([all_confounds, spike_regressors])
    
    # Apply temporal filtering and confound regression
    time_series_clean = signal.clean(
        time_series,
        confounds=all_confounds,
        t_r=t_r,
        low_pass=low_pass,
        high_pass=high_pass,
        detrend=True,
        standardize=standardize
    )
    
    print(f"   Cleaned time series shape: {time_series_clean.shape}")
    
    # ==========================================
    # 8. CENSORING (INTERPOLATION)
    # ==========================================
    print(f"\n8. Handling censored timepoints...")
    
    # For connectivity analysis, we typically interpolate rather than remove timepoints
    if outlier_mask.sum() > 0:
        print(f"   Interpolating {outlier_mask.sum()} outlier timepoints...")
        time_series_final = time_series_clean.copy()
        
        # Simple linear interpolation for outlier timepoints
        for i, is_outlier in enumerate(outlier_mask):
            if is_outlier:
                # Find nearest non-outlier timepoints
                before = i - 1
                after = i + 1
                while before >= 0 and outlier_mask[before]:
                    before -= 1
                while after < len(outlier_mask) and outlier_mask[after]:
                    after += 1
                
                if before >= 0 and after < len(outlier_mask):
                    # Linear interpolation
                    alpha = (i - before) / (after - before)
                    time_series_final[i] = ((1 - alpha) * time_series_clean[before] + 
                                          alpha * time_series_clean[after])
                elif before >= 0:
                    time_series_final[i] = time_series_clean[before]
                elif after < len(outlier_mask):
                    time_series_final[i] = time_series_clean[after]
    else:
        time_series_final = time_series_clean


    # ==========================================
    # Setp 8.1 -> Save the nifti image
    # ==========================================

    # print(f"   Saving fully cleaned BOLD image...")
    # cleaned_bold_img = image.new_img_like(bold_img, time_series_final.T.reshape(bold_img.shape))
    # cleaned_bold_path = os.path.join('/media/leas/Elements/PhD/PNC_Data/rsfmri_preprocessed_nilearn/', f'{subject_id}_cleaned_bold_interpolated.nii.gz')
    # cleaned_bold_img.to_filename(cleaned_bold_path)
    # print(f"   Fully cleaned BOLD saved to: {cleaned_bold_path}")    
    
    # ==========================================
    # 9. COMPUTE CONNECTIVITY MATRIX
    # ==========================================
    print(f"\n9. Computing {connectivity_kind} connectivity matrix...")
    
    # Final check for NaN/Inf values before computing connectivity
    if np.any(np.isnan(time_series_final)) or np.any(np.isinf(time_series_final)):
        print("   Warning: Found NaN/Inf values in final time series. Removing...")
        # Remove timepoints with any NaN/Inf
        valid_timepoints = ~(np.any(np.isnan(time_series_final), axis=1) | 
                           np.any(np.isinf(time_series_final), axis=1))
        time_series_final = time_series_final[valid_timepoints, :]
        print(f"   Kept {time_series_final.shape[0]} valid timepoints")
    
    # Check if we have enough timepoints
    if time_series_final.shape[0] < 50:
        raise ValueError(f"Too few valid timepoints ({time_series_final.shape[0]}). Cannot compute reliable connectivity.")
    
    connectivity_measure = ConnectivityMeasure(kind=connectivity_kind)
    fc_matrix = connectivity_measure.fit_transform([time_series_final])[0]
    
    # Final check for NaN values in FC matrix
    nan_count = np.sum(np.isnan(fc_matrix))
    if nan_count > 0:
        print(f"   Warning: FC matrix contains {nan_count} NaN values. Replacing with 0.")
        fc_matrix = np.nan_to_num(fc_matrix, nan=0.0, posinf=1.0, neginf=-1.0)
    
    print(f"   FC matrix shape: {fc_matrix.shape}")
    print(f"   FC matrix range: [{np.nanmin(fc_matrix):.3f}, {np.nanmax(fc_matrix):.3f}]")
    print(f"   FC matrix finite values: {np.sum(np.isfinite(fc_matrix))}/{fc_matrix.size}")


    if len(custom_atlas_labels) != fc_matrix.shape[0]:
        raise ValueError(f"Mismatch: atlas_labels ({len(custom_atlas_labels)}) vs fc_matrix size ({fc_matrix.shape[0]})")

    fc_df = pd.DataFrame(fc_matrix, index=custom_atlas_labels, columns=custom_atlas_labels)
    print(custom_atlas_labels)
    print(fc_df)
    # ==========================================
    # 10. QUALITY CONTROL METRICS
    # ==========================================
    print(f"\n10. Computing quality control metrics...")
    
    qc_metrics = {
        'mean_fd': confounds_df['framewise_displacement'].mean() if 'framewise_displacement' in confounds_df.columns else np.nan,
        'mean_dvars': confounds_df['dvars'].mean() if 'dvars' in confounds_df.columns else np.nan,
        'n_outliers': outlier_mask.sum(),
        'outlier_fraction': outlier_mask.sum() / len(outlier_mask),
        'n_timepoints': time_series_final.shape[0],
        'n_regions': time_series_final.shape[1],
        'mean_connectivity': np.nanmean(fc_matrix[np.triu_indices_from(fc_matrix, k=1)]),
        'connectivity_sparsity': np.sum(np.abs(fc_matrix) > 0.1) / (fc_matrix.shape[0] * fc_matrix.shape[1]),
        'n_finite_connections': np.sum(np.isfinite(fc_matrix)),
        'n_nan_connections': np.sum(np.isnan(fc_matrix))
    }
    
    print("   QC Metrics:")
    for key, value in qc_metrics.items():
        if isinstance(value, float):
            print(f"     {key}: {value:.3f}")
        else:
            print(f"     {key}: {value}")
    
    # ==========================================
    # 11. SAVE RESULTS
    # ==========================================
    print(f"\n11. Saving results to {output_dir}...")
    
    # Save FC matrix
    np.save(os.path.join(output_dir, 'fc_matrix.npy'), fc_matrix)
    fc_df.to_csv(os.path.join(output_dir, 'fc_matrix.csv'))
    
    # Save time series
    np.save(os.path.join(output_dir, 'time_series_clean.npy'), time_series_final)
    
    # Save QC metrics
    qc_df = pd.DataFrame([qc_metrics])
    qc_df.to_csv(os.path.join(output_dir, 'qc_metrics.csv'), index=False)
    
    # Save region labels
    labels_df = pd.DataFrame({'region_id': range(len(labels)), 'region_name': labels})
    labels_df.to_csv(os.path.join(output_dir, 'atlas_labels.csv'), index=False)
    params = {
        'atlas_name': atlas_name,
        'smoothing_fwhm': smoothing_fwhm,
        'low_pass': low_pass,
        'high_pass': high_pass,
        't_r': t_r,
        'fd_threshold': fd_threshold,
        'dvars_threshold': dvars_threshold,
        'denoising_strategy': denoising_strategy,
        'connectivity_kind': connectivity_kind,
        'n_confounds': len(confound_vars)
    }
    params_df = pd.DataFrame([params])
    params_df.to_csv(os.path.join(output_dir, 'processing_params.csv'), index=False)
    
    # ==========================================
    # 12. VISUALIZATION
    # ==========================================
    if plot_results:
        print(f"\n12. Generating visualizations...")
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # Plot 1: FC Matrix
        im1 = axes[0, 0].imshow(fc_matrix, cmap='RdBu_r', vmin=-1, vmax=1)
        axes[0, 0].set_title(f'Functional Connectivity Matrix\n({connectivity_kind})')
        axes[0, 0].set_xlabel('Brain Regions')
        axes[0, 0].set_ylabel('Brain Regions')
        plt.colorbar(im1, ax=axes[0, 0])
        
        # Plot 2: Connectivity Distribution
        fc_values = fc_matrix[np.triu_indices_from(fc_matrix, k=1)]
        # Remove NaN/Inf values for plotting
        fc_values_clean = fc_values[np.isfinite(fc_values)]
        
        if len(fc_values_clean) > 0:
            axes[0, 1].hist(fc_values_clean, bins=50, alpha=0.7, color='skyblue', edgecolor='black')
            axes[0, 1].set_title(f'Connectivity Distribution\n({len(fc_values_clean)}/{len(fc_values)} finite values)')
            axes[0, 1].set_xlabel('Correlation Coefficient')
            axes[0, 1].set_ylabel('Frequency')
            axes[0, 1].axvline(np.mean(fc_values_clean), color='red', linestyle='--', 
                              label=f'Mean: {np.mean(fc_values_clean):.3f}')
            axes[0, 1].legend()
        else:
            axes[0, 1].text(0.5, 0.5, 'No finite connectivity values\nto display', 
                           ha='center', va='center', transform=axes[0, 1].transAxes)
            axes[0, 1].set_title('Connectivity Distribution - No Data')
        
        # Plot 3: Motion Parameters
        if 'framewise_displacement' in confounds_df.columns:
            time_points = np.arange(len(confounds_df))
            axes[1, 0].plot(time_points, confounds_df['framewise_displacement'], 
                           color='blue', alpha=0.7)
            axes[1, 0].axhline(fd_threshold, color='red', linestyle='--', 
                              label=f'Threshold: {fd_threshold}mm')
            axes[1, 0].fill_between(time_points, 0, confounds_df['framewise_displacement'], 
                                   where=outlier_mask, color='red', alpha=0.3, 
                                   label=f'Outliers: {outlier_mask.sum()}')
            axes[1, 0].set_title('Framewise Displacement')
            axes[1, 0].set_xlabel('Time (TR)')
            axes[1, 0].set_ylabel('FD (mm)')
            axes[1, 0].legend()
        
        # Plot 4: Sample Time Series
        n_show = min(10, time_series_final.shape[1])
        for i in range(n_show):
            axes[1, 1].plot(time_series_final[:, i] + i*3, alpha=0.7, linewidth=0.8)
        axes[1, 1].set_title(f'Sample Time Series (first {n_show} regions)')
        axes[1, 1].set_xlabel('Time (TR)')
        axes[1, 1].set_ylabel('Signal (standardized) + offset')
        
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, 'fc_analysis_summary.png'), 
                   dpi=300, bbox_inches='tight')
        plt.show()
        
        # Additional plot: Connectome visualization (if coordinates available)
        if atlas_name == 'power_264':
            plt.figure(figsize=(10, 8))
            plotting.plot_connectome(fc_matrix, coords, 
                                   edge_threshold='90%',
                                   title=f'Functional Connectome - {atlas_name}')
            plt.savefig(os.path.join(output_dir, 'connectome_plot.png'), 
                       dpi=300, bbox_inches='tight')
            plt.show()
    
    print(f"\n=== Pipeline Complete ===")
    print(f"Results saved to: {output_dir}")
    
    return {
        'fc_matrix': fc_matrix,
        'time_series': time_series_final,
        'qc_metrics': qc_metrics,
        'outlier_mask': outlier_mask,
        'confounds_used': confound_vars
    }

#%%

# ==========================================
# USAGE EXAMPLE
# ==========================================

base_path = '/home/leas/Documents/PhD/'

labels_file = f'{base_path}03_Data/atlas-hMRF/atlas-hMRF_dseg.tsv'
if not os.path.exists(labels_file):
    raise ValueError(f"Labels file not found: {labels_file}")

labels_df = pd.read_csv(labels_file, sep="\t")  # Ensure tab separation

# Extract labels from the second column named 'label'
if "label" not in labels_df.columns:
    raise ValueError(f"Column 'label' not found in {labels_file}. Available columns: {labels_df.columns}")

labels = labels_df["label"].tolist()  # Convert column to list
# Define file paths (adjust these to your actual file paths)
# bold_file = "/media/leas/T7 Shield2/PNC_dataset/Bids_images/derivatives/sub-9293224671/func/sub-9293224671_task-bbl1restbold1124_run-1_echo-1_space-MNI152NLin2009cAsym_res-2_desc-preproc_bold.nii.gz"
# confounds_file = "/media/leas/T7 Shield2/PNC_dataset/Bids_images/derivatives/sub-9293224671/func/sub-9293224671_task-bbl1restbold1124_run-1_desc-confounds_timeseries.tsv"

INPUT_DIR = '/media/leas/T7 Shield2/PNC_dataset/Bids_images/derivatives/'
FILE_PATTERN = f'*/func/*_task-bbl1restbold1124_run-1_echo-1_space-MNI152NLin2009cAsym_res-2_desc-preproc_bold.nii.gz'


search_pattern = os.path.join(INPUT_DIR, FILE_PATTERN)
bold_file_list  = sorted(glob.glob(search_pattern))
# bold_file_list 

base_output_dir = f'{base_path}/Projects/pnc/data/custom_preproc/'


for bold_file in bold_file_list:
    # Extract filename from full path
    
    filename = os.path.basename(bold_file)
    match = re.search(r'(sub-[^_]+)', filename)
    confounds_file = bold_file.replace('_echo-1_space-MNI152NLin2009cAsym_res-2_desc-preproc_bold.nii.gz', '_desc-confounds_timeseries.tsv')
    subject_id = match.group(1)
    # Create subject-specific output directory
    subject_output_dir = os.path.join(base_output_dir, subject_id)

    if os.path.exists(f"{base_path}/Projects/pnc/data/custom_preproc/{subject_id}/fc_matrix.npy"):
        print(subject_id, " already exists")
        continue
    try:
        # Run the complete pipeline with custom atlas
        results = bold_to_fc_complete_pipeline(
            bold_file=bold_file,
            confounds_file=confounds_file,
            atlas_name='custom',  # Use 'custom' for your own atlas
            custom_atlas_file=f'{base_path}/03_Data/atlas-hMRF/space-MNI152NLin2009cAsym_atlas-hMRF_res-2_dseg.nii.gz',  # Path to your atlas
            custom_atlas_labels=labels,  # Optional: custom labels
            smoothing_fwhm=0,
            low_pass=0.1, 
            high_pass=0.01,
            t_r=3.0,
            fd_threshold=0.6,
            dvars_threshold=3.0,
            denoising_strategy='12P+8PhysioCor+SpikeReg',  # NO GSR - Options: '6P', '24P', '24P+8PhysioCor+SpikeReg' (no GSR), '24P+8PhysioCor+GSR+SpikeReg' (with GSR)
            standardize=True,
            connectivity_kind='correlation',  # Options: 'correlation', 'partial correlation', 'covariance'
            plot_results=True,
            output_dir=subject_output_dir,
            subject_id = subject_id
        )
    except:
        "Next"
        continue
#%%

# Access results
fc_matrix = results['fc_matrix']
time_series = results['time_series']
qc_metrics = results['qc_metrics']

print(f"\nFinal FC matrix shape: {fc_matrix.shape}")
print(f"Mean connectivity: {qc_metrics['mean_connectivity']:.3f}")

# %%
