In [None]:
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt, hilbert
import os
from sklearn.cluster import KMeans

## `LEiDA_EEG_eigenvectors.m`

#### Helper Functions - Filtering

In [None]:
def butter_bandpass(lowcut, highcut, fs, order=6):
    """
    Construct bandpass filter coefficients for a Butterworth filter.
    
    Parameters
    ----------
    lowcut : float
        Low cutoff frequency (Hz).
    highcut : float
        High cutoff frequency (Hz).
    fs : float
        Sampling frequency in Hz.
    order : int, optional
        The order of the Butterworth filter. Default is 6.
    
    Returns
    -------
    b, a : ndarray
        Numerator (b) and denominator (a) polynomials of the filter.
    """
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return b, a


def bandpass_filter(data, fs, lowcut, highcut, order=6):
    """
    Apply zero-phase Butterworth bandpass filter to 1D data.
    
    Parameters
    ----------
    data : ndarray
        One-dimensional time series data (e.g., one ROI).
    fs : float
        Sampling frequency in Hz.
    lowcut : float
        Low cutoff frequency (Hz).
    highcut : float
        High cutoff frequency (Hz).
    order : int, optional
        The order of the Butterworth filter. Default is 6.

    Returns
    -------
    filtered_data : ndarray
        Filtered time series, same shape as input.
    """
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    filtered_data = filtfilt(b, a, data)
    return filtered_data


#### Eigenvectors

