In [None]:
!pip install tslearn nilearn

import os
import numpy as np
import matplotlib.pyplot as plt
import requests
import tarfile
import gc # For garbage collection and memory management
import matplotlib.cm as cm # For colormap handling
import pandas as pd # For saving to CSV

# Nilearn for brain image plotting and datasets
from nilearn import plotting, datasets

# tslearn for K-Means clustering (DTW K-Means)
from tslearn.clustering import TimeSeriesKMeans
from tslearn.barycenters import dtw_barycenter_averaging

# --- Basic parameters ---
HCP_DIR = "./DATA" # Directory to store downloaded and extracted HCP data
plots_output_dir = "./brain_plots" # Directory to store generated plots
csv_output_dir = "./csv_data" # Directory to store generated CSV files

# Ensure directories exist
if not os.path.isdir(HCP_DIR):
    os.mkdir(HCP_DIR)
    print(f"Created data directory: {HCP_DIR}")

if not os.path.isdir(plots_output_dir):
    os.mkdir(plots_output_dir)
    print(f"Created plots output directory: {plots_output_dir}")

if not os.path.isdir(csv_output_dir):
    os.mkdir(csv_output_dir)
    print(f"Created CSV output directory: {csv_output_dir}")

print(f"Brain plots will be saved to: {plots_output_dir}")
print(f"CSV data will be saved to: {csv_output_dir}")

N_PARCELS = 360 # Number of parcels in the Glasser Atlas
TR = 0.72 # Repetition Time (TR) for HCP fMRI data
N_RUNS_REST = 4 # Number of resting-state runs per subject
N_SUBJECTS_FOR_GROUP_ANALYSIS = 100 # How many subjects to include in the group average (for K-means on brain states)

# BOLD run names (only need resting state for this analysis)
BOLD_NAMES = [
    "rfMRI_REST1_LR", "rfMRI_REST1_RL",
    "rfMRI_REST2_LR", "rfMRI_REST2_RL",
]

# --- Download and Extract Data ---
print("\n--- Checking and downloading necessary data files ---")
# Only need hcp_rest.tgz and atlas.npz for this specific request
fnames_to_check = ["hcp_rest.tgz", "atlas.npz"]
urls_to_check = {
    "hcp_rest.tgz": "https://osf.io/bqp7m/download",
    "atlas.npz": "https://osf.io/j5kuc/download"
}

for fname in fnames_to_check:
    url = urls_to_check.get(fname)
    if not url:
        print(f"Error: URL for {fname} not found in urls_to_check map. Skipping download.")
        continue

    if not os.path.isfile(fname):
        try:
            print(f"Downloading {fname} from {url}...")
            r = requests.get(url, stream=True)
            r.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
        except requests.exceptions.RequestException as e:
            print(f"!!! Failed to download {fname}: {e} !!! Please check your internet connection or URL.")
        else:
            with open(fname, "wb") as fid:
                for chunk in r.iter_content(chunk_size=8192):
                    fid.write(chunk)
            print(f"Download of {fname} completed!")
    else:
        print(f"'{fname}' already exists, skipping download.")

print("\n--- Extracting data ---")
extraction_targets = {
    "hcp_rest.tgz": os.path.join(HCP_DIR, "hcp_rest"),
}

for tgz_file, extract_dir in extraction_targets.items():
    if os.path.isfile(tgz_file):
        if not os.path.exists(extract_dir):
            print(f"Extracting {tgz_file} to {extract_dir}...")
            try:
                with tarfile.open(tgz_file, "r:gz") as tar:
                    tar.extractall(path=HCP_DIR)
                print(f"Extraction of {tgz_file} completed!")
            except tarfile.ReadError as e:
                print(f"!!! Failed to extract {tgz_file}: {e} !!! Ensure download was complete/not corrupted.")
        else:
            print(f"{extract_dir} already exists, skipping extraction for {tgz_file}.")
    else:
        print(f"'{tgz_file}' not found, skipping extraction.")

