In [None]:
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt, hilbert
import os
from sklearn.cluster import KMeans
import mne

# Improved

In [None]:
import os
import numpy as np
from sklearn.cluster import KMeans

##############################################################################
# 1) Load data into a nested dictionary
##############################################################################
def load_leading_eigenvectors(data_dir):
    """
    Loads .npy files of shape (n_epochs, n_windows, n_channels) from data_dir.
    Files are named like 's_{subject_number}_{Condition}-eigenvectors.npy'.
    
    Returns a nested dict with this structure:
        data_dict[condition][subject_number] = eigenvectors, shape (n_epochs, n_windows, n_channels)
    """
    data_dict = {}
    
    # Example conditions you expect to find in filenames
    possible_conditions = ["Coordination", "Solo", "Spontaneous"]
    
    for fname in sorted(os.listdir(data_dir)):

        if fname.endswith(".npy") and "eigenvectors" in fname:
            # Strip extension
            base = fname.replace(".npy","")
            
            # Remove any "s_" prefix if it exists
            if base.startswith("s_"):
                base = base[2:]

            parts = base.split("_", 1)
            subject_str = parts[0]
            
            remainder = parts[1]

            cond_part = remainder.split("-")[0]

            condition = None
            for c in possible_conditions:
                if cond_part.lower().startswith(c.lower()):
                    condition = c
                    break
            if condition is None:
                print(f"Warning: condition not recognized in filename {fname}. Skipping.")
                continue
            
            # Load file
            filepath = os.path.join(data_dir, fname)
            data_array = np.load(filepath)
            
            # Initialize sub-dict if needed
            if condition not in data_dict:
                data_dict[condition] = {}
            
            data_dict[condition][subject_str] = data_array
            print(f"Loaded file {fname} for subject={subject_str}, condition={condition}, shape={data_array.shape}")
    
    return data_dict

##############################################################################
# 2) Collate data into a 2D array (and keep track of indices)
##############################################################################
def collate_eigenvectors(data_dict):
    """
    Takes the nested data_dict[condition][subject] = array of shape 
    (n_epochs, n_windows, n_channels)
    and stacks all conditions & subjects into one 2D array of shape:
       (total_epochs * total_windows, n_channels)

    Also builds a 'meta' list describing how each row in the flattened array
    maps back to (condition, subject, epoch_idx, window_idx).

    Returns: coll_eigs, meta_list
    """
    coll_list = []
    meta_list = []
    
    for condition, subj_dict in data_dict.items():
        for subject, eigenvectors in subj_dict.items():
            n_epochs, n_windows, n_channels = eigenvectors.shape
            
            # For each epoch and window, we have one row
            # So we'll reshape (n_epochs, n_windows, n_channels) -> (n_epochs*n_windows, n_channels)
            reshaped = eigenvectors.reshape(-1, n_channels)
            
            coll_list.append(reshaped)
            
            # Build meta info for each row in 'reshaped'
            # We need (condition, subject, epoch_idx, window_idx)
            for e in range(n_epochs):
                for w in range(n_windows):
                    meta_list.append((condition, subject, e, w))
    
    # Stack vertically
    coll_eigs = np.vstack(coll_list)
    coll_eigs_shape = coll_eigs.shape
    print(f"Collated shape = {coll_eigs_shape} (rows, n_channels).  #rows = sum of (n_epochs*n_windows) across all data.")
    
    return coll_eigs, meta_list


##############################################################################
# 3) K-means function (similar to run_leida_kmeans)
##############################################################################
def run_leida_kmeans(coll_eigenvectors, K, n_init=50, max_iter=200, random_state=None):
    """
    Replicates the essential steps of LEiDA_EEG_kmeans in Python:
      - K-means with 'sqeuclidean' distance 
      - Re-label clusters by descending frequency (largest cluster -> cluster 0, etc.).
    Returns dict with keys ['IDX', 'C', 'counts', 'distances'].
    """
    print(f"Running k-means with K={K}, n_init={n_init}, max_iter={max_iter}")
    kmeans = KMeans(n_clusters=K,
                    n_init=n_init,
                    max_iter=max_iter,
                    random_state=random_state,
                    verbose=0)
    kmeans.fit(coll_eigenvectors)
    
    old_labels = kmeans.labels_        # shape [n_samples,]
    centers   = kmeans.cluster_centers_
    # Sort by descending size
    counts = np.bincount(old_labels, minlength=K)
    ind_sort = np.argsort(counts)[::-1]  # largest cluster first
    
    # Create cluster_map: old_label -> new_label
    cluster_map = np.zeros(K, dtype=int)
    for new_label, old_label in enumerate(ind_sort):
        cluster_map[old_label] = new_label
    
    new_labels = cluster_map[old_labels]
    new_centers = centers[ind_sort, :]
    distances = kmeans.transform(coll_eigenvectors)
    
    return {
        'IDX': new_labels,           
        'C': new_centers,           
        'counts': counts[ind_sort],  
        'distances': distances,      
    }

##############################################################################
# 4) Map cluster labels back into the same dictionary structure
##############################################################################
def map_labels_back(kmeans_labels, meta_list, data_dict):
    """
    Reorganize the 1D array of cluster assignments (kmeans_labels) into
    a parallel nested dictionary matching data_dict’s keys and shapes.

    data_dict[condition][subject] has shape (n_epochs, n_windows, n_channels),
    so we create labels_dict[condition][subject], shape = (n_epochs, n_windows).
    """
    # Prepare an empty structure
    labels_dict = {}
    
    # We need to know how many windows per row. But we can reconstruct from meta_list:
    # meta_list[i] = (condition, subject, epoch_idx, window_idx)
    # We'll gather the assigned cluster for each row.
    
    for (label_value, meta) in zip(kmeans_labels, meta_list):
        (condition, subject, e, w) = meta
        # Initialize empty dictionary if needed
        if condition not in labels_dict:
            labels_dict[condition] = {}
        if subject not in labels_dict[condition]:
            # figure out shape from data_dict to pre-allocate
            n_epochs, n_windows, _ = data_dict[condition][subject].shape
            labels_dict[condition][subject] = np.zeros((n_epochs, n_windows), dtype=int)
        
        labels_dict[condition][subject][e, w] = label_value
    
    return labels_dict

In [None]:
data_dir = "../data/leading/MNE/alpha/"

# 1) Load data into dictionary
data_dict = load_leading_eigenvectors(data_dir)
print(f"Loaded data_dict with {len(data_dict)} conditions and {sum(len(v) for v in data_dict.values())} subjects.")

In [None]:
# 2) Collate all eigenvectors
coll_eigs, meta_list = collate_eigenvectors(data_dict)

# 3) Run k-means 
k = 10
kmeans_results = run_leida_kmeans(coll_eigs, K=k, n_init=50, max_iter=100, random_state=42)