In [None]:
def compute_leading_eigenvectors(data, fs, window_size, freq_band='alpha', verbose=True, do_plots=False):
    """
    Replicates the MATLAB pipeline for:
      1) Bandpass filtering a multi-channel EEG time series (per ROI).
      2) Computing the Hilbert transform to extract instantaneous phases.
      3) Computing dynamic phase-locking (dPL) matrices in non-overlapping windows.
      4) Extracting the leading eigenvector from each dPL.

    Parameters
    ----------
    data : ndarray
        Shape [n_areas, n_timepoints]. Each row is the time series of one brain region.
    fs : float
        Sampling frequency in Hz.
    window_size : int
        Number of samples in each non-overlapping window (e.g., 250).
    freq_band : str, optional
        Which frequency band to use: 'alpha', 'beta', or 'gamma'. Default is 'alpha'.
    verbose : bool, optional
        If True, prints progress messages. Default is True.
    do_plots : bool, optional
        If True, shows intermediate plots (raw vs filtered signal, etc.) for debugging.
        Default is False.

    Returns
    -------
    lead_eigs : ndarray
        Array of leading eigenvectors, shape [n_windows-2, n_areas].
        (We skip the first and last window as in the MATLAB code.)
    """
    
    # --------------------------
    # 1) Determine filter band
    # --------------------------
    if freq_band == 'alpha':
        lowcut, highcut = 8, 12
    elif freq_band == 'beta':
        lowcut, highcut = 15, 25 # same as in other projects of mine
    elif freq_band == 'gamma':
        lowcut, highcut = 30, 80
    else:
        raise ValueError("freq_band must be 'alpha', 'beta', or 'gamma'.")
        
    if verbose:
        print(f"Filtering data ({data.shape[0]} areas, {data.shape[1]} timepoints) "
              f"from {lowcut} to {highcut} Hz, order=6.")
    
    n_areas, T = data.shape
    print(f"N areas: {n_areas}, T: {T}")
    
    # -------------------------------------------
    # 2) De-mean and filter each ROI separately
    # -------------------------------------------
    # Subtract the mean per channel (as in the MATLAB code)
    data_demean = data - np.mean(data, axis=1, keepdims=True)
    
    filtered_data = np.zeros_like(data_demean)
    for i in range(n_areas):
        filtered_data[i, :] = bandpass_filter(data_demean[i, :], fs,
                                             lowcut, highcut, order=6)
    
    if do_plots:
        # Plot an example channel (ROI 0) before and after filtering
        t = np.arange(T) / fs
        plt.figure(figsize=(10, 4))
        plt.plot(t, data_demean[0, :], label='Raw (demeaned)', alpha=0.7)
        plt.plot(t, filtered_data[0, :], label='Filtered', alpha=0.7)
        plt.xlim([0, min(10.0, t[-1])])  # zoom in on the first second
        plt.legend()
        plt.title("ROI 0: Before and After Filtering")
        plt.xlabel("Time (s)")
        plt.show()

    # --------------------------------------------------------
    # 3) Compute Hilbert transform to get instantaneous phase
    # --------------------------------------------------------
    analytic_signal = hilbert(filtered_data, axis=1)
    phases = np.angle(analytic_signal)
    
    if do_plots:
        # Plot example channel's phase
        plt.figure(figsize=(10, 4))
        plt.plot(t, phases[0, :], label='Phase (ROI 0)')
        plt.title("Instantaneous Phase of Filtered Signal (ROI 0)")
        plt.xlabel("Time (s)")
        plt.ylabel("Phase (radians)")
        plt.xlim([0, min(10.0, t[-1])])  # zoom in
        plt.legend()
        plt.show()
    
    # -----------------------------------
    # 4) Windowing & dynamic phase-locking
    # -----------------------------------
    # In MATLAB code:
    #   repArray = 1:window_size:size(data,2)
    #   for t = 2 : (repetitions-1)
    # so effectively we skip the first and last windows.
    # We'll do the same to match their indexing logic.
    
    repArray = np.arange(0, T, window_size)  # e.g. [0, 250, 500, ...]
    repetitions = len(repArray)
    if T % window_size != 0 and verbose:
        print("Warning: discarding last incomplete window since T not multiple of window_size.")
    
    lead_eig_list = []
    example_iFC = None
    example_V1 = None
    
    for t_idx in range(1, repetitions - 1):
        start_idx = repArray[t_idx - 1]
        end_idx = repArray[t_idx]
        
        # Build iFC (dynamic phase-locking matrix) for this window
        iFC = np.zeros((n_areas, n_areas))
        for n in range(n_areas):
            for p in range(n_areas):
                # Extract the phase differences and compute the average of cos(differences)
                diffs = phases[n, start_idx:end_idx] - phases[p, start_idx:end_idx]
                iFC[n, p] = np.mean(np.cos(diffs))
        
        # --------------------------------
        # 5) Leading eigenvector of iFC
        # --------------------------------
        # For a symmetric real matrix iFC, we can use np.linalg.eigh.
        vals, vecs = np.linalg.eigh(iFC)
        # Largest eigenvalue => last entry if ascending
        idx_max = np.argmax(vals)  # or we can just take [-1] if guaranteed sorted
        V1 = vecs[:, idx_max]
        lead_eig_list.append(V1)

        if t_idx == 10:
            example_iFC = iFC.copy()
            example_V1 = V1.copy()
    
    lead_eigs = np.array(lead_eig_list)  # shape [n_windows - 2, n_areas]
    
    if verbose:
        print(f"Computed {lead_eigs.shape[0]} leading eigenvectors "
              f"for {n_areas} areas with window_size={window_size}.")
    
    if do_plots and example_iFC is not None:
        # Plot the phase-locking matrix (dPL) as an image
        plt.figure(figsize=(6, 5))
        plt.imshow(example_iFC, cmap='bwr', aspect='auto', vmin=-1, vmax=1)
        plt.colorbar(label='Phase Coherence (mean cos(diff))')
        plt.title("Example Dynamic Phase-Locking Matrix (dPL)")
        plt.xlabel("Brain Region (ROI index)")
        plt.ylabel("Brain Region (ROI index)")
        plt.show()
        
        # Plot the corresponding leading eigenvector with sign preserved.
        plt.figure(figsize=(6, 4))
        markerline, stemlines, baseline = plt.stem(np.arange(n_areas), example_V1)
        plt.setp(markerline, marker='o', markersize=6, color='b')
        plt.setp(stemlines, color='b')
        plt.title("Leading Eigenvector (with sign) from Example dPL")
        plt.xlabel("Brain Region (ROI index)")
        plt.ylabel("Eigenvector Component")
        plt.show()

    return lead_eigs