# --- Load Atlas Information and fsaverage surface ---
print("\n--- Loading atlas and fsaverage surface information ---")

try:
    with np.load("atlas.npz") as dobj:
        atlas = dict(**dobj)
    print("Atlas information (Glasser parcels and hemisphere labels) loaded.")
except FileNotFoundError:
    print("Error: atlas.npz not found. Please ensure it is in the current directory or downloaded.")
    exit()

fsaverage = datasets.fetch_surf_fsaverage()
print("fsaverage surface data loaded.")


# --- Helper functions ---
def get_image_ids(name):
    """
    Returns the BOLD run IDs for a given task/rest name.
    """
    run_ids = [i for i, code in enumerate(BOLD_NAMES, 1) if name.upper() in code]
    if not run_ids:
        raise ValueError(f"Found no data for '{name}'")
    return run_ids

def load_single_timeseries(subject, bold_run, dir, remove_mean=True):
    """
    Loads a single BOLD time series for a given subject and run.
    """
    bold_path = os.path.join(dir, "subjects", str(subject), "timeseries")
    bold_file = f"bold{bold_run}_Atlas_MSMAll_Glasser360Cortical.npy"
    ts = np.load(os.path.join(bold_path, bold_file))
    if remove_mean:
        ts -= ts.mean(axis=1, keepdims=True)
    return ts

def load_timeseries(subject, name, dir, runs=None, concat=True, remove_mean=True):
    """
    Loads and optionally concatenates multiple BOLD time series for a subject and condition.
    """
    if runs is None:
        runs_count = N_RUNS_REST # Only resting state for this analysis
        runs = range(runs_count)
    elif isinstance(runs, int):
        runs = [runs]

    offset = get_image_ids(name)[0]

    bold_data = [load_single_timeseries(subject, offset + run, dir, remove_mean) for run in runs]

    if concat:
        bold_data = np.concatenate(bold_data, axis=-1)
    return bold_data

# --- End Helper Functions ---


##################################################################################
#                                 MAIN ANALYSIS:                               #
#        K-Means Clustering on Subject-Level Resting-State Brain States        #
##################################################################################

print("\n" + "="*80)
print(f"Starting K-Means Clustering on Resting-State Brain States for {N_SUBJECTS_FOR_GROUP_ANALYSIS} Subjects")
print("="*80 + "\n")

# Parameters for K-Means Clustering on Resting States
chosen_k = 3 # Number of clusters (kernels)

resting_state_data_dir = os.path.join(HCP_DIR, "hcp_rest")

print(f"Collecting resting-state brain states for the first {N_SUBJECTS_FOR_GROUP_ANALYSIS} subjects...")

all_subjects_resting_brain_states = []

for subject_idx in range(N_SUBJECTS_FOR_GROUP_ANALYSIS):
    # print(f"  Processing Subject {subject_idx}...") # Uncomment for detailed progress
    try:
        ts_rest_subject = load_timeseries(subject=subject_idx,
                                          name="rest",
                                          dir=resting_state_data_dir,
                                          concat=True,
                                          remove_mean=True)

        brain_state_subject = np.mean(ts_rest_subject, axis=1) # Average over time to get a single brain state per parcel for the subject
        all_subjects_resting_brain_states.append(brain_state_subject)

        del ts_rest_subject, brain_state_subject
        gc.collect()

    except FileNotFoundError:
        print(f"    Warning: Resting-state data not found for Subject {subject_idx} at {os.path.join(resting_state_data_dir, 'subjects', str(subject_idx))}. Skipping this subject.")
        continue
    except Exception as e:
        print(f"    Error processing Subject {subject_idx}: {e}. Skipping this subject.")
        continue

if not all_subjects_resting_brain_states:
    print("No subjects were successfully processed for K-Means analysis. Please check data paths and files.")
