In [None]:
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt, hilbert
import os
from sklearn.cluster import KMeans
import mne

## `LEiDA_EEG_eigenvectors.m`

#### Helper Functions - Filtering

In [None]:
def butter_bandpass(lowcut, highcut, fs, order=6):
    """
    Construct bandpass filter coefficients for a Butterworth filter.
    
    Parameters
    ----------
    lowcut : float
        Low cutoff frequency (Hz).
    highcut : float
        High cutoff frequency (Hz).
    fs : float
        Sampling frequency in Hz.
    order : int, optional
        The order of the Butterworth filter. Default is 6.
    
    Returns
    -------
    b, a : ndarray
        Numerator (b) and denominator (a) polynomials of the filter.
    """
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return b, a


def bandpass_filter(data, fs, lowcut, highcut, order=6):
    """
    Apply zero-phase Butterworth bandpass filter to 1D data.
    
    Parameters
    ----------
    data : ndarray
        One-dimensional time series data (e.g., one ROI).
    fs : float
        Sampling frequency in Hz.
    lowcut : float
        Low cutoff frequency (Hz).
    highcut : float
        High cutoff frequency (Hz).
    order : int, optional
        The order of the Butterworth filter. Default is 6.

    Returns
    -------
    filtered_data : ndarray
        Filtered time series, same shape as input.
    """
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    filtered_data = filtfilt(b, a, data)
    return filtered_data


#### Eigenvectors