In [None]:
def load_example_eeg_mat(file_path):
    """
    Load the exampleSourceEEG_8.mat file containing EEG data for 8 participants.
    
    Parameters
    ----------
    file_path : str
        Path to the .mat file.
    
    Returns
    -------
    data_list : list of ndarray
        List of length 8, where each element is an EEG data matrix of shape (N_areas, T).
    """
    mat_contents = scipy.io.loadmat(file_path)
    
    # Extract participants (p1, p2, ..., p8)
    participants = [mat_contents[f'p{i+1}'] for i in range(8)]
    
    # Check consistency
    for i, data in enumerate(participants):
        if data.ndim != 2:
            raise ValueError(f"Participant p{i+1} data is not 2D (found shape {data.shape}).")
    
    print(f"Loaded EEG data for {len(participants)} participants.")
    return participants

## Main pipeline

In [None]:
# ✅ Path to the provided dataset (update the path as needed)
file_path = '../data/LEiDA_EEG/exampleSourceEEG_8.mat'

if not os.path.exists(file_path):
    raise FileNotFoundError(f"File '{file_path}' not found. Please check the path.")

# Load the dataset
participants_data = load_example_eeg_mat(file_path)

# Parameters
fs = 250           # Sampling frequency in Hz
window_size = 50  # Window size in samples (200 ms windows)
freq_band = 'alpha'  # Frequency band: 'alpha', 'beta', 'gamma'

# Process each participant
all_leading_eigenvectors = []
for i, data in enumerate(participants_data):
    print(f"\n--- Processing Participant {i+1} ---")
    lead_vecs = compute_leading_eigenvectors(data, 
                                             fs, 
                                             window_size, 
                                             freq_band, 
                                             verbose=True, 
                                             do_plots=True)
    all_leading_eigenvectors.append(lead_vecs)

print("\n✅ Processing complete!")

In [None]:
print(participants_data[0].shape)

print(all_leading_eigenvectors[1].shape)

# 2. K-means clusters