# 4) Map cluster labels back
labels_dict = map_labels_back(kmeans_results['IDX'], meta_list, data_dict)

In [None]:
# Inspect an example
for cond in labels_dict:
    for subj in labels_dict[cond]:
        print(f"{cond}, subject={subj}, labels shape = {labels_dict[cond][subj].shape}")
        # Check cluster label of the first epoch-window
        # print(labels_dict[cond][subj][0,0])

# And you can also access cluster centers from kmeans_res['C']
print(f"Cluster centers shape: {kmeans_results['C'].shape}")
print("Done.")


# Prepare the data

In [None]:
# load the .npy files
data_dir = '../data/leading/MNE/alpha/'
# go through all files in the directory and add them to a list
recordings = []
for filename in os.listdir(data_dir):
    if filename.endswith('.npy'):
        #load the .npy file
        filepath = os.path.join(data_dir, filename)
        # load the .npy file
        data = np.load(filepath)
        print(f'Loaded {filename} with shape {data.shape}') # shape should be (n_epochs, n_windows, n_channels)
        # append the data to the list
        recordings.append(data)


In [None]:
print(f'Loaded {len(recordings)} recordings')

In [None]:
# stack the recordings along the first axis
coll_eigenvectors = np.vstack(recordings)
print(f'Collated eigenvectors shape: {coll_eigenvectors.shape}') # shape should be (n_recordings * n_epochs, n_windows, n_channels)
# reshape the data to (n_recordings * n_epochs * n_windows, n_channels)
coll_eigenvectors = coll_eigenvectors.reshape(-1, coll_eigenvectors.shape[-1])
print(f'Collated eigenvectors reshaped to: {coll_eigenvectors.shape}') # shape should be (n_recordings * n_epochs * n_windows, n_channels)

# Run k-means for a chosen k

In [None]:
def run_leida_kmeans(coll_eigenvectors, K, n_init=50, max_iter=200, random_state=None):
    """
    Replicates the essential steps of LEiDA_EEG_kmeans.m in Python:
      - K-means with 'sqeuclidean' (which is the usual sum-of-squares in scikit-learn).
      - Re-label clusters by size (largest cluster -> cluster #0, second largest -> #1, etc.).
    Returns a dict similar to the MATLAB struct.
    """

    # 1) Run k-means with the given K
    #    'algorithm'='lloyd' is standard, inertia_ is sum of squared distances
    kmeans = KMeans(n_clusters=K,
                    n_init=n_init,
                    max_iter=max_iter,
                    random_state=random_state,
                    verbose=1
                   )
    kmeans.fit(coll_eigenvectors)

    # labels_: shape [n_samples,], each sample assigned cluster 0..K-1
    old_labels = kmeans.labels_
    centers = kmeans.cluster_centers_

    # 2) Sort clusters by descending frequency (size)
    #    Count how many rows ended up in each cluster
    counts = np.bincount(old_labels, minlength=K)  # shape [K]
    # Sort them descending by cluster size
    ind_sort = np.argsort(counts)[::-1]  # largest cluster first
    # Create a mapping array: cluster 'ind_sort[0]' -> new label 0, etc.
    cluster_map = np.zeros(K, dtype=int)
    for new_label, old_label in enumerate(ind_sort):
        cluster_map[old_label] = new_label

    # 3) Re-label the assignments according to cluster size
    new_labels = cluster_map[old_labels]

    # 4) Reorder the cluster centers the same way
    new_centers = centers[ind_sort, :]
    # 5) Reorder the distances to the new centers
    distances = kmeans.transform(coll_eigenvectors)

    # Return results in a dictionary (mimicking MATLAB struct)
    return {
        'IDX': new_labels,           # cluster labels (time course) after re-labeling
        'C': new_centers,            # cluster centroids
        'counts': counts[ind_sort],  # cluster sizes in descending order
        'distances': distances,      # distance of each point to each centroid
    }


In [None]:
# run k-means
k = 10
kmeans_results = run_leida_kmeans(coll_eigenvectors, k, n_init=50, max_iter=100, random_state=42)


# Plotting

### PCA

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

def plot_kmeans_3d_pca(collEigenvectors, labels, centers, title="K-means Clusters in PCA Space", max_points=1000):
    """
    Project eigenvector data (collEigenvectors) onto 3 principal components,
    then scatter-plot them in 3D, color-coded by cluster label. The cluster
    centers are also projected and plotted as black-edged triangles.
    
    Parameters
    ----------
    collEigenvectors : ndarray, shape (n_samples, n_features)
        All collated eigenvectors that were clustered.
    labels : ndarray, shape (n_samples,)
        Integer cluster labels for each row in collEigenvectors.
    centers : ndarray, shape (K, n_features)
        The cluster centroids (in original feature space).
    title : str
        Title for the plot (optional).
    max_points : int
        Maximum number of points to plot (will subsample if exceeded)
    """
    # Subsample data if too many points
    n_samples = collEigenvectors.shape[0]
    indices = np.arange(n_samples)
    
    if n_samples > max_points:
        print(f"Subsampling {max_points} points from {n_samples} total points for visualization")
        # Stratified sampling to maintain cluster proportions
        sampled_indices = []
        for k in np.unique(labels):
            cluster_indices = indices[labels == k]
            # Calculate proportional sample size for this cluster
            cluster_size = len(cluster_indices)
            sample_size = max(1, int(cluster_size * max_points / n_samples))
            # Sample indices from this cluster
            if sample_size < cluster_size:
                sampled_cluster_indices = np.random.choice(cluster_indices, size=sample_size, replace=False)
            else:
                sampled_cluster_indices = cluster_indices
            sampled_indices.extend(sampled_cluster_indices)
        
        # Convert to numpy array and shuffle
        sampled_indices = np.array(sampled_indices)
        np.random.shuffle(sampled_indices)
        
        # Limit to max_points (in case rounding caused more points)
        if len(sampled_indices) > max_points:
            sampled_indices = sampled_indices[:max_points]
            
        # Subsample the data and labels
        plot_vectors = collEigenvectors[sampled_indices]
        plot_labels = labels[sampled_indices]
    else:
        plot_vectors = collEigenvectors
        plot_labels = labels
    
    # 1) Run PCA to reduce dimensionality to 3
    pca = PCA(n_components=3)
    X_pca = pca.fit_transform(plot_vectors)     # shape: [sampled_points, 3]
    centers_pca = pca.transform(centers)        # shape: [K, 3]
    
    # 2) Plot the data, color-coded by cluster label
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    scatter = ax.scatter(
        X_pca[:, 0], 
        X_pca[:, 1], 
        X_pca[:, 2],
        c=plot_labels,                # use cluster labels as color
        cmap='rainbow',
        alpha=0.6,
        s=30                         # slightly larger points for better visibility
    )
    
    # 3) Plot the cluster centers as bigger, distinct markers
    K = centers.shape[0]
    ax.scatter(
        centers_pca[:, 0], 
        centers_pca[:, 1], 
        centers_pca[:, 2],
        c=np.arange(K),         # color them using the same colormap but indexed by cluster
        cmap='rainbow',
        marker='^',
        s=200,
        edgecolors='k',
        linewidths=1.5,
        label='Cluster centers'
    )
    
    # 4) Tidy up the figure
    ax.set_xlabel('PC1')
    ax.set_ylabel('PC2')
    ax.set_zlabel('PC3')
    ax.set_title(title)
    
    # Add explained variance information
    explained_var = pca.explained_variance_ratio_ * 100
    ax.text2D(0.02, 0.95, f"Explained variance: PC1={explained_var[0]:.1f}%, PC2={explained_var[1]:.1f}%, PC3={explained_var[2]:.1f}%", 
             transform=ax.transAxes)
    
    # Add info about subsampling if applicable
    if n_samples > max_points:
        ax.text2D(0.02, 0.90, f"Showing {max_points} of {n_samples} points", transform=ax.transAxes)
    
    # Colorbar with cluster label indices
    cbar = fig.colorbar(scatter, ax=ax, fraction=0.03, pad=0.07)
    cbar.set_label("Cluster Label")
    
    # Add a legend
    ax.legend(loc='upper right')
    
    # Add grid lines for better 3D perception
    ax.grid(True)
    
    plt.tight_layout()
    return fig