In [None]:
def compute_leading_eigenvectors(data, fs, window_size, freq_band='alpha', verbose=True, do_plots=False):
    """
    Replicates the MATLAB pipeline for:
      1) Bandpass filtering a multi-channel EEG time series (per ROI).
      2) Computing the Hilbert transform to extract instantaneous phases.
      3) Computing dynamic phase-locking (dPL) matrices in non-overlapping windows.
      4) Extracting the leading eigenvector from each dPL.

    Parameters
    ----------
    data : ndarray
        Shape [n_areas, n_timepoints]. Each row is the time series of one brain region.
    fs : float
        Sampling frequency in Hz.
    window_size : int
        Number of samples in each non-overlapping window (e.g., 250).
    freq_band : str, optional
        Which frequency band to use: 'alpha', 'beta', or 'gamma'. Default is 'alpha'.
    verbose : bool, optional
        If True, prints progress messages. Default is True.
    do_plots : bool, optional
        If True, shows intermediate plots (raw vs filtered signal, etc.) for debugging.
        Default is False.

    Returns
    -------
    lead_eigs : ndarray
        Array of leading eigenvectors, shape [n_windows-2, n_areas].
        (We skip the first and last window as in the MATLAB code.)
    """
    
    # --------------------------
    # 1) Determine filter band
    # --------------------------
    if freq_band == 'alpha':
        lowcut, highcut = 8, 12
    elif freq_band == 'beta':
        lowcut, highcut = 15, 25 # same as in other projects of mine
    elif freq_band == 'gamma':
        lowcut, highcut = 30, 80
    else:
        raise ValueError("freq_band must be 'alpha', 'beta', or 'gamma'.")
        
    if verbose:
        print(f"Filtering data ({data.shape[0]} areas, {data.shape[1]} timepoints) "
              f"from {lowcut} to {highcut} Hz, order=6.")
    
    n_areas, T = data.shape
    print(f"N areas: {n_areas}, T: {T}")
    
    # -------------------------------------------
    # 2) De-mean and filter each ROI separately
    # -------------------------------------------
    # Subtract the mean per channel (as in the MATLAB code)
    data_demean = data - np.mean(data, axis=1, keepdims=True)
    
    filtered_data = np.zeros_like(data_demean)
    for i in range(n_areas):
        filtered_data[i, :] = bandpass_filter(data_demean[i, :], fs,
                                             lowcut, highcut, order=6)
    
    if do_plots:
        # Plot an example channel (ROI 0) before and after filtering
        t = np.arange(T) / fs
        plt.figure(figsize=(10, 4))
        plt.plot(t, data_demean[0, :], label='Raw (demeaned)', alpha=0.7)
        plt.plot(t, filtered_data[0, :], label='Filtered', alpha=0.7)
        plt.xlim([0, min(10.0, t[-1])])  # zoom in on the first second
        plt.legend()
        plt.title("ROI 0: Before and After Filtering")
        plt.xlabel("Time (s)")
        plt.show()

    # --------------------------------------------------------
    # 3) Compute Hilbert transform to get instantaneous phase
    # --------------------------------------------------------
    analytic_signal = hilbert(filtered_data, axis=1)
    phases = np.angle(analytic_signal)
    
    if do_plots:
        # Plot example channel's phase
        plt.figure(figsize=(10, 4))
        plt.plot(t, phases[0, :], label='Phase (ROI 0)')
        plt.title("Instantaneous Phase of Filtered Signal (ROI 0)")
        plt.xlabel("Time (s)")
        plt.ylabel("Phase (radians)")
        plt.xlim([0, min(10.0, t[-1])])  # zoom in
        plt.legend()
        plt.show()
    
    # -----------------------------------
    # 4) Windowing & dynamic phase-locking
    # -----------------------------------
    # In MATLAB code:
    #   repArray = 1:window_size:size(data,2)
    #   for t = 2 : (repetitions-1)
    # so effectively we skip the first and last windows.
    # We'll do the same to match their indexing logic.
    
    repArray = np.arange(0, T, window_size)  # e.g. [0, 250, 500, ...]
    repetitions = len(repArray)
    if T % window_size != 0 and verbose:
        print("Warning: discarding last incomplete window since T not multiple of window_size.")
    
    lead_eig_list = []
    example_iFC = None
    example_V1 = None
    
    for t_idx in range(1, repetitions - 1):
        start_idx = repArray[t_idx - 1]
        end_idx = repArray[t_idx]
        
        # Build iFC (dynamic phase-locking matrix) for this window
        iFC = np.zeros((n_areas, n_areas))
        for n in range(n_areas):
            for p in range(n_areas):
                # Extract the phase differences and compute the average of cos(differences)
                diffs = phases[n, start_idx:end_idx] - phases[p, start_idx:end_idx]
                iFC[n, p] = np.mean(np.cos(diffs))
        
        # --------------------------------
        # 5) Leading eigenvector of iFC
        # --------------------------------
        # For a symmetric real matrix iFC, we can use np.linalg.eigh.
        vals, vecs = np.linalg.eigh(iFC)
        # Largest eigenvalue => last entry if ascending
        idx_max = np.argmax(vals)  # or we can just take [-1] if guaranteed sorted
        V1 = vecs[:, idx_max]
        # make sure the largest eigenvector is negative
        if np.mean(V1 > 0) > 0.5:
            V1 = -V1
        elif np.mean(V1 > 0) == 0.5 and np.sum(V1[V1 > 0]) > -np.sum(V1[V1 < 0]):
            V1 = -V1
        
        lead_eig_list.append(V1)


        if t_idx == 10:
            example_iFC = iFC.copy()
            example_V1 = V1.copy()
    
    lead_eigs = np.array(lead_eig_list)  # shape [n_windows - 2, n_areas]
    
    if verbose:
        print(f"Computed {lead_eigs.shape[0]} leading eigenvectors "
              f"for {n_areas} areas with window_size={window_size}.")
    
    if do_plots and example_iFC is not None:
        # Plot the phase-locking matrix (dPL) as an image
        plt.figure(figsize=(6, 5))
        plt.imshow(example_iFC, cmap='bwr', aspect='auto', vmin=-1, vmax=1)
        plt.colorbar(label='Phase Coherence (mean cos(diff))')
        plt.title("Example Dynamic Phase-Locking Matrix (dPL)")
        plt.xlabel("Brain Region (ROI index)")
        plt.ylabel("Brain Region (ROI index)")
        plt.show()
        
        # Plot the corresponding leading eigenvector with sign preserved.
        plt.figure(figsize=(6, 4))
        markerline, stemlines, baseline = plt.stem(np.arange(n_areas), example_V1)
        plt.setp(markerline, marker='o', markersize=6, color='b')
        plt.setp(stemlines, color='b')
        plt.title("Leading Eigenvector (with sign) from Example dPL")
        plt.xlabel("Brain Region (ROI index)")
        plt.ylabel("Eigenvector Component")
        plt.show()

    return lead_eigs