In [None]:
def run_leida_kmeans(coll_eigenvectors, K, n_init=50, max_iter=200, random_state=None):
    """
    Replicates the essential steps of LEiDA_EEG_kmeans.m in Python:
      - K-means with 'sqeuclidean' (which is the usual sum-of-squares in scikit-learn).
      - Re-label clusters by size (largest cluster -> cluster #0, second largest -> #1, etc.).
    Returns a dict similar to the MATLAB struct.
    """

    # 1) Run k-means with the given K
    #    'algorithm'='lloyd' is standard, inertia_ is sum of squared distances
    kmeans = KMeans(n_clusters=K,
                    n_init=n_init,
                    max_iter=max_iter,
                    random_state=random_state,
                    # scikit-learn uses Euclidean distance (which is effectively
                    # the same "sqeuclidean" objective for sum-of-squared-distances).
                   )
    kmeans.fit(coll_eigenvectors)

    # labels_: shape [n_samples,], each sample assigned cluster 0..K-1
    old_labels = kmeans.labels_
    centers = kmeans.cluster_centers_

    # 2) Sort clusters by descending frequency (size)
    #    Count how many rows ended up in each cluster
    counts = np.bincount(old_labels, minlength=K)  # shape [K]
    # Sort them descending by cluster size
    ind_sort = np.argsort(counts)[::-1]  # largest cluster first
    # Create a mapping array: cluster 'ind_sort[0]' -> new label 0, etc.
    cluster_map = np.zeros(K, dtype=int)
    for new_label, old_label in enumerate(ind_sort):
        cluster_map[old_label] = new_label

    # 3) Re-label the assignments according to cluster size
    new_labels = cluster_map[old_labels]

    # 4) Reorder the cluster centers the same way
    new_centers = centers[ind_sort, :]

    # For completeness, we can compute per-point distance to each centroid
    # to replicate "D" in MATLAB (which is the distance from each point to each centroid).
    # By default, KMeans.transform() returns the Euclidean distance to each cluster center,
    # i.e. shape [n_samples, K]. For "sqeuclidean" in MATLAB, one might want to square these
    # distances manually if an exact match is needed. Below we keep them Euclidean:
    distances = kmeans.transform(coll_eigenvectors)

    # If you want the sum of squared distances to each cluster center, can do:
    # sumd_per_cluster = []
    # for c in range(K):
    #     # find points in cluster c
    #     mask = (new_labels == c)
    #     # sum of squared distances
    #     # (If you want the 'sqeuclidean' version, do distances[mask,c]**2)
    #     ssd = np.sum(distances[mask, c] ** 2)
    #     sumd_per_cluster.append(ssd)

    # Return results in a dictionary (mimicking MATLAB struct)
    return {
        'IDX': new_labels,           # cluster labels (time course) after re-labeling
        'C': new_centers,            # cluster centroids
        'counts': counts[ind_sort],  # cluster sizes in descending order
        'distances': distances,      # distance of each point to each centroid
        # optionally include sum-of-distances if needed
        # 'SUMD': sumd_per_cluster,
    }


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

def plot_kmeans_3d_pca(collEigenvectors, labels, centers, title="K-means Clusters in PCA Space"):
    """
    Project eigenvector data (collEigenvectors) onto 3 principal components,
    then scatter-plot them in 3D, color-coded by cluster label. The cluster
    centers are also projected and plotted as black-edged triangles.
    
    Parameters
    ----------
    collEigenvectors : ndarray, shape (n_samples, n_features)
        All collated eigenvectors that were clustered.
    labels : ndarray, shape (n_samples,)
        Integer cluster labels for each row in collEigenvectors.
    centers : ndarray, shape (K, n_features)
        The cluster centroids (in original feature space).
    title : str
        Title for the plot (optional).
    """
    # 1) Run PCA to reduce dimensionality to 3
    pca = PCA(n_components=3)
    X_pca = pca.fit_transform(collEigenvectors)   # shape: [n_samples, 3]
    centers_pca = pca.transform(centers)          # shape: [K, 3]
    
    # 2) Plot the data, color-coded by cluster label
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')
    
    scatter = ax.scatter(
        X_pca[:, 0], 
        X_pca[:, 1], 
        X_pca[:, 2],
        c=labels,                # use cluster labels as color
        cmap='rainbow',
        alpha=0.6
    )
    
    # 3) Plot the cluster centers as bigger, distinct markers
    K = centers.shape[0]
    ax.scatter(
        centers_pca[:, 0], 
        centers_pca[:, 1], 
        centers_pca[:, 2],
        c=np.arange(K),         # color them using the same colormap but indexed by cluster
        cmap='rainbow',
        marker='^',
        s=200,
        edgecolors='k',
        label='Cluster centers'
    )
    
    # 4) Tidy up the figure
    ax.set_xlabel('PC1')
    ax.set_ylabel('PC2')
    ax.set_zlabel('PC3')
    ax.set_title(title)
    # Colorbar with cluster label indices
    cbar = fig.colorbar(scatter, ax=ax, fraction=0.03, pad=0.07)
    cbar.set_label("Cluster Label")
    plt.tight_layout()
    plt.show()

In [None]:
# 1) Collate all subject data into one big array
collEigenvectors = np.vstack(all_leading_eigenvectors)
print("Shape of collEigenvectors:", collEigenvectors.shape)