In [None]:
# plot the k-means results
plot_kmeans_3d_pca(
    coll_eigs, 
    kmeans_results['IDX'], 
    kmeans_results['C'], 
    title=f"K-means Clusters in PCA Space (K={k})",
    max_points=1000
)

In [None]:
print(f"Results for K={k}:\n", 
      f"Cluster sizes: {kmeans_results['counts']}\n",
      f"Cluster centers:\n{kmeans_results['C'][:k]}...\n", 
      f"Distances:\n{kmeans_results['distances'][:k]}...\n" 
)

### Cluster Centers

In [None]:
epochs = mne.read_epochs("../data/source/s_101_Coordination-source-epo.fif")
roi_names = epochs.ch_names  # Each name is the label of a parcellated ROI
print("ROI names:", roi_names)

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=k, ncols=1, figsize=(12, 3*k), sharex=True)
if k == 1:
    axes = [axes]

for i in range(k):
    ax = axes[i]
    center_vec = kmeans_results['C'][i, :]
    
    # Color bars red if positive, blue if negative
    colors = ['red' if val >= 0 else 'blue' for val in center_vec]
    
    ax.bar(
        x=range(len(center_vec)),
        height=center_vec,
        color=colors
    )
    ax.axhline(0, color='black', linewidth=1)  # horizontal line at 0
    
    ax.set_title(f"Cluster {i} center", fontsize=12)
    ax.set_ylabel("Eigenvector component")
    # Label the x-ticks with ROI names, rotate if needed
    ax.set_xticks(range(len(center_vec)))
    ax.set_xticklabels(roi_names, rotation=90, fontsize=9)
    
fig.tight_layout()
plt.show()


In [None]:
# Dictionary that groups the Desikan–Killiany atlas ROIs into lobes/regions
lobe_groups = {
    "Frontal": [
        "caudalmiddlefrontal-lh", "caudalmiddlefrontal-rh",
        "lateralorbitofrontal-lh", "lateralorbitofrontal-rh",
        "medialorbitofrontal-lh", "medialorbitofrontal-rh",
        "parsopercularis-lh", "parsopercularis-rh",
        "parsorbitalis-lh", "parsorbitalis-rh",
        "parstriangularis-lh", "parstriangularis-rh",
        "precentral-lh", "precentral-rh",
        "rostralmiddlefrontal-lh", "rostralmiddlefrontal-rh",
        "superiorfrontal-lh", "superiorfrontal-rh",
        "frontalpole-lh", "frontalpole-rh"
    ],
    "Parietal": [
        "inferiorparietal-lh", "inferiorparietal-rh",
        "superiorparietal-lh", "superiorparietal-rh",
        "supramarginal-lh", "supramarginal-rh",
        "postcentral-lh", "postcentral-rh",
        "precuneus-lh", "precuneus-rh",
        "paracentral-lh", "paracentral-rh"  # Often near boundary of frontal/parietal
    ],
    "Temporal": [
        "bankssts-lh", "bankssts-rh",
        "entorhinal-lh", "entorhinal-rh",
        "fusiform-lh", "fusiform-rh",
        "inferiortemporal-lh", "inferiortemporal-rh",
        "middletemporal-lh", "middletemporal-rh",
        "parahippocampal-lh", "parahippocampal-rh",
        "superiortemporal-lh", "superiortemporal-rh",
        "temporalpole-lh", "temporalpole-rh",
        "transversetemporal-lh", "transversetemporal-rh"
    ],
    "Occipital": [
        "cuneus-lh", "cuneus-rh",
        "lateraloccipital-lh", "lateraloccipital-rh",
        "lingual-lh", "lingual-rh",
        "pericalcarine-lh", "pericalcarine-rh"
    ],
    "Cingulate": [
        "caudalanteriorcingulate-lh", "caudalanteriorcingulate-rh",
        "isthmuscingulate-lh", "isthmuscingulate-rh",
        "posteriorcingulate-lh", "posteriorcingulate-rh",
        "rostralanteriorcingulate-lh", "rostralanteriorcingulate-rh"
    ],
    "Insula": [
        "insula-lh", "insula-rh"
    ]
}


In [None]:
import matplotlib.pyplot as plt

# Build a lookup: ROI name -> original index
roi_to_index = {roi: idx for idx, roi in enumerate(roi_names)}

# Define an *ordered* list of lobes for plotting,
# so we can control which lobe goes first, second, etc.
ordered_lobes = ["Frontal", "Parietal", "Temporal", "Occipital", "Cingulate", "Insula"]

# ---------------------------------------------------------------------------
# 1) Build a new ROI order that groups lobes consecutively
#    Also track the group boundaries so we know the start/end index
# ---------------------------------------------------------------------------
roi_names_ordered = []
group_boundaries = []  # list of (group_name, start_idx, end_idx)

start_idx = 0
for lobe_name in ordered_lobes:
    # All ROIs in this group that actually appear in roi_names
    rois_in_this_group = [r for r in lobe_groups[lobe_name] if r in roi_to_index]
    if not rois_in_this_group:
        # skip if no ROI is found for this group
        continue

    for roi in rois_in_this_group:
        roi_names_ordered.append(roi)

    end_idx = len(roi_names_ordered) - 1
    group_boundaries.append((lobe_name, start_idx, end_idx))

    start_idx = end_idx + 1

# Build an index array so we can reorder each cluster center
plot_indices = [roi_to_index[r] for r in roi_names_ordered]