## Main pipeline

In [None]:
eeg_data = mne.read_epochs("../data/source_reconstruction/s_101_Coordination-source-epo.fif")
data = eeg_data.get_data()
print(f"Data shape: {data.shape}") # Data shape: (87 epochs, 68 channels , 1536 samples)
fs = eeg_data.info['sfreq']
print(f"Sampling frequency: {fs} Hz")

# concatenate epochs so that we have a single time series (new shape [n_channels, n_timepoints])
data = np.concatenate(data, axis=1)
print(f"New data shape: {data.shape}") # New data shape: (68 channels , 1536*87 samples)

In [None]:
window_size = int(fs/4)  # 250 ms window
print(f"Window size: {window_size} samples")
freq_band = 'alpha'  # Frequency band: 'alpha', 'beta', 'gamma'

print("\n🔍 Processing data...")
lead_vecs = compute_leading_eigenvectors(data, 
                                            fs, 
                                            window_size, 
                                            freq_band, 
                                            verbose=True, 
                                            do_plots=True)

print("\n✅ Processing complete!")

In [None]:
print("Data shape after processing:")
print(f"Lead vectors shape: {lead_vecs.shape}")  # Should be [n_windows - 2, n_areas]

# 2. K-means clusters

In [None]:
def run_leida_kmeans(coll_eigenvectors, K, n_init=50, max_iter=200, random_state=None):
    """
    Replicates the essential steps of LEiDA_EEG_kmeans.m in Python:
      - K-means with 'sqeuclidean' (which is the usual sum-of-squares in scikit-learn).
      - Re-label clusters by size (largest cluster -> cluster #0, second largest -> #1, etc.).
    Returns a dict similar to the MATLAB struct.
    """

    # 1) Run k-means with the given K
    #    'algorithm'='lloyd' is standard, inertia_ is sum of squared distances
    kmeans = KMeans(n_clusters=K,
                    n_init=n_init,
                    max_iter=max_iter,
                    random_state=random_state,
                    # scikit-learn uses Euclidean distance (which is effectively
                    # the same "sqeuclidean" objective for sum-of-squared-distances).
                   )
    kmeans.fit(coll_eigenvectors)

    # labels_: shape [n_samples,], each sample assigned cluster 0..K-1
    old_labels = kmeans.labels_
    centers = kmeans.cluster_centers_

    # 2) Sort clusters by descending frequency (size)
    #    Count how many rows ended up in each cluster
    counts = np.bincount(old_labels, minlength=K)  # shape [K]
    # Sort them descending by cluster size
    ind_sort = np.argsort(counts)[::-1]  # largest cluster first
    # Create a mapping array: cluster 'ind_sort[0]' -> new label 0, etc.
    cluster_map = np.zeros(K, dtype=int)
    for new_label, old_label in enumerate(ind_sort):
        cluster_map[old_label] = new_label

    # 3) Re-label the assignments according to cluster size
    new_labels = cluster_map[old_labels]

    # 4) Reorder the cluster centers the same way
    new_centers = centers[ind_sort, :]

    # For completeness, we can compute per-point distance to each centroid
    # to replicate "D" in MATLAB (which is the distance from each point to each centroid).
    # By default, KMeans.transform() returns the Euclidean distance to each cluster center,
    # i.e. shape [n_samples, K]. For "sqeuclidean" in MATLAB, one might want to square these
    # distances manually if an exact match is needed. Below we keep them Euclidean:
    distances = kmeans.transform(coll_eigenvectors)

    # If you want the sum of squared distances to each cluster center, can do:
    # sumd_per_cluster = []
    # for c in range(K):
    #     # find points in cluster c
    #     mask = (new_labels == c)
    #     # sum of squared distances
    #     # (If you want the 'sqeuclidean' version, do distances[mask,c]**2)
    #     ssd = np.sum(distances[mask, c] ** 2)
    #     sumd_per_cluster.append(ssd)

    # Return results in a dictionary (mimicking MATLAB struct)
    return {
        'IDX': new_labels,           # cluster labels (time course) after re-labeling
        'C': new_centers,            # cluster centroids
        'counts': counts[ind_sort],  # cluster sizes in descending order
        'distances': distances,      # distance of each point to each centroid
        # optionally include sum-of-distances if needed
        # 'SUMD': sumd_per_cluster,
    }


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