else:
    # Convert list of brain states to a NumPy array for clustering
    # Shape will be (N_SUBJECTS, N_PARCELS)
    all_subjects_resting_brain_states_array = np.array(all_subjects_resting_brain_states)
    print(f"Shape of collected subject-level brain states: {all_subjects_resting_brain_states_array.shape}")

    # For tslearn, data needs to be (n_samples, n_timestamps, n_features).
    # Here, n_samples is N_SUBJECTS, n_timestamps is N_PARCELS, and n_features is 1 (univariate 'brain state' vector)
    # The 'time' dimension for clustering is the parcels themselves.
    resting_states_for_kmeans = all_subjects_resting_brain_states_array[:, :, np.newaxis]
    print(f"Shape of resting brain states for tslearn K-Means: {resting_states_for_kmeans.shape}")

    print(f"\nPerforming K-Means Clustering (k={chosen_k}) on {len(all_subjects_resting_brain_states)} subject resting brain states...")

    try:
        km_rest = TimeSeriesKMeans(n_clusters=chosen_k, metric="dtw", max_iter=10, random_state=0, verbose=False)
        km_rest.fit(resting_states_for_kmeans)
        print("  Clustering of subject resting states complete.")

        # Get cluster assignments for each subject
        subject_clusters = km_rest.labels_
        print(f"Subject cluster assignments (labels_) shape: {subject_clusters.shape}")
        print(f"Unique subject cluster IDs: {np.unique(subject_clusters)}")

        # Get the barycenters (mean brain state for each cluster of subjects)
        # These barycenters represent the 'average' brain state for each identified cluster.
        barycenters_rest = km_rest.cluster_centers_
        barycenters_rest_reshaped = barycenters_rest.squeeze() # Remove the single-feature dimension
        print(f"Shape of barycenters (mean brain states for each cluster): {barycenters_rest_reshaped.shape}")

        # --- Plot Barycenter Brain States on Brain Surface ---
        print(f"\nPlotting barycenter brain states for k={chosen_k} clusters...")
        try:
            fig_bary_brains = plt.figure(figsize=(16, 4 * chosen_k))
            fig_bary_brains.suptitle(f'Resting State Brain State Barycenters (k={chosen_k})', fontsize=16)

            for i in range(chosen_k):
                # The barycenter itself is a brain state (a 360-element vector)
                current_barycenter_state = barycenters_rest_reshaped[i]

                vabs_max_bary = np.max(np.abs(current_barycenter_state))
                # Ensure the color range is not zero if all values are tiny
                if vabs_max_bary == 0:
                    vabs_max_bary = 1e-12 # Small non-zero value to prevent division by zero for colorbar
                vmin_bary, vmax_bary = -vabs_max_bary, vabs_max_bary

                # Left Hemisphere plot for current barycenter
                ax_bary_left = fig_bary_brains.add_subplot(chosen_k, 2, i*2 + 1, projection='3d')
                display_bary_left = plotting.plot_surf_stat_map(
                    fsaverage['infl_left'],
                    current_barycenter_state[atlas["labels_L"]],
                    hemi='left',
                    view='lateral',
                    cmap='cold_hot',
                    colorbar=False,
                    bg_map=fsaverage['sulc_left'],
                    title=f'Barycenter {i} - Left',
                    vmax=vmax_bary, vmin=vmin_bary,
                    figure=fig_bary_brains,
                    axes=ax_bary_left
                )

                # Right Hemisphere plot for current barycenter
                ax_bary_right = fig_bary_brains.add_subplot(chosen_k, 2, i*2 + 2, projection='3d')
                display_bary_right = plotting.plot_surf_stat_map(
                    fsaverage['infl_right'],
                    current_barycenter_state[atlas["labels_R"]],
                    hemi='right',
                    view='lateral',
                    cmap='cold_hot',
                    colorbar=False,
                    bg_map=fsaverage['sulc_right'],
                    title=f'Barycenter {i} - Right',
                    vmax=vmax_bary, vmin=vmin_bary,
                    figure=fig_bary_brains,
                    axes=ax_bary_right
                )

                # Add a colorbar for each row (pair of hemispheres)
                # Position: [left, bottom, width, height]
                # 'bottom' calculated to place it below each row of plots
                # This setup gives individual colorbars for each barycenter, which is appropriate as their ranges might differ.
                cbar_ax_bary = fig_bary_brains.add_axes([0.3, (chosen_k - i - 1) * (1/(chosen_k+0.5)) + 0.05, 0.4, 0.02]) # Adjusted position for each row
                norm_bary = plt.Normalize(vmin=vmin_bary, vmax=vmax_bary)
                sm_bary = cm.ScalarMappable(cmap='cold_hot', norm=norm_bary)
                sm_bary.set_array([])
                cbar_bary = fig_bary_brains.colorbar(sm_bary, cax=cbar_ax_bary, orientation='horizontal')
                cbar_bary.set_label('Mean BOLD (Barycenter)')
                cbar_bary.ax.tick_params(labelsize=8)

            fig_bary_brains.subplots_adjust(left=0.01, right=0.99, top=0.95, bottom=0.01, hspace=0.3, wspace=0.05)
            plt.savefig(os.path.join(plots_output_dir, f'resting_state_barycenter_brain_states_k{chosen_k}.png'))
            plt.close(fig_bary_brains)
            print(f"Barycenter brain states plot saved to '{plots_output_dir}'.")

        except Exception as e:
            print(f"Error plotting barycenter brain states: {e}")

        # --- Save Barycenter Values to CSV ---
        print(f"\nSaving barycenter values to CSV...")
        try:
            # Create column names
            column_names = [f'Barycenter_{i}' for i in range(chosen_k)]
            # Create a DataFrame from the reshaped barycenters
            # Each row is a parcel, each column is a barycenter
            df_barycenters = pd.DataFrame(barycenters_rest_reshaped.T, columns=column_names)
            # Add a column for Parcel ID if desired (0 to 359)
            df_barycenters.index.name = 'Parcel_ID'

            csv_file_path = os.path.join(csv_output_dir, f'resting_state_barycenters_k{chosen_k}.csv')
            df_barycenters.to_csv(csv_file_path)
            print(f"Barycenter values saved to '{csv_file_path}'.")
        except Exception as e:
            print(f"Error saving barycenter values to CSV: {e}")


    except Exception as e:
        print(f"  An error occurred during K-Means clustering of resting states: {e}")