# ---------------------------------------------------------------------------
# 2) Plot each cluster, reordering the ROI axis so that same-lobe ROIs are consecutive
# ---------------------------------------------------------------------------
fig, axes = plt.subplots(nrows=k, ncols=1, figsize=(12, 3*k))
if k == 1:
    axes = [axes]

for cluster_i, ax in enumerate(axes):
    # Original center vector
    center_vec = kmeans_results['C'][cluster_i, :]

    # Reorder to match our new grouping
    center_vec_ordered = center_vec[plot_indices]

    # Bar colors: red if >= 0, else blue
    bar_colors = ['red' if val >= 0 else 'blue' for val in center_vec_ordered]

    # Plot
    ax.bar(x=range(len(center_vec_ordered)),
           height=center_vec_ordered,
           color=bar_colors,
           alpha=0.7)
    ax.axhline(0, color='black', linewidth=1)

    ax.set_title(f"Cluster {cluster_i} center", fontsize=12)
    ax.set_ylabel("Eigenvector\ncomponent")
    ax.set_xticks(range(len(center_vec_ordered)))
    ax.set_xticklabels(roi_names_ordered, rotation=90, fontsize=8)

    # Color tick labels
    for idx, tick_label in enumerate(ax.get_xticklabels()):
        val = center_vec_ordered[idx]
        tick_label.set_color('red' if val >= 0 else 'blue')

    # Now shade groups. Because we've reordered, each lobe's ROIs are consecutive
    # We'll label the group name near the top
    ymax = max(center_vec_ordered.max(), 0)
    y_label_pos = 1.07 * ymax if ymax != 0 else 0.05  # if all negative, place label near 0

    for lobe_name, start_idx, end_idx in group_boundaries:
        ax.axvspan(start_idx - 0.4, end_idx + 0.4, alpha=0.12, color='gray')
        mid = 0.5 * (start_idx + end_idx)
        ax.text(mid, y_label_pos, lobe_name,
                ha='center', va='bottom', fontsize=9, color='black')

fig.tight_layout()
plt.show()


In [None]:
SUBJECTS_DIR = "~/mne_data/MNE-fsaverage-data/" 
labels = mne.read_labels_from_annot(
    subject="fsaverage",
    parc="aparc",
    subjects_dir=SUBJECTS_DIR
)
labels = labels[:-1]

Brain = mne.viz.get_brain_class()
brain = Brain(
    subject="fsaverage",
    hemi="both",
    surf="white",
    subjects_dir=SUBJECTS_DIR,
    background="white",
    size=(800, 600),
    alpha=0.3,
)

cluster_i = 0
cluster_vec = kmeans_results['C'][cluster_i, :]

for i, label in enumerate(labels):
    # Get the color for this cluster
    val = cluster_vec[i]
    if val > 0:
        color = 'red'
        # Set the color for this label
        label.color = color
        # Add the label to the brain
        brain.add_label(label, hemi="both", color=color, alpha=0.7)
        brain.add_label(label, hemi="both", color="black", alpha=0.8, borders=1)

# Show the brain
brain.show_view("lateral")

In [None]:
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap
import numpy as np
import mne
from matplotlib.colors import LinearSegmentedColormap

def visualize_brain_clusters(k, kmeans_results, roi_names, labels, 
                             view_angles=['lateral-left', 'lateral-right', 'dorsal', 'ventral'],
                             subjects_dir="~/mne_data/MNE-fsaverage-data/"):
    """
    Visualize brain clusters from k-means solutions by highlighting opposite-phase ROIs.
    
    Parameters:
    -----------
    k : int
        The number of clusters in the k-means solution to visualize
    kmeans_results : dict
        Dictionary containing k-means results, including 'C' for cluster centers
    roi_names : list
        List of ROI names corresponding to the dimensions of each cluster center
    labels : list
        List of MNE label objects for brain regions
    view_angles : list, optional
        List of view angles for brain visualization
    subjects_dir : str, optional
        Directory containing FreeSurfer subject data
        
    Returns:
    --------
    fig : matplotlib.figure.Figure
        The figure containing all brain visualizations
    """
    
    cluster_centers = kmeans_results['C']
    n_clusters = cluster_centers.shape[0]
    
    # Create a custom colormap with distinct vibrant colors (no greys or whites)
    colors = [
        '#e41a1c',  # red
        '#377eb8',  # blue
        '#4daf4a',  # green
        '#984ea3',  # purple
        '#ff7f00',  # orange
        '#ffff33',  # yellow
        '#a65628',  # brown
        '#f781bf',  # pink
        '#1b9e77',  # teal
        '#d95f02',  # vermillion
        '#7570b3',  # slate blue
        '#e6ab02',  # mustard
    ]
    # If more than 12 clusters, we'll cycle through these colors
    custom_cmap = LinearSegmentedColormap.from_list('custom_vibrant', colors[:n_clusters], N=n_clusters)
    
    # Prepare subplot grid
    fig, axes = plt.subplots(n_clusters, len(view_angles),
                           figsize=(2.5 * len(view_angles), 2 * n_clusters))
    
    for c in range(n_clusters):
        center_vec = cluster_centers[c, :]
        pos_idx = np.where(center_vec > 0)[0]
        print(f"Cluster {c}: {len(pos_idx)} minority ROIs")
    
        # Create a fresh Brain for this cluster
        Brain = mne.viz.get_brain_class()
        brain = Brain(
            subject="fsaverage",
            hemi="both",
            surf="pial",
            subjects_dir=subjects_dir,
            background="white",
            size=(400, 400),
            alpha=0.7
        )
    
        # Highlight only the positive ROIs
        color = colors[c % len(colors)]  # Use our vibrant colors
        for idx in pos_idx:
            lab = labels[idx]
            brain.add_label(lab, color=color, alpha=1.0, borders=False)
            brain.add_label(lab, color='black', alpha=0.8, borders=True)
    
        # Loop over each view, capture screenshot, and plot into the grid
        for j, view in enumerate(view_angles):
            if view == 'lateral-left':
                brain.show_view('lateral', hemi='lh')
            elif view == 'lateral-right':
                brain.show_view('lateral', hemi='rh')
            else:
                brain.show_view(view)
            
            img = brain.screenshot()
            ax = axes[c, j] if n_clusters > 1 else axes[j]
            ax.imshow(img)
            ax.axis('off')
            
            # Add column headers (view names) to the top row
            if c == 0:
                clean_name = view.replace('-', ' ').capitalize()
                ax.set_title(clean_name, fontsize=12)
            
            # Add row labels (cluster numbers) to the leftmost column
            if j == 0:
                # Create more prominent cluster labels with boxes
                ax.text(-0.15, 0.5, f"Cluster {c}", fontsize=12, 
                       ha='right', va='center', transform=ax.transAxes,
                       bbox=dict(facecolor=color, alpha=0.5, boxstyle='round,pad=0.5'))
    
        brain.close()
    
    # Add overall title
    fig.suptitle(f"LEiDA States (K={k})", fontsize=16, y=1.02)
    
    # Adjust layout to make room for the row labels
    plt.tight_layout()
    plt.subplots_adjust(left=0.1, wspace=0.05)
    
    # Add description at the bottom
    plt.figtext(0.5, -0.05, 
               f"Brain regions showing opposite-phase activity within each state.\n"
               f"Each row represents one of the {n_clusters} states identified by k-means clustering.",
               ha='center', fontsize=10)
    
    return fig