def plot_kmeans_3d_pca(collEigenvectors, labels, centers, title="K-means Clusters in PCA Space"):
    """
    Project eigenvector data (collEigenvectors) onto 3 principal components,
    then scatter-plot them in 3D, color-coded by cluster label. The cluster
    centers are also projected and plotted as black-edged triangles.
    
    Parameters
    ----------
    collEigenvectors : ndarray, shape (n_samples, n_features)
        All collated eigenvectors that were clustered.
    labels : ndarray, shape (n_samples,)
        Integer cluster labels for each row in collEigenvectors.
    centers : ndarray, shape (K, n_features)
        The cluster centroids (in original feature space).
    title : str
        Title for the plot (optional).
    """
    # 1) Run PCA to reduce dimensionality to 3
    pca = PCA(n_components=3)
    X_pca = pca.fit_transform(collEigenvectors)   # shape: [n_samples, 3]
    centers_pca = pca.transform(centers)          # shape: [K, 3]
    
    # 2) Plot the data, color-coded by cluster label
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    
    scatter = ax.scatter(
        X_pca[:, 0], 
        X_pca[:, 1], 
        X_pca[:, 2],
        c=labels,                # use cluster labels as color
        cmap='rainbow',
        alpha=0.6
    )
    
    # 3) Plot the cluster centers as bigger, distinct markers
    K = centers.shape[0]
    ax.scatter(
        centers_pca[:, 0], 
        centers_pca[:, 1], 
        centers_pca[:, 2],
        c=np.arange(K),         # color them using the same colormap but indexed by cluster
        cmap='rainbow',
        marker='^',
        s=200,
        edgecolors='k',
        label='Cluster centers'
    )
    
    # 4) Tidy up the figure
    ax.set_xlabel('PC1')
    ax.set_ylabel('PC2')
    ax.set_zlabel('PC3')
    ax.set_title(title)
    # Colorbar with cluster label indices
    cbar = fig.colorbar(scatter, ax=ax, fraction=0.03, pad=0.07)
    cbar.set_label("Cluster Label")
    plt.tight_layout()
    plt.show()

In [None]:
# 1) Collate all subject data into one big array
collEigenvectors = lead_vecs
print("Shape of collEigenvectors:", collEigenvectors.shape)

In [None]:
%matplotlib inline
# 2) Try k-means with k from e.g. 4..10
rangeK = range(4, 11)
kmeans_solutions = {}

for k in rangeK:
    print(f"Running k-means for K={k} ...")
    res = run_leida_kmeans(collEigenvectors, K=k, n_init=50, max_iter=200)
    kmeans_solutions[k] = res
    plot_kmeans_3d_pca(collEigenvectors, res['IDX'], res['C'], title=f"K-means Clusters for K={k}")
    print(f"Finished K={k}.")
    print("Cluster sizes:", res['counts'])
    print("Cluster centers shape:", res['C'].shape)
    print("Cluster labels shape:", res['IDX'].shape)
    print("Distances shape:", res['distances'].shape)
    print("------\n")