print("\n" + "="*80)
print("K-Means Clustering on Resting-State Brain States Completed.")
print("="*80 + "\n")

Created CSV output directory: ./csv_data
Brain plots will be saved to: ./brain_plots
CSV data will be saved to: ./csv_data

--- Checking and downloading necessary data files ---
'hcp_rest.tgz' already exists, skipping download.
'atlas.npz' already exists, skipping download.

--- Extracting data ---
./DATA/hcp_rest already exists, skipping extraction for hcp_rest.tgz.

--- Loading atlas and fsaverage surface information ---
Atlas information (Glasser parcels and hemisphere labels) loaded.
fsaverage surface data loaded.

Starting K-Means Clustering on Resting-State Brain States for 100 Subjects

Collecting resting-state brain states for the first 100 subjects...
Shape of collected subject-level brain states: (100, 360)
Shape of resting brain states for tslearn K-Means: (100, 360, 1)

Performing K-Means Clustering (k=3) on 100 subject resting brain states...
  Clustering of subject resting states complete.
Subject cluster assignments (labels_) shape: (100,)
Unique subject cluster IDs: [0 

  display_bary_left = plotting.plot_surf_stat_map(
  display_bary_right = plotting.plot_surf_stat_map(


Barycenter brain states plot saved to './brain_plots'.

Saving barycenter values to CSV...
Barycenter values saved to './csv_data/resting_state_barycenters_k3.csv'.

K-Means Clustering on Resting-State Brain States Completed.