k = 10
roi_names = epochs.ch_names  # Each name is the label of a parcellated ROI
print("ROI names:", roi_names)

# Read Desikan–Killiany annotation from fsaverage
SUBJECTS_DIR = "~/mne_data/MNE-fsaverage-data/" 
labels = mne.read_labels_from_annot(
    subject="fsaverage",
    parc="aparc",
    subjects_dir=SUBJECTS_DIR
)
# remove "unknown" if needed
if labels and labels[-1].name == "unknown":
    labels = labels[:-1]

view_angles = ['lateral-left', 'lateral-right', 'dorsal', 'ventral']
fig = visualize_brain_clusters(k, kmeans_results, roi_names, labels, view_angles=view_angles)
plt.show()



### State Occurrence probability

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

chosen_k = 10
solution = kmeans_results
all_labels = solution['IDX']
time_points = len(all_labels)

# Compute occurrence probability for each cluster
prob_vec = [(all_labels == c).mean() for c in range(chosen_k)]

palette = sns.color_palette("husl", chosen_k)

plt.figure(figsize=(8, 4))
plt.bar(range(chosen_k), prob_vec, color=palette)
plt.xlabel("Cluster (State)")
plt.ylabel("Occurrence Probability")
plt.title(f"State Occurrence Probability (Single Subject, K={chosen_k})")
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

chosen_k = 10
solution = kmeans_results
palette = sns.color_palette("husl", chosen_k)

# 'labels_dict' contains labels per condition and subject.
conditions = ["Coordination", "Solo", "Spontaneous"]

# Calculate the occurrence probability for each condition separately
condition_probs = {}
for condition in conditions:
    labels_list = []
    for subject, labels in labels_dict[condition].items():
        labels_list.append(labels.flatten())
    labels_concat = np.hstack(labels_list)
    prob_vec = [(labels_concat == c).mean() for c in range(chosen_k)]
    condition_probs[condition] = prob_vec

# Plotting
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18, 4), sharey=True)

for idx, condition in enumerate(conditions):
    ax = axes[idx]
    prob_vec = condition_probs[condition]
    ax.bar(range(chosen_k), prob_vec, color=palette)
    ax.set_title(f"{condition}", fontsize=14)
    ax.set_xlabel("Cluster (State)", fontsize=12)
    if idx == 0:
        ax.set_ylabel("Occurrence Probability", fontsize=12)
    ax.set_xticks(range(chosen_k))
    ax.set_ylim(0, max(max(p) for p in condition_probs.values()) * 1.1)  # consistent y-axis

fig.suptitle(f"State Occurrence Probability per Condition (K={chosen_k})", fontsize=16, y=1.02)
fig.tight_layout()
plt.show()


In [None]:
import numpy as np

conditions = ["Coordination", "Solo", "Spontaneous"]
all_subjects = set()  # gather from your data

# Suppose we keep a nested structure: occ_prob[condition] = 2D array (n_subjects x K)
occ_prob = {cond: None for cond in conditions}

# 1) Build a consistent list of subjects, to ensure correct row ordering
for cond in conditions:
    all_subjects.update(labels_dict[cond].keys())
all_subjects = sorted(all_subjects)
n_subj = len(all_subjects)

# 2) Fill in the per-subject cluster probabilities
for cond in conditions:
    # array shape = (n_subj, K)
    arr = np.zeros((n_subj, k), dtype=np.float64)
    for si, subject in enumerate(all_subjects):
        if subject not in labels_dict[cond]:
            # e.g., if subject doesn't exist for that condition
            arr[si, :] = np.nan
            print(f"Warning: subject {subject} not found for condition {cond}. Filling with NaN.")
            continue
        
        labels_2d = labels_dict[cond][subject]  # shape (n_epochs, n_windows)
        all_labels_1d = labels_2d.flatten()      # shape (n_total_timepoints,)

        for c in range(k):
            arr[si, c] = np.mean(all_labels_1d == c)
    occ_prob[cond] = arr

print(f"Number of probability values per condition: {len(occ_prob[conditions[0]])}")
print(f"Shape of probability array for {conditions[0]}: {occ_prob[conditions[0]].shape}")


In [None]:
# import numpy as np
# import pandas as pd
# import seaborn as sns
# import matplotlib.pyplot as plt

# # Data preparation (already prepared by you)
# data_list = []
# for c in range(k):
#     for cond in conditions:
#         for subj_idx, subj in enumerate(all_subjects):
#             prob = occ_prob[cond][subj_idx, c]
#             if not np.isnan(prob):
#                 data_list.append({
#                     "Subject": subj,
#                     "Condition": cond,
#                     "Cluster": f"Cluster {c}",
#                     "Probability": prob
#                 })

# df = pd.DataFrame(data_list)

# # Define colors for conditions
# condition_colors = {
#     'Coordination': 'orchid',
#     'Solo': 'yellowgreen',
#     'Spontaneous': 'lightblue'
# }

# # Plot vertical raincloud plot per cluster
# for cluster in df['Cluster'].unique():
#     plt.figure(figsize=(8, 6))
#     cluster_df = df[df['Cluster'] == cluster]

#     conditions_order = ["Coordination", "Solo", "Spontaneous"]

#     for idx, cond in enumerate(conditions_order):
#         data_cond = cluster_df[cluster_df["Condition"] == cond]["Probability"].values

#         # Violin plot (density distribution)
#         vp = plt.violinplot(data_cond, positions=[idx], points=300,
#                             vert=True, widths=0.7,
#                             showmeans=False, showmedians=False, showextrema=False)

#         # Cut violin plot in half (right side only)
#         for b in vp['bodies']:
#             path = b.get_paths()[0].vertices
#             mean_x = np.mean(path[:, 0])
#             path[:, 0] = np.clip(path[:, 0], idx, idx + 0.4)
#             b.set_facecolor(condition_colors[cond])
#             b.set_alpha(0.5)
#             b.set_edgecolor('none')

#         # Boxplot
#         bp = plt.boxplot(data_cond, positions=[idx - 0.15], widths=0.1,
#                          patch_artist=True, vert=True,
#                          showcaps=False, boxprops=dict(facecolor=condition_colors[cond], alpha=0.7),
#                          medianprops=dict(color="k", linewidth=1.5),
#                          whiskerprops=dict(color=condition_colors[cond], linewidth=1.5),
#                          flierprops=dict(marker='o', color='gray', alpha=0.5))