In [None]:
print("\n✅ K-means clustering complete!")
print(f"Results for K={k}:\n", kmeans_solutions[k].keys())

In [None]:
print("Cluster assignments (IDX):", kmeans_solutions[k]['IDX'])


In [None]:
print("Cluster centers (C):", kmeans_solutions[k]['C'].shape)


In [None]:
# plot the cluster centers as individual eigenvectors
# like in this code from before:
# Plot the corresponding leading eigenvector with sign preserved.
        # plt.figure(figsize=(6, 4))
        # markerline, stemlines, baseline = plt.stem(np.arange(n_areas), example_V1)
        # plt.setp(markerline, marker='o', markersize=6, color='b')
        # plt.setp(stemlines, color='b')
        # plt.title("Leading Eigenvector (with sign) from Example dPL")
        # plt.xlabel("Brain Region (ROI index)")
        # plt.ylabel("Eigenvector Component")
        # plt.show()

# Plot the cluster centers as individual eigenvectors (like in the previous example)
fig, axes = plt.subplots(5, 2, figsize=(15, 10))
for i, ax in enumerate(axes.flat):
    if i < k:
        markerline, stemlines, baseline = ax.stem(np.arange(lead_vecs.shape[1]), kmeans_solutions[k]['C'][i, :])
        ax.set_title(f"Cluster center {i}")
        ax.set_xlabel("Brain Region (ROI index)")
        ax.set_ylabel("Eigenvector Component")
        plt.setp(markerline, marker='o', markersize=6, color='b')
        plt.setp(stemlines, color='b')
    else:
        ax.axis('off')
plt.tight_layout()
plt.show()


In [None]:
import mne
import numpy as np

##############################################################################
# 1) Suppose each epoch channel corresponds to a source ROI
##############################################################################
# For example:
roi_names = eeg_data.ch_names  # Each name is the label of a parcellated ROI
print("ROI names:", roi_names)
# Example: pick a cluster index
k = 5
cluster_idx = 2
center_vec = kmeans_solutions[k]['C'][cluster_idx, :]  # shape: (n_rois,)

##############################################################################
# 2) Split ROIs by positive vs negative loadings (or "majority vs minority")
##############################################################################
pos_idx = np.where(center_vec > 0)[0]
neg_idx = np.where(center_vec < 0)[0]

##############################################################################
# 3) Read Desikan–Killiany annotation from fsaverage
##############################################################################
SUBJECTS_DIR = "~/mne_data/MNE-fsaverage-data/" 
labels = mne.read_labels_from_annot(
    subject="fsaverage",
    parc="aparc",
    subjects_dir=SUBJECTS_DIR
)
# remove "unknown" if needed
if labels and labels[-1].name == "unknown":
    labels = labels[:-1]

##############################################################################
# 4) Create a Brain object on the 'pial' surface
#    This is more transparent than 'inflated', but you will still see cortex shading.
##############################################################################
Brain = mne.viz.get_brain_class()
brain = Brain(
    subject="fsaverage",
    hemi="both",
    surf="white",
    subjects_dir=SUBJECTS_DIR,
    background="white",
    size=(800, 600),
    alpha=0.3,
)

# Optionally remove the default annotation if you want only your custom highlights
# (If you do not call add_annotation, the cortex has no parcellation coloring)
# brain.add_annotation("aparc", borders=True, alpha=0.02)  # <— if you want faint outlines

##############################################################################
# 5) Helper to match your channel names to MNE label names
##############################################################################
def match_roi_label(roi, all_labels):
    """
    Attempt to find an MNE Label whose name corresponds to `roi`.
    You may need to adjust this logic depending on how your ROI names are formatted.
    """
    for lab in all_labels:
        # Example: if your epochs channel is 'bankssts-lh', you might check:
        #   if 'bankssts-lh' in lab.name
        # or if you have "lh_bankssts" in epochs, you might do the reverse.
        if roi in lab.name:
            return lab
    return None