In [None]:
%matplotlib inline
# 2) Try k-means with k from e.g. 4..10
rangeK = range(4, 11)
kmeans_solutions = {}

for k in rangeK:
    print(f"Running k-means for K={k} ...")
    res = run_leida_kmeans(collEigenvectors, K=k, n_init=50, max_iter=200)
    kmeans_solutions[k] = res
    plot_kmeans_3d_pca(collEigenvectors, res['IDX'], res['C'], title=f"K-means Clusters for K={k}")
    print(f"Finished K={k}.")
    print("Cluster sizes:", res['counts'])
    print("Cluster centers shape:", res['C'].shape)
    print("Cluster labels shape:", res['IDX'].shape)
    print("Distances shape:", res['distances'].shape)
    print("------\n")


In [None]:
print("\n✅ K-means clustering complete!")
print(f"Results for K={k}:\n", kmeans_solutions[k].keys())

In [None]:
print("Cluster assignments (IDX):", kmeans_solutions[k]['IDX'])


In [None]:
print("Cluster centers (C):", kmeans_solutions[k]['C'].shape)


In [None]:

print("Cluster counts:", kmeans_solutions[k]['counts'])

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
%matplotlib widget
# Choose fixed K
chosen_k = 10
solution = kmeans_solutions[chosen_k]
all_labels = solution['IDX']  # shape [total_time_points]

# Determine subject-specific time window lengths
subject_lengths = [arr.shape[0] for arr in all_leading_eigenvectors]
num_subjects = len(subject_lengths)

# Generate colormap for the states
palette = sns.color_palette("husl", chosen_k)

# Plot for each participant
fig, axes = plt.subplots(num_subjects, 1, figsize=(10, 2 * num_subjects), sharex=True, sharey=True)

start_idx = 0
for i, (length_i, ax) in enumerate(zip(subject_lengths, axes)):
    end_idx = start_idx + length_i
    subj_labels = all_labels[start_idx:end_idx]  # Get states for this subject

    # Use a scatter plot where color represents the state
    ax.scatter(range(length_i), subj_labels, c=[palette[l] for l in subj_labels], s=10, alpha=0.8)
    ax.set_ylabel(f"Subj {i+1}")
    ax.set_yticks(range(chosen_k))  # Ensure all labels appear
    start_idx = end_idx

# Final plot formatting
axes[-1].set_xlabel("Time Window")
fig.suptitle(f"State Transitions for Each Participant (K={chosen_k})", fontsize=14)
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Compute state occurrence probabilities per subject
subject_probabilities = []

start_idx = 0
for length_i in subject_lengths:
    end_idx = start_idx + length_i
    subj_labels = all_labels[start_idx:end_idx]  # Get labels for this subject
    prob_vec = [(subj_labels == c).mean() for c in range(chosen_k)]  # Compute probability
    subject_probabilities.append(prob_vec)
    start_idx = end_idx

subject_probabilities = np.array(subject_probabilities)  # shape: [num_subjects, chosen_k]

# Plot grouped bar chart
fig, ax = plt.subplots(figsize=(10, 6))
x = np.arange(num_subjects)  # Subject indices
width = 0.15  # Width of each bar

palette = sns.color_palette("husl", chosen_k)

for c in range(chosen_k):
    ax.bar(x + c * width, subject_probabilities[:, c], width, label=f"State {c}", color=palette[c])

ax.set_xlabel("Subjects")
ax.set_ylabel("Occurrence Probability")
ax.set_title(f"State Occurrence Probability for Each Subject (K={chosen_k})")
ax.set_xticks(x + (chosen_k / 2) * width)
ax.set_xticklabels([f"Subj {i+1}" for i in range(num_subjects)])
ax.legend(title="Brain State")
plt.show()


In [None]:
fig, ax = plt.subplots(figsize=(10, 6))
bottoms = np.zeros(num_subjects)