#         # Scatter plot (rain drops)
#         x_jittered = np.random.uniform(idx - 0.35, idx - 0.2, size=len(data_cond))
#         plt.scatter(x_jittered, data_cond, color=condition_colors[cond], alpha=0.6, s=15)

#     plt.xticks(range(len(conditions_order)), conditions_order, fontsize=12)
#     plt.ylabel("Occurrence Probability", fontsize=12)
#     plt.title(f"Raincloud Plot for {cluster}", fontsize=14)
#     sns.despine(trim=True)
#     plt.grid(axis='y', linestyle='--', alpha=0.5)
#     plt.tight_layout()
#     plt.show()


In [None]:
from scipy.stats import ttest_rel
import numpy as np

# Define your condition pairs clearly
condition_pairs = [
    ("Coordination", "Solo"),
    ("Coordination", "Spontaneous"),
    ("Solo", "Spontaneous")
]

alpha = 0.05
k = occ_prob['Coordination'].shape[1]  # number of clusters

# Run paired t-tests for each pair of conditions and each cluster
for cond1, cond2 in condition_pairs:
    print(f"\nPaired t-test between '{cond1}' and '{cond2}':")
    p_values = []

    for c in range(k):
        # Get paired data for the current cluster
        data1 = occ_prob[cond1][:, c]
        data2 = occ_prob[cond2][:, c]

        # Only consider subjects without missing data
        mask = ~np.isnan(data1) & ~np.isnan(data2)
        data1_clean, data2_clean = data1[mask], data2[mask]

        # Perform paired t-test
        t_stat, p_val = ttest_rel(data1_clean, data2_clean)
        p_values.append(p_val)

    # Multiple-comparison correction (e.g., Bonferroni or FDR)
    from statsmodels.stats.multitest import fdrcorrection
    rejected, p_values_corrected = fdrcorrection(p_values, alpha=alpha)

    # Print results per cluster
    for c, (p_uncorrected, p_corr, reject) in enumerate(zip(p_values, p_values_corrected, rejected)):
        sig_str = "✅ significant" if reject else "❌ not significant"
        print(f"  Cluster {c}: p={p_uncorrected:.4f}, corrected p={p_corr:.4f} → {sig_str}")


In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import ttest_rel
from statsmodels.stats.multitest import fdrcorrection

# Data preparation (already prepared by you)
data_list = []
for c in range(k):
    for cond in conditions:
        for subj_idx, subj in enumerate(all_subjects):
            prob = occ_prob[cond][subj_idx, c]
            if not np.isnan(prob):
                data_list.append({
                    "Subject": subj,
                    "Condition": cond,
                    "Cluster": f"Cluster {c}",
                    "Probability": prob
                })

df = pd.DataFrame(data_list)

# Colors for conditions
condition_colors = {
    'Coordination': 'orchid',
    'Solo': 'yellowgreen',
    'Spontaneous': 'lightblue'
}

condition_pairs = [("Coordination", "Solo"),
                   ("Coordination", "Spontaneous"),
                   ("Solo", "Spontaneous")]

alpha = 0.05

# Run paired t-tests and store significant results
significance_results = {}
for cond1, cond2 in condition_pairs:
    sig_list = []
    p_values = []

    for c in range(k):
        data1 = occ_prob[cond1][:, c]
        data2 = occ_prob[cond2][:, c]

        mask = ~np.isnan(data1) & ~np.isnan(data2)
        data1_clean, data2_clean = data1[mask], data2[mask]

        t_stat, p_val = ttest_rel(data1_clean, data2_clean)
        p_values.append(p_val)

    # FDR correction
    rejected, p_values_corrected = fdrcorrection(p_values, alpha=alpha)
    significance_results[(cond1, cond2)] = rejected

# Plotting raincloud plots with significance bars
for cluster_idx, cluster in enumerate(df['Cluster'].unique()):
    plt.figure(figsize=(8, 6))
    cluster_df = df[df['Cluster'] == cluster]

    conditions_order = ["Coordination", "Solo", "Spontaneous"]

    positions = np.arange(len(conditions_order))

    for idx, cond in enumerate(conditions_order):
        data_cond = cluster_df[cluster_df["Condition"] == cond]["Probability"].values

        # Violin plot (half)
        vp = plt.violinplot(data_cond, positions=[positions[idx]], points=300,
                            vert=True, widths=0.7,
                            showmeans=False, showmedians=False, showextrema=False)

        for b in vp['bodies']:
            path = b.get_paths()[0].vertices
            path[:, 0] = np.clip(path[:, 0], positions[idx], positions[idx] + 0.4)
            b.set_facecolor(condition_colors[cond])
            b.set_alpha(0.5)
            b.set_edgecolor('none')

        # Boxplot
        plt.boxplot(data_cond, positions=[positions[idx] - 0.15], widths=0.1,
                    patch_artist=True, vert=True,
                    showcaps=False, boxprops=dict(facecolor=condition_colors[cond], alpha=0.7),
                    medianprops=dict(color="k", linewidth=1.5),
                    whiskerprops=dict(color=condition_colors[cond], linewidth=1.5),
                    flierprops=dict(marker='o', color='gray', alpha=0.5))

        # Scatter plot (rain drops)
        x_jittered = np.random.uniform(positions[idx] - 0.35, positions[idx] - 0.2, size=len(data_cond))
        plt.scatter(x_jittered, data_cond, color='black', alpha=0.6, s=15)

    plt.xticks(positions, conditions_order, fontsize=12)
    plt.ylabel("Occurrence Probability", fontsize=12)
    plt.title(f"Raincloud Plot for {cluster}", fontsize=14)
    sns.despine(trim=True)
    plt.grid(axis='y', linestyle='--', alpha=0.5)

    # Add significance bars and asterisks
    ymax = cluster_df["Probability"].max()
    ystart = ymax + 0.03  # start a bit above the top data point
    ystep = 0.02  # space between bars

    pair_offsets = {("Coordination", "Solo"): (-0.1, 0.1),
                    ("Coordination", "Spontaneous"): (-0.1, 0.2),
                    ("Solo", "Spontaneous"): (0.1, 0.2)}

    for (cond1, cond2), significant_array in significance_results.items():
        if significant_array[cluster_idx]:
            idx1 = conditions_order.index(cond1)
            idx2 = conditions_order.index(cond2)
            y = ystart
            ystart += ystep  # increment for next bar

            # Draw the significance bar
            plt.plot([positions[idx1], positions[idx1], positions[idx2], positions[idx2]],
                     [y, y + ystep, y + ystep, y], lw=1.5, color='black')

            # Add asterisk at midpoint
            mid = (positions[idx1] + positions[idx2]) / 2
            plt.text(mid, y + ystep, "*", ha='center', va='bottom', fontsize=16)

    plt.ylim(0, ystart + 0.05)
    plt.tight_layout()
    plt.show()