##############################################################################
# 6) Highlight only the minority ROIs (opposite‑phase “state”)
##############################################################################
# Decide which side (positive vs negative) is smaller
if len(pos_idx) > len(neg_idx):
    minority_idx = neg_idx
else:
    minority_idx = pos_idx

for idx in minority_idx:
    roi_name = roi_names[idx]
    found_label = match_roi_label(roi_name, labels)
    if found_label is not None:
        brain.add_label(found_label, color="red", alpha=0.7, borders=False)


##############################################################################
# 7) Adjust the 3D view, e.g. to a lateral angle
##############################################################################
brain.show_view("axial")

##############################################################################
# 8) (Optional) You can save a screenshot to file:
##############################################################################
s_axial = brain.screenshot()
import matplotlib.pyplot as plt
plt.imshow(s_axial)
plt.axis('off')
plt.show()

brain.show_view("lateral")
s_lateral = brain.screenshot()
plt.imshow(s_lateral)



plt.axis('off')
plt.show()


In [None]:
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
import numpy as np
import mne
from matplotlib.colors import LinearSegmentedColormap

def visualize_brain_clusters(k, kmeans_solutions, roi_names, labels, 
                             view_angles=['lateral-left', 'lateral-right', 'dorsal', 'ventral'],
                             subjects_dir="~/mne_data/MNE-fsaverage-data/"):
    """
    Visualize brain clusters from k-means solutions by highlighting opposite-phase ROIs.
    
    Parameters:
    -----------
    k : int
        The number of clusters in the k-means solution to visualize
    kmeans_solutions : dict
        Dictionary containing k-means solutions with key 'C' for cluster centers
    roi_names : list
        List of ROI names corresponding to the dimensions of each cluster center
    labels : list
        List of MNE label objects for brain regions
    view_angles : list, optional
        List of view angles for brain visualization
    subjects_dir : str, optional
        Directory containing FreeSurfer subject data
        
    Returns:
    --------
    fig : matplotlib.figure.Figure
        The figure containing all brain visualizations
    """
    
    cluster_centers = kmeans_solutions[k]['C']
    n_clusters = cluster_centers.shape[0]
    
    # Create a custom colormap with distinct vibrant colors (no greys or whites)
    colors = [
        '#e41a1c',  # red
        '#377eb8',  # blue
        '#4daf4a',  # green
        '#984ea3',  # purple
        '#ff7f00',  # orange
        '#ffff33',  # yellow
        '#a65628',  # brown
        '#f781bf',  # pink
        '#1b9e77',  # teal
        '#d95f02',  # vermillion
        '#7570b3',  # slate blue
        '#e6ab02',  # mustard
    ]
    # If more than 12 clusters, we'll cycle through these colors
    custom_cmap = LinearSegmentedColormap.from_list('custom_vibrant', colors[:n_clusters], N=n_clusters)
    
    # Prepare subplot grid
    fig, axes = plt.subplots(n_clusters, len(view_angles),
                           figsize=(2.5 * len(view_angles), 2 * n_clusters))
    
    for c in range(n_clusters):
        center_vec = cluster_centers[c, :]
        pos_idx = np.where(center_vec > 0)[0]
        neg_idx = np.where(center_vec < 0)[0]
        minority_idx = neg_idx if len(pos_idx) > len(neg_idx) else pos_idx
        print(f"Cluster {c}: {len(minority_idx)} minority ROIs")
    
        # Create a fresh Brain for this cluster
        Brain = mne.viz.get_brain_class()
        brain = Brain(
            subject="fsaverage",
            hemi="both",
            surf="pial",
            subjects_dir=subjects_dir,
            background="white",
            size=(400, 400),
            alpha=0.3
        )
    
        # Highlight only the minority ROIs in the cluster's unique color
        color = colors[c % len(colors)]  # Use our vibrant colors
        for idx in minority_idx:
            lab = match_roi_label(roi_names[idx], labels)
            if lab is not None:
                brain.add_label(lab, color=color, alpha=1.0, borders=False)
    
        # Loop over each view, capture screenshot, and plot into the grid
        for j, view in enumerate(view_angles):
            if view == 'lateral-left':
                brain.show_view('lateral', hemi='lh')
            elif view == 'lateral-right':
                brain.show_view('lateral', hemi='rh')
            else:
                brain.show_view(view)
            
            img = brain.screenshot()
            ax = axes[c, j] if n_clusters > 1 else axes[j]
            ax.imshow(img)
            ax.axis('off')
            
            # Add column headers (view names) to the top row
            if c == 0:
                clean_name = view.replace('-', ' ').capitalize()
                ax.set_title(clean_name, fontsize=12)
            
            # Add row labels (cluster numbers) to the leftmost column
            if j == 0:
                # Create more prominent cluster labels with boxes
                ax.text(-0.15, 0.5, f"Cluster {c}", fontsize=12, 
                       ha='right', va='center', transform=ax.transAxes,
                       bbox=dict(facecolor=color, alpha=0.5, boxstyle='round,pad=0.5'))
    
        brain.close()
    
    # Add overall title
    fig.suptitle(f"LEiDA States (K={k})", fontsize=16, y=1.02)
    
    # Adjust layout to make room for the row labels
    plt.tight_layout()
    plt.subplots_adjust(left=0.1, wspace=0.05)
    
    # Add description at the bottom
    plt.figtext(0.5, -0.05, 
               f"Brain regions showing opposite-phase activity within each state.\n"
               f"Each row represents one of the {n_clusters} states identified by k-means clustering.",
               ha='center', fontsize=10)
    
    return fig