for c in range(chosen_k):
    ax.bar(x, subject_probabilities[:, c], bottom=bottoms, label=f"State {c}", color=palette[c])
    bottoms += subject_probabilities[:, c]  # Update bottom for stacking

ax.set_xlabel("Subjects")
ax.set_ylabel("Proportion of Time in State")
ax.set_title(f"State Distribution per Subject (K={chosen_k})")
ax.set_xticks(x)
ax.set_xticklabels([f"Subj {i+1}" for i in range(num_subjects)])
ax.legend(title="Brain State", bbox_to_anchor=(1,1))
plt.show()


In [None]:
import pandas as pd

# Convert to DataFrame for seaborn heatmap
df = pd.DataFrame(subject_probabilities, columns=[f"State {i}" for i in range(chosen_k)])
df.index = [f"Subj {i+1}" for i in range(num_subjects)]

plt.figure(figsize=(10, 5))
sns.heatmap(df, annot=True, cmap="Blues", linewidths=0.5)
plt.title(f"Occurrence Probability Heatmap (K={chosen_k})")
plt.xlabel("Brain State")
plt.ylabel("Subjects")
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def compute_dwell_time(all_labels, subject_lengths, chosen_k):
    """
    Compute the dwell time for each state and each subject.
    
    Parameters
    ----------
    all_labels : ndarray
        Cluster assignments for all time windows across subjects.
    subject_lengths : list
        Number of time windows per subject.
    chosen_k : int
        Number of clusters (states).

    Returns
    -------
    dwell_times : ndarray, shape [num_subjects, chosen_k]
        Average dwell time per subject for each state.
    """

    dwell_times = []

    start_idx = 0
    for length_i in subject_lengths:
        end_idx = start_idx + length_i
        subj_labels = all_labels[start_idx:end_idx]

        # Compute dwell time for each state
        dwell_time_per_state = []
        for state in range(chosen_k):
            # Find indices where state appears
            state_indices = np.where(subj_labels == state)[0]

            if len(state_indices) == 0:
                dwell_time_per_state.append(0)  # No occurrence of this state
                continue

            # Identify state runs
            diff = np.diff(state_indices)
            run_starts = np.where(diff > 1)[0]  # Where gaps exist in indices
            num_runs = len(run_starts) + 1  # Total runs including last

            # Average dwell time for this state
            dwell_time = len(state_indices) / num_runs
            dwell_time_per_state.append(dwell_time)

        dwell_times.append(dwell_time_per_state)
        start_idx = end_idx

    return np.array(dwell_times)  # Shape: [num_subjects, chosen_k]


# Compute dwell times
dwell_times = compute_dwell_time(all_labels, subject_lengths, chosen_k)


In [None]:
plt.figure(figsize=(10, 6))
sns.boxplot(data=dwell_times, palette="husl")
plt.xlabel("State")
plt.ylabel("Dwell Time (avg. time windows in state)")
plt.title(f"Dwell Time per State (K={chosen_k})")
plt.xticks(ticks=np.arange(chosen_k), labels=[f"State {i}" for i in range(chosen_k)])
plt.show()


In [None]:
plt.figure(figsize=(10, 5))
sns.heatmap(dwell_times, annot=True, cmap="Blues", linewidths=0.5)
plt.xlabel("Brain State")
plt.ylabel("Subjects")
plt.title(f"Dwell Time Heatmap (K={chosen_k})")
plt.xticks(ticks=np.arange(chosen_k), labels=[f"State {i}" for i in range(chosen_k)])
plt.yticks(ticks=np.arange(len(subject_lengths)), labels=[f"Subj {i+1}" for i in range(len(subject_lengths))])
plt.show()


In [None]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