In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import ttest_rel
from statsmodels.stats.multitest import fdrcorrection

# Conditions and pairs
conditions = ["Coordination", "Solo", "Spontaneous"]
condition_pairs = [
    ("Coordination", "Solo"),
    ("Coordination", "Spontaneous"),
    ("Solo", "Spontaneous")
]

# Compute paired t-tests
alpha = 0.05
k = occ_prob['Coordination'].shape[1]

# Store all results clearly
test_results = {}  # key: (cond1, cond2), value: list of (t_stat, p_value, corrected_p, significant)

for cond1, cond2 in condition_pairs:
    p_values = []
    t_stats = []
    
    for c in range(k):
        data1 = occ_prob[cond1][:, c]
        data2 = occ_prob[cond2][:, c]

        mask = ~np.isnan(data1) & ~np.isnan(data2)
        data1_clean, data2_clean = data1[mask], data2[mask]

        t_stat, p_val = ttest_rel(data1_clean, data2_clean)
        p_values.append(p_val)
        t_stats.append(t_stat)

    rejected, p_corrected = fdrcorrection(p_values, alpha=alpha)

    test_results[(cond1, cond2)] = [
        (t_stats[c], p_values[c], p_corrected[c], rejected[c]) for c in range(k)
    ]


In [None]:
# Prepare data into a DataFrame
data_list = []
for c in range(k):
    for cond in conditions:
        for subj_idx, subj in enumerate(all_subjects):
            prob = occ_prob[cond][subj_idx, c]
            if not np.isnan(prob):
                data_list.append({
                    "Subject": subj,
                    "Condition": cond,
                    "Cluster": f"Cluster {c}",
                    "Probability": prob
                })

df = pd.DataFrame(data_list)

# Colors for conditions
condition_colors = {'Coordination':'orchid','Solo':'yellowgreen','Spontaneous':'lightblue'}

# Plot vertical raincloud plots
for c in range(k):
    plt.figure(figsize=(8,6))
    cluster_df = df[df['Cluster'] == f'Cluster {c}']

    positions = np.arange(len(conditions))
    for idx, cond in enumerate(conditions):
        data_cond = cluster_df[cluster_df["Condition"] == cond]["Probability"].values

        # Violin plot (density)
        vp = plt.violinplot(data_cond, positions=[idx], points=300, vert=True, widths=0.7,
                            showmeans=False, showmedians=False, showextrema=False)
        for b in vp['bodies']:
            path = b.get_paths()[0].vertices
            path[:,0] = np.clip(path[:,0], idx, idx+0.4)
            b.set_facecolor(condition_colors[cond])
            b.set_alpha(0.5)
            b.set_edgecolor('none')

        # Boxplot
        plt.boxplot(data_cond, positions=[idx-0.15], widths=0.1, vert=True,
                    showcaps=False, patch_artist=True,
                    boxprops=dict(facecolor=condition_colors[cond],alpha=0.7),
                    medianprops=dict(color='k'))

        # Scatter points
        jitter = np.random.uniform(idx-0.35, idx-0.2, size=len(data_cond))
        plt.scatter(jitter, data_cond, color='black', alpha=0.6, s=15)

    plt.xticks(positions, conditions, fontsize=12)
    plt.ylabel("Occurrence Probability", fontsize=12)
    plt.title(f"Raincloud Plot for Cluster {c}", fontsize=14)
    sns.despine(trim=True)
    plt.grid(axis='y', linestyle='--', alpha=0.5)

    # Add significance annotations
    y_max = cluster_df['Probability'].max()
    y_start = y_max + 0.02
    y_step = (y_max * 0.05)

    for idx, (cond1, cond2) in enumerate(condition_pairs):
        t_stat, p_val, p_corr, significant = test_results[(cond1, cond2)][c]

        # Select positions
        x1, x2 = positions[conditions.index(cond1)], positions[conditions.index(cond2)]
        y = y_start + idx * y_step

        # Significance stars
        if p_corr < 0.001:
            stars = '***'
        elif p_corr < 0.01:
            stars = '**'
        elif p_corr < 0.05:
            stars = '*'
        else:
            stars = 'ns'

        # Draw lines for significance
        if significant:
            plt.plot([x1, x1, x2, x2], [y, y+y_step/4, y+y_step/4, y], lw=1.5, color='grey', alpha=0.7)
            plt.text((x1+x2)/2, y+y_step/4, f"{stars}",
                     ha='center', va='bottom', fontsize=20, color='r')
            plt.text((x1+x2)/2, y+y_step/4, f"p = {p_corr:.5f}",
                     ha='center', va='bottom', fontsize=8, color='grey')

    plt.tight_layout()
    plt.show()


### Dwell Time

In [None]:
import numpy as np

def compute_dwell_time_single_subject(labels, chosen_k):
    """
    For each state, compute average dwell time for the single subject:
    (total # of windows in that state) / (# of runs).
    """
    dwell_times = np.zeros(chosen_k)
    for state in range(chosen_k):
        idx = np.where(labels == state)[0]
        if len(idx) == 0:
            dwell_times[state] = 0
            continue
        # Number of consecutive runs
        diffs = np.diff(idx)
        run_breaks = np.where(diffs > 1)[0]  
        num_runs = len(run_breaks) + 1
        dwell_times[state] = len(idx) / num_runs
    return dwell_times

# Example usage:
import matplotlib.pyplot as plt
import seaborn as sns

solution = kmeans_results
all_labels = solution['IDX']

dwell_times = compute_dwell_time_single_subject(all_labels, k)

plt.figure(figsize=(8, 4))
sns.barplot(x=list(range(k)), y=dwell_times)
plt.xlabel("State")
plt.ylabel("Dwell Time (avg. consecutive windows in state)")
plt.title(f"Dwell Time per State (Single Subject, K={chosen_k})")
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from scipy.stats import ttest_rel
from statsmodels.stats.multitest import fdrcorrection

# Conditions and pairs
conditions = ["Coordination", "Solo", "Spontaneous"]
condition_pairs = [
    ("Coordination", "Solo"),
    ("Coordination", "Spontaneous"),
    ("Solo", "Spontaneous")
]

alpha = 0.05
k = dwell_time_dict['Coordination'].shape[1]

# Perform paired t-tests across conditions per cluster
test_results = {}

for cond1, cond2 in condition_pairs:
    p_values = []
    for state in range(k):
        data1 = dwell_time_dict[cond1][:, state]
        data2 = dwell_time_dict[cond2][:, state]

        mask = ~np.isnan(data1) & ~np.isnan(data2)
        t_stat, p_val = ttest_rel(data1[mask], data2[mask])
        p_values.append(p_val)

    # Multiple comparisons correction
    rejected, p_corrected = fdrcorrection(p_values, alpha=alpha)
    test_results[(cond1, cond2)] = [(p_values[i], p_corrected[i], rejected[i]) for i in range(k)]