k = 10
roi_names = eeg_data.ch_names  # Each name is the label of a parcellated ROI
print("ROI names:", roi_names)

# Read Desikan–Killiany annotation from fsaverage
SUBJECTS_DIR = "~/mne_data/MNE-fsaverage-data/" 
labels = mne.read_labels_from_annot(
    subject="fsaverage",
    parc="aparc",
    subjects_dir=SUBJECTS_DIR
)
# remove "unknown" if needed
if labels and labels[-1].name == "unknown":
    labels = labels[:-1]

view_angles = ['lateral-left', 'lateral-right', 'dorsal', 'ventral']
fig = visualize_brain_clusters(k, kmeans_solutions, roi_names, labels, view_angles=view_angles)
plt.show()



In [None]:

print("Cluster counts:", kmeans_solutions[k]['counts'])

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# Suppose we already have these:
chosen_k = 5
solution = kmeans_solutions[chosen_k]
all_labels = solution['IDX']  # shape: [n_time_windows]
time_points = len(all_labels)

# Create a color palette for the states
palette = sns.color_palette("husl", chosen_k)

plt.figure(figsize=(10, 4))
plt.scatter(
    x=range(time_points),
    y=all_labels,
    c=[palette[label] for label in all_labels],
    s=10,
    alpha=0.8
)
plt.title(f"State Transitions for Single Subject (K={chosen_k})", fontsize=14)
plt.xlabel("Time Window")
plt.ylabel("Cluster Label")
plt.yticks(range(chosen_k))
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

chosen_k = 5
solution = kmeans_solutions[chosen_k]
all_labels = solution['IDX']
time_points = len(all_labels)

# Compute occurrence probability for each cluster
prob_vec = [(all_labels == c).mean() for c in range(chosen_k)]

palette = sns.color_palette("husl", chosen_k)

plt.figure(figsize=(8, 4))
plt.bar(range(chosen_k), prob_vec, color=palette)
plt.xlabel("Cluster (State)")
plt.ylabel("Occurrence Probability")
plt.title(f"State Occurrence Probability (Single Subject, K={chosen_k})")
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

chosen_k = 10
solution = kmeans_solutions[chosen_k]
all_labels = solution['IDX']