def compute_transition_matrices(all_labels, subject_lengths, chosen_k):
    """
    Compute transition matrices for each subject.

    Parameters
    ----------
    all_labels : ndarray
        Cluster assignments for all time windows across subjects.
    subject_lengths : list
        Number of time windows per subject.
    chosen_k : int
        Number of clusters (states).

    Returns
    -------
    transition_matrices : ndarray, shape [num_subjects, chosen_k, chosen_k]
        State transition matrices per subject.
    """

    num_subjects = len(subject_lengths)
    transition_matrices = np.zeros((num_subjects, chosen_k, chosen_k))

    start_idx = 0
    for subj in range(num_subjects):
        length_i = subject_lengths[subj]
        end_idx = start_idx + length_i
        subj_labels = all_labels[start_idx:end_idx]  # Get labels for this subject

        # Build transition count matrix
        for t in range(len(subj_labels) - 1):
            from_state = subj_labels[t]
            to_state = subj_labels[t + 1]
            transition_matrices[subj, from_state, to_state] += 1

        # Normalize to get transition probabilities
        row_sums = transition_matrices[subj].sum(axis=1, keepdims=True)
        row_sums[row_sums == 0] = 1  # Avoid division by zero
        transition_matrices[subj] /= row_sums

        start_idx = end_idx

    return transition_matrices

# Compute transition matrices
transition_matrices = compute_transition_matrices(all_labels, subject_lengths, chosen_k)


In [None]:
avg_transition_matrix = np.mean(transition_matrices, axis=0)  # Mean over subjects

plt.figure(figsize=(8, 6))
sns.heatmap(avg_transition_matrix, annot=True, cmap="Blues", linewidths=0.5, fmt=".2f")
plt.xlabel("To State")
plt.ylabel("From State")
plt.title(f"Average State Transition Matrix (K={chosen_k})")
plt.xticks(ticks=np.arange(chosen_k), labels=[f"{i}" for i in range(chosen_k)])
plt.yticks(ticks=np.arange(chosen_k), labels=[f"{i}" for i in range(chosen_k)])
plt.show()


In [None]:
fig, axes = plt.subplots(4, len(subject_lengths)//4, figsize=(16, 24), sharex=True, sharey=True)
flat_axes = axes.flatten()

for i, ax in enumerate(flat_axes):
    if i < len(subject_lengths):
        sns.heatmap(transition_matrices[i], ax=ax, annot=True, cmap="Blues", linewidths=0.5, fmt=".2f")
        ax.set_title(f"Subj {i+1}")
        ax.set_xlabel("To State")
        ax.set_ylabel("From State")
        ax.set_xticks(np.arange(chosen_k))
        ax.set_yticks(np.arange(chosen_k))
        ax.set_xticklabels([f"{j}" for j in range(chosen_k)])
        ax.set_yticklabels([f"{j}" for j in range(chosen_k)])
    else:
        ax.axis('off')  # Hide empty subplots
plt.tight_layout()
plt.show()


In [None]:

# # 3) Probability of each cluster for each subject
# # We'll illustrate for k=6 as an example:
# chosen_k = 6
# res_chosen = kmeans_solutions[chosen_k]
# all_labels = res_chosen['IDX']  # shape [total_time_points]

# # figure out how many time windows each subject had:
# subject_lengths = [arr.shape[0] for arr in all_leading_eigs_per_subject]
# subject_probabilities = []

# start_idx = 0
# for i, length_i in enumerate(subject_lengths):
#     end_idx = start_idx + length_i
#     # cluster labels for subject i
#     subj_labels = all_labels[start_idx:end_idx]
#     # probability of each cluster c in 0..(k-1)
#     # note: after re-labeling in run_leida_kmeans, clusters are 0..k-1
#     prob_vec = [(subj_labels == c).mean() for c in range(chosen_k)]
#     subject_probabilities.append(prob_vec)
#     start_idx = end_idx

# subject_probabilities = np.array(subject_probabilities)  # shape [num_subjects, chosen_k]
# print("Probability that each subject belongs to each cluster:\n", subject_probabilities)

# # (At this point, you would do your group-level stats or store them,
# # just like the MATLAB code does with P(s,k,c).)