# Plot boxplots with significance bars per cluster
for state in range(k):
    plt.figure(figsize=(8, 6))

    # Prepare data
    dwell_data = []
    for cond in conditions:
        for subj_dwell in dwell_time_dict[cond]:
            dwell_data.append({
                "Condition": cond, 
                "Dwell_Time": subj_dwell[state]
            })
    df_dwell = pd.DataFrame(dwell_data)

    # Boxplot
    sns.boxplot(data=df_dwell, x="Condition", y="Dwell_Time",
                palette="Set2", linewidth=1.5)
    
    # Overlay individual points
    sns.stripplot(data=df_dwell, x="Condition", y="Dwell_Time",
                  color="black", alpha=0.6, jitter=True)

    plt.title(f"Dwell Time Distribution for Cluster {state}", fontsize=14)
    plt.ylabel("Average Dwell Time (windows)")
    plt.xlabel("Condition")
    sns.despine(trim=True)
    plt.grid(axis='y', linestyle='--', alpha=0.6)

    # Add significance bars
    y_max = df_dwell["Dwell_Time"].max()
    y_start = y_max + y_max*0.05
    y_step = y_max * 0.05

    pos_dict = {cond: idx for idx, cond in enumerate(conditions)}

    for idx, (cond1, cond2) in enumerate(condition_pairs):
        p_val, p_corr, reject = test_results[(cond1, cond2)][state]

        x1, x2 = pos_dict[cond1], pos_dict[cond2]
        y = y_start + idx * y_step

        # Significance stars
        if p_corr < 0.001:
            stars = '***'
        elif p_corr < 0.01:
            stars = '**'
        elif p_corr < 0.05:
            stars = '*'
        else:
            stars = 'ns'

        # Draw line and annotation
        plt.plot([x1, x1, x2, x2], [y, y+y_step/4, y+y_step/4, y], lw=1.5, color='k')
        plt.text((x1+x2)/2, y+y_step/4, f"{stars}\n(p={p_corr:.3f})",
                 ha='center', va='bottom', fontsize=10, color='k')

    plt.tight_layout()
    plt.show()


In [None]:
import numpy as np

def compute_transition_matrix_single_subject(labels, chosen_k):
    """
    Returns a chosen_k x chosen_k matrix, where entry (i, j)
    is the probability of transitioning from state i to j.
    """
    transition_mat = np.zeros((chosen_k, chosen_k))
    for t in range(len(labels) - 1):
        from_state = labels[t]
        to_state = labels[t + 1]
        transition_mat[from_state, to_state] += 1

    row_sums = transition_mat.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1
    transition_mat /= row_sums
    return transition_mat

# Example usage + plotting:
import matplotlib.pyplot as plt
import seaborn as sns

solution = kmeans_results
all_labels = solution['IDX']

transition_matrix = compute_transition_matrix_single_subject(all_labels, k)

plt.figure(figsize=(6, 5))
sns.heatmap(transition_matrix, annot=True, cmap="Blues", linewidths=0.5, fmt=".2f")
plt.xlabel("To State")
plt.ylabel("From State")
plt.title(f"State Transition Matrix (Single Subject, K={chosen_k})")
plt.xticks(ticks=np.arange(chosen_k)+0.5, labels=[str(i) for i in range(chosen_k)])
plt.yticks(ticks=np.arange(chosen_k)+0.5, labels=[str(i) for i in range(chosen_k)])
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def compute_transition_matrix(labels_epoch, chosen_k):
    """
    Compute transition matrix for a single epoch.
    Returns a chosen_k x chosen_k transition matrix.
    """
    transition_mat = np.zeros((chosen_k, chosen_k))
    for t in range(len(labels_epoch) - 1):
        from_state = labels_epoch[t]
        to_state = labels_epoch[t + 1]
        transition_mat[from_state, to_state] += 1

    # Normalize rows to get transition probabilities
    row_sums = transition_mat.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1  # Prevent division by zero
    transition_mat /= row_sums
    return transition_mat

# Now clearly average across epochs and subjects:
conditions = ["Coordination", "Solo", "Spontaneous"]

transition_matrices = {}

for cond in conditions:
    subject_matrices = []
    for subj in labels_dict[cond]:
        epochs_labels = labels_dict[cond][subj]  # shape: (n_epochs, n_windows)
        epoch_matrices = []
        for epoch_labels in epochs_labels:
            epoch_trans_mat = compute_transition_matrix(epoch_labels, k)
            epoch_matrices.append(epoch_trans_mat)
        # Average transition matrices across epochs for the current subject
        subj_mean_matrix = np.mean(epoch_matrices, axis=0)
        subject_matrices.append(subj_mean_matrix)
    # Average transition matrices across subjects for this condition
    condition_mean_matrix = np.mean(subject_matrices, axis=0)
    transition_matrices[cond] = condition_mean_matrix

# Plotting clearly with heatmaps:
for cond in conditions:
    plt.figure(figsize=(7,6))
    sns.heatmap(transition_matrices[cond], annot=True, cmap="Blues",
                linewidths=0.5, fmt=".2f", square=True,
                xticklabels=[f"{i}" for i in range(k)],
                yticklabels=[f"{i}" for i in range(k)])
    plt.title(f"Average State Transition Matrix ({cond})", fontsize=14)
    plt.xlabel("To State")
    plt.ylabel("From State")
    plt.tight_layout()
    plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

condition_pairs = [
    ("Coordination", "Solo"),
    ("Coordination", "Spontaneous"),
    ("Solo", "Spontaneous")
]

# Step 1: Compute all difference matrices and find max abs value
diff_matrices = {}
all_diffs = []

for cond1, cond2 in condition_pairs:
    diff = transition_matrices[cond1] - transition_matrices[cond2]
    diff_matrices[(cond1, cond2)] = diff
    all_diffs.append(np.abs(diff))

# Global symmetric colorbar limits
vmax = np.max(all_diffs)
vmin = -vmax

# Step 2: Plot each with same color scale
for (cond1, cond2), diff_matrix in diff_matrices.items():
    plt.figure(figsize=(7, 6))
    sns.heatmap(diff_matrix, annot=True, cmap="bwr", center=0, linewidths=0.5, fmt=".2f",
                vmin=vmin, vmax=vmax,
                xticklabels=[f"{i}" for i in range(k)],
                yticklabels=[f"{i}" for i in range(k)])
    plt.title(f"Transition Matrix Difference: {cond1} − {cond2}", fontsize=14)
    plt.xlabel("To State")
    plt.ylabel("From State")
    plt.tight_layout()
    plt.show()