prob_vec = [(all_labels == c).mean() for c in range(chosen_k)]

palette = sns.color_palette("husl", chosen_k)

plt.figure(figsize=(4, 5))

x_pos = [0]  # one bar at x=0
bottom_val = 0
for c in range(chosen_k):
    plt.bar(
        x_pos,
        prob_vec[c],
        bottom=bottom_val,
        color=palette[c],
        label=f"State {c}"
    )
    bottom_val += prob_vec[c]

plt.xticks([0], ["Subject 1"])
plt.ylabel("Proportion of Time in State")
plt.title(f"State Distribution (Single Subject, K={chosen_k})")
plt.legend(bbox_to_anchor=(1,1))
plt.tight_layout()
plt.show()


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

chosen_k = 10
solution = kmeans_solutions[chosen_k]
all_labels = solution['IDX']

prob_vec = [(all_labels == c).mean() for c in range(chosen_k)]

df = pd.DataFrame([prob_vec], columns=[f"State {i}" for i in range(chosen_k)])
df.index = ["Subject 1"]

plt.figure(figsize=(8, 3))
sns.heatmap(df, annot=True, cmap="Blues", linewidths=0.5)
plt.title(f"Occurrence Probability Heatmap (Single Subject, K={chosen_k})")
plt.xlabel("Brain State")
plt.ylabel("Subject")
plt.show()


In [None]:
import numpy as np

def compute_dwell_time_single_subject(labels, chosen_k):
    """
    For each state, compute average dwell time for the single subject:
    (total # of windows in that state) / (# of runs).
    """
    dwell_times = np.zeros(chosen_k)
    for state in range(chosen_k):
        idx = np.where(labels == state)[0]
        if len(idx) == 0:
            dwell_times[state] = 0
            continue
        # Number of consecutive runs
        diffs = np.diff(idx)
        run_breaks = np.where(diffs > 1)[0]  
        num_runs = len(run_breaks) + 1
        dwell_times[state] = len(idx) / num_runs
    return dwell_times

# Example usage:
import matplotlib.pyplot as plt
import seaborn as sns

chosen_k = 10
solution = kmeans_solutions[chosen_k]
all_labels = solution['IDX']

dwell_times = compute_dwell_time_single_subject(all_labels, chosen_k)

plt.figure(figsize=(8, 4))
sns.barplot(x=list(range(chosen_k)), y=dwell_times)
plt.xlabel("State")
plt.ylabel("Dwell Time (avg. consecutive windows in state)")
plt.title(f"Dwell Time per State (Single Subject, K={chosen_k})")
plt.show()


In [None]:
import numpy as np

def compute_transition_matrix_single_subject(labels, chosen_k):
    """
    Returns a chosen_k x chosen_k matrix, where entry (i, j)
    is the probability of transitioning from state i to j.
    """
    transition_mat = np.zeros((chosen_k, chosen_k))
    for t in range(len(labels) - 1):
        from_state = labels[t]
        to_state = labels[t + 1]
        transition_mat[from_state, to_state] += 1

    row_sums = transition_mat.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1
    transition_mat /= row_sums
    return transition_mat

# Example usage + plotting:
import matplotlib.pyplot as plt
import seaborn as sns

chosen_k = 10
solution = kmeans_solutions[chosen_k]
all_labels = solution['IDX']

transition_matrix = compute_transition_matrix_single_subject(all_labels, chosen_k)

plt.figure(figsize=(6, 5))
sns.heatmap(transition_matrix, annot=True, cmap="Blues", linewidths=0.5, fmt=".2f")
plt.xlabel("To State")
plt.ylabel("From State")
plt.title(f"State Transition Matrix (Single Subject, K={chosen_k})")
plt.xticks(ticks=np.arange(chosen_k)+0.5, labels=[str(i) for i in range(chosen_k)])
plt.yticks(ticks=np.arange(chosen_k)+0.5, labels=[str(i) for i in range(chosen_k)])
plt.show()
