In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Set seed for reproducibility
np.random.seed(0)

# Parameters
n_channels = 5
fs = 256  # Sampling rate (Hz)
duration = 5  # seconds
times = np.linspace(0, duration, int(fs * duration))

channels_names = ["1", "...", "62", "63", "64"]

# Generate mock EEG data
eeg_data = 0.8 * np.random.randn(n_channels, len(times)) + np.sin(2 * np.pi * 10 * times) * 0.1


fig, ax = plt.subplots(figsize=(8, 5))
offset = 5
for i in range(n_channels):
    if i == 1:
        continue
    ax.plot(times, eeg_data[i] + i * offset, color="#D72638", lw=1)

ax.set_xlabel("Time (s)")
ax.grid(axis='x', linestyle='--', alpha=0.7)
ax.set_ylabel("ROIs")
ax.set_yticks(np.arange(0, n_channels * offset, offset))
ax.set_yticklabels(channels_names)
ax.set_title("EEG Signal")
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.show()



In [None]:
from scipy.signal import butter, filtfilt

# Alpha band filtering as in the user's code
def butter_bandpass(lowcut, highcut, fs, order=6):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return b, a

def filter_alpha_band(data, fs):
    lowcut, highcut = 8, 12  # Alpha band
    b, a = butter_bandpass(lowcut, highcut, fs, order=6)
    filtered_data = np.zeros_like(data)
    for ch in range(data.shape[0]):
        demeaned = data[ch, :] - np.mean(data[ch, :])
        filtered_data[ch, :] = filtfilt(b, a, demeaned)
    return filtered_data

# Apply filter
eeg_data_alpha = filter_alpha_band(eeg_data, fs)

# Plot filtered EEG data
fig, ax = plt.subplots(figsize=(8, 5))
for i in range(n_channels):
    if i == 1:
        continue
    ax.plot(times, eeg_data_alpha[i]*3 + i * offset, color="#D72638", lw=1)
ax.set_xlabel("Time (s)")
ax.grid(axis='x', linestyle='--', alpha=0.7)
ax.set_ylabel("ROIs")
ax.set_yticks(np.arange(0, n_channels * offset, offset))
ax.set_yticklabels(channels_names)
ax.set_title("Bandpass Filtered EEG Signal")
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.show()


In [None]:
from scipy.signal import hilbert

# Hilbert transform to get analytic signal
analytic_signal = hilbert(eeg_data_alpha, axis=1)
phases = np.angle(analytic_signal)  # instantaneous phase

# Plot the phases
fig, ax = plt.subplots(figsize=(8, 5))
for i in range(n_channels):
    if i == 1:
        continue
    ax.plot(times, phases[i]*0.7 + i * offset, color="#FFAA43", lw=1)
ax.set_xlabel("Time (s)")
ax.grid(axis='x', linestyle='--', alpha=0.7)
ax.set_ylabel("ROIs")
ax.set_yticks(np.arange(0, n_channels * offset, offset))
ax.set_yticklabels(channels_names)
ax.set_title("Instantaneous Phase of EEG Signal")
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.show()

# Plot the amplitude envelope
# Compute amplitude envelope
amplitude = np.abs(analytic_signal)

# Plot the amplitude
fig, ax = plt.subplots(figsize=(8, 5))
for i in range(n_channels):
    if i == 1:
        continue
    ax.plot(times, amplitude[i]*10 + i * offset, color="#FF7886", lw=1)
ax.set_xlabel("Time (s)")
ax.grid(axis='x', linestyle='--', alpha=0.7)
ax.set_ylabel("ROIs")
ax.set_yticks(np.arange(0, n_channels * offset, offset))
ax.set_yticklabels(channels_names)
ax.set_title("Amplitude Envelope of EEG Signal")
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Example phase angles
theta1 = 2.1  # radians
theta2 = 1.2  # radians
delta = theta1 - theta2  # phase difference

# Normalize delta to -pi to pi
delta = (delta + np.pi) % (2 * np.pi) - np.pi

# Compute points on unit circle
t = np.linspace(0, 2 * np.pi, 400)
circle_x = np.cos(t)
circle_y = np.sin(t)

# Points for v1, v2, and difference vector
v1 = (np.cos(theta1), np.sin(theta1))
v2 = (np.cos(theta2), np.sin(theta2))
vd = (np.cos(delta), np.sin(delta))

# Plot
plt.figure(figsize=(6, 6))
# Plot unit circle with colored line
plt.plot(circle_x, circle_y, color='lightgray', linewidth=0.5)
# plt.plot(circle_x, circle_y, linestyle='-', linewidth=1)  # unit circle
plt.axhline(0, color='lightgray', linewidth=0.5)
plt.axvline(0, color='lightgray', linewidth=0.5)

# Plot vectors
plt.arrow(0, 0, v1[0], v1[1], head_width=0.05, length_includes_head=True, color='#FFAA43', label=r'$\theta_1$')
plt.arrow(0, 0, v2[0], v2[1], head_width=0.05, length_includes_head=True, color='#FFAA43', label=r'$\theta_2$')
plt.arrow(0, 0, vd[0], vd[1], head_width=0.08, length_includes_head=True, color='#4D7EA8', label=r'$\Delta\theta$')

# Projection of delta vector onto real axis
plt.plot([vd[0], vd[0]], [vd[1], 0], linestyle='--', color='gray')
plt.scatter([vd[0]], [0], color='#4D7EA8')
plt.text(vd[0], -0.1, r'$\cos(\theta_{i,s}-\theta_{j,s})$', ha='center', color="#4D7EA8")

# Annotations
plt.text(v1[0]*1.1, v1[1]*1.1, r'$\theta_{i,s}$', color='#FFAA43')
plt.text(v2[0]*1.1, v2[1]*1.1, r'$\theta_{j,s}$', color='#FFAA43')
plt.text(vd[0]*1.1, vd[1]*1.1, r'$\Delta\theta$', color='#4D7EA8')

# plt.xlabel('Real')
# plt.ylabel('Imaginary')
plt.title('Phase Difference')
plt.axis('equal')
plt.xticks([-1, 1])
plt.yticks([-1, 0, 1])
plt.grid(False)
# remove top and right spines
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.gca().spines['left'].set_position('zero')
plt.gca().spines['bottom'].set_position('zero')

plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt, hilbert
from sklearn.cluster import KMeans
import mne

class LEiDAEEGAnalyzer:
    """
    A class to perform the LEiDA (Leading Eigenvector Dynamics Analysis) pipeline on EEG data.

    The pipeline includes:
      1) Per-channel bandpass filtering.
      2) Hilbert transform to extract instantaneous phases.
      3) Dynamic phase-locking (dPL) matrix computation in fixed windows.
      4) Leading eigenvector extraction for each window.

    Attributes
    ----------
    fs : float
        Sampling frequency in Hz.
    freq_band : str
        Which frequency band to analyze: 'alpha', 'beta', or 'gamma'.
    window_size : int
        Number of samples in each non-overlapping window (e.g., 250).
    remove_edges : bool
        If True, skip the first and last window (as done in some MATLAB code).
        Otherwise, keep them.
    do_plots : bool
        If True, show optional diagnostic plots.
    verbose : bool
        If True, print progress messages.
    
    Methods
    -------
    set_frequency_band(freq_band)
        Update lowcut/highcut based on 'alpha', 'beta', or 'gamma'.
    filter_data(data)
        Filter each channel of input data using a zero-phase Butterworth filter.
    compute_hilbert_phases(filtered_data)
        Compute instantaneous phases via Hilbert transform (per channel).
    compute_leading_eigenvectors(data_3d)
        Run the LEiDA pipeline over an array of shape (n_epochs, n_channels, n_timepoints).
    _compute_windows(phases)
        Private helper to subdivide phases into windows, build dPL, and compute leading eigenvectors.
    plot_filter_example(original, filtered, fs, epoch_idx=0, channel_idx=0)
        (Optional) Plot an example channel before and after filtering.
    plot_phase_example(phases, fs, epoch_idx=0, channel_idx=0)
        (Optional) Plot an example channel’s phase.
    plot_example_dpl(iFC, V1)
        (Optional) Visualize a dynamic phase-locking matrix and the associated leading eigenvector.
    """

    def __init__(self,
                 fs: float,
                 freq_band: str = 'alpha',
                 window_size: int = 256,
                 remove_edges: bool = True,
                 do_plots: bool = False,
                 verbose: bool = True):
        """
        Parameters
        ----------
        fs : float
            Sampling frequency (Hz).
        freq_band : str, optional
            Desired frequency band ('alpha', 'beta', or 'gamma'). Default is 'alpha'.
        window_size : int, optional
            Window size in samples for dPL calculations. Default is 250.
        remove_edges : bool, optional
            If True, skip first and last window. Default is True.
        do_plots : bool, optional
            Whether to generate diagnostic plots. Default is False.
        verbose : bool, optional
            Whether to print status messages. Default is True.
        """
        self.fs = fs
        self.freq_band = freq_band
        self.window_size = window_size
        self.remove_edges = remove_edges
        self.do_plots = do_plots
        self.verbose = verbose

        # Track whether we've already shown each plot type
        self.did_plot_filter = False
        self.did_plot_phase = False
        self.did_plot_dpl = False

        # Initialize frequency band limits
        self.lowcut, self.highcut = 8, 12  # alpha defaults
        self.set_frequency_band(freq_band)

    def set_frequency_band(self, freq_band: str):
        """
        Set the lowcut/highcut frequency range based on the chosen freq_band.

        Parameters
        ----------
        freq_band : str
            Either 'alpha', 'beta', or 'gamma'.
        """
        if freq_band == 'alpha':
            self.lowcut, self.highcut = 8, 12
        elif freq_band == 'beta':
            self.lowcut, self.highcut = 15, 25
        elif freq_band == 'gamma':
            self.lowcut, self.highcut = 30, 80
        else:
            raise ValueError("freq_band must be 'alpha', 'beta', or 'gamma'.")
        self.freq_band = freq_band

        if self.verbose:
            print(f"Frequency band set to {freq_band} "
                  f"({self.lowcut}-{self.highcut} Hz).")

    def _butter_bandpass(self, lowcut: float, highcut: float, fs: float, order=6):
        """
        Construct bandpass filter coefficients for a Butterworth filter.
        """
        nyquist = 0.5 * fs
        low = lowcut / nyquist
        high = highcut / nyquist
        b, a = butter(order, [low, high], btype='band')
        return b, a

    def filter_data(self, data: np.ndarray) -> np.ndarray:
        """
        Apply zero-phase bandpass filtering to each channel of the input.

        Parameters
        ----------
        data : ndarray, shape (n_channels, n_timepoints)
            Single epoch of EEG data or a single trial.

        Returns
        -------
        filtered_data : ndarray, shape (n_channels, n_timepoints)
            The filtered data per channel.
        """
        b, a = self._butter_bandpass(self.lowcut, self.highcut, self.fs, order=6)
        filtered_data = np.zeros_like(data)
        for ch in range(data.shape[0]):
            # De-mean each channel first
            demeaned = data[ch, :] - np.mean(data[ch, :])
            filtered_data[ch, :] = filtfilt(b, a, demeaned)
        return filtered_data

    def compute_hilbert_phases(self, filtered_data: np.ndarray) -> np.ndarray:
        """
        Compute instantaneous phases using the Hilbert transform.

        Parameters
        ----------
        filtered_data : ndarray, shape (n_channels, n_timepoints)
            The bandpass filtered EEG data.

        Returns
        -------
        phases : ndarray, shape (n_channels, n_timepoints)
            Phases derived from Hilbert transform.
        """
        # Axis=1 means we're applying Hilbert transform along the time dimension
        analytic_signal = hilbert(filtered_data, axis=1)
        phases = np.angle(analytic_signal)
        return phases

    def _compute_windows(self, phases: np.ndarray, epoch_idx: int) -> np.ndarray:
        """
        Subdivide `phases` into non-overlapping windows of length `self.window_size`,
        build the dynamic phase-locking (dPL) matrix for each window,
        and return its leading eigenvector (unit-normalised, no sign flip).
        """
        n_chan, T = phases.shape
        W = self.window_size

        # indices of LEFT edges of every complete window -----------------
        starts = np.arange(0, T - W + 1, W)          # e.g. [0,256,512,...]
        if len(starts) == 0:                         # epoch shorter than one window
            return np.empty((0, n_chan))

        if self.remove_edges and len(starts) >= 3:   # need ≥3 windows to drop edges
            starts = starts[1:-1]                    # drop first & last

        lead_eigs = []
        for s in starts:
            seg = phases[:, s:s+W]                   # (n_chan, W)

            # ---------- vectorised iFC  -----------------
            # pairwise phase-diff tensor: (n_chan, n_chan, W)
            dphase = seg[:, None, :] - seg[None, :, :]
            iFC = np.cos(dphase).mean(axis=-1)       # (n_chan, n_chan)

            # ---------- leading eigenvector ----------------------------
            vals, vecs = np.linalg.eigh(iFC)         # ascending λ
            v1 = vecs[:, -1]                         # largest eigenvalue
            v1 /= np.linalg.norm(v1)                 # unit length
            lead_eigs.append(v1)

            # optional one-off plot
            if (
                self.do_plots
                and not self.did_plot_dpl
                and epoch_idx in [0,1,2,3,4]
                and len(lead_eigs) == 11             # e.g. 10th window (0-based)
            ):
                self.plot_example_dpl(iFC, v1)
        # self.did_plot_dpl = True

        return np.vstack(lead_eigs)                  # (n_windows, n_chan)


    def compute_leading_eigenvectors(self, data_3d: np.ndarray) -> np.ndarray:
        """
        Run the LEiDA pipeline over an array of shape (n_epochs, n_channels, n_timepoints).
        
        Parameters
        ----------
        data_3d : ndarray, shape (n_epochs, n_channels, n_timepoints)
            The EEG data to analyze. Each epoch is [n_channels, n_timepoints].
        
        Returns
        -------
        all_lead_eigs : ndarray, shape (n_epochs, n_windows, n_channels)
            Leading eigenvectors for each epoch and each window.
            If remove_edges=True, then n_windows = (#windows_of_epoch - 2) for each epoch.
        """
        n_epochs, n_channels, n_timepoints = data_3d.shape

        if self.verbose:
            print(f"Processing data with shape (epochs={n_epochs}, channels={n_channels}, "
                  f"timepoints={n_timepoints})")
            print(f"Window size = {self.window_size}, removing edges = {self.remove_edges}")
            print(f"Bandpass from {self.lowcut} to {self.highcut} Hz. Order=6")

        epoch_eig_list = []

        for ep_idx in range(n_epochs):
            epoch_data = data_3d[ep_idx, :, :]  # shape (n_channels, n_timepoints)

            # 1) Filter
            filtered_epoch = self.filter_data(epoch_data)

            # Plot filter example once (epoch=0, channel=0)
            if self.do_plots and not self.did_plot_filter and ep_idx == 0:
                self.plot_filter_example(epoch_data[0, :], filtered_epoch[0, :], self.fs, epoch_idx=0, channel_idx=0)
                self.did_plot_filter = True

            # 2) Hilbert phases
            phases = self.compute_hilbert_phases(filtered_epoch)

            # Plot phase example once (epoch=0, channel=0)
            if self.do_plots and not self.did_plot_phase and ep_idx == 0:
                self.plot_phase_example(phases, self.fs, epoch_idx=0, channel_idx=0)
                self.did_plot_phase = True

            # 3) Windowing & dynamic phase-locking + leading eigenvectors
            lead_eigs = self._compute_windows(phases, ep_idx)
            epoch_eig_list.append(lead_eigs)
            # print every 10 epochs
            if self.verbose and ep_idx % 10 == 0:
                print(f"Epoch {ep_idx}/{n_epochs} processed.")

        # Convert list of arrays to a single 3D array
        all_lead_eigs = np.stack(epoch_eig_list, axis=0)

        if self.verbose:
            print(f"\nCompleted LEiDA!!! Output shape = {all_lead_eigs.shape} "
                  "(epochs x windows x channels).")

        return all_lead_eigs

    # ----------------------- PLOTTING ROUTINES -----------------------
    def plot_filter_example(self, original, filtered, fs, epoch_idx=0, channel_idx=0):
        """
        Plot an example channel before and after filtering.

        Parameters
        ----------
        original : ndarray, shape (n_timepoints,)
            Original (demeaned) time series data for one channel.
        filtered : ndarray, shape (n_timepoints,)
            Filtered time series.
        fs : float
            Sampling frequency.
        epoch_idx : int, optional
            Epoch index (for labeling).
        channel_idx : int, optional
            Channel index (for labeling).
        """
        t = np.arange(len(original)) / fs
        plt.figure(figsize=(10, 4))
        plt.plot(t, original, label='Raw (demeaned)', alpha=0.7)
        plt.plot(t, filtered, label='Filtered', alpha=0.7)
        plt.xlim([0, min(10.0, t[-1])])  # zoom in up to 10s
        plt.legend()
        plt.title(f"Epoch {epoch_idx}, Channel {channel_idx}: Before/After Filtering")
        plt.xlabel("Time (s)")
        plt.show()

    def plot_phase_example(self, phases, fs, epoch_idx=0, channel_idx=0):
        """
        Plot an example channel's instantaneous phase.

        Parameters
        ----------
        phases : ndarray, shape (n_channels, n_timepoints)
            Phase array from Hilbert transform.
        fs : float
            Sampling frequency.
        epoch_idx : int, optional
            Epoch index (for labeling).
        channel_idx : int, optional
            Channel index (for labeling).
        """
        t = np.arange(phases.shape[1]) / fs
        plt.figure(figsize=(10, 4))
        plt.plot(t, phases[channel_idx, :], label=f'Phase (Ch={channel_idx})')
        plt.title(f"Epoch {epoch_idx}, Channel {channel_idx}: Instantaneous Phase")
        plt.xlabel("Time (s)")
        plt.ylabel("Phase (radians)")
        plt.xlim([0, min(10.0, t[-1])])
        plt.legend()
        plt.show()

    def plot_example_dpl(self, iFC, V1):
        """
        Plot a dynamic phase-locking (dPL) matrix and its leading eigenvector.

        Parameters
        ----------
        iFC : ndarray, shape (n_channels, n_channels)
            The phase-locking matrix for a given window.
        V1 : ndarray, shape (n_channels,)
            The leading eigenvector of iFC.
        """
        plt.figure(figsize=(6, 5))
        plt.imshow(iFC, cmap='bwr', aspect='auto', vmin=-1, vmax=1)
        plt.colorbar(label='Phase Coherence')
        plt.title("dPL Marix")
        plt.xlabel("ROI")
        plt.ylabel("ROI")
        plt.show()

        # plt.figure(figsize=(6, 4))
        # markerline, stemlines, baseline = plt.stem(np.arange(len(V1)), V1)
        # plt.setp(markerline, marker='o', markersize=6, color='b')
        # plt.setp(stemlines, color='b')
        # plt.title("Leading Eigenvector (with sign)")
        # plt.xlabel("ROI")
        # plt.ylabel("Eigenvector Component")
        # plt.show()

        fig, ax = plt.subplots(figsize=(3, 8))
        colors = ['red' if x >= 0 else 'blue' for x in V1]
        ax.barh(np.arange(len(V1)), V1, color=colors, height=0.8)
        ax.axvline(0, color='k', lw=.8)
        ax.set_title("Leading Eigenvector")
        ax.set_xlabel("Value")
        ax.set_ylabel("ROI")
        # ax.set_yticks([i for i in range(len(V1))], [f"{i}" for i in range(len(V1))])
        ax.set_xticks([-0.15, 0, 0.15])
        ax.invert_yaxis()
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)   
        plt.show()

        fig, ax = plt.subplots(figsize=(3, 8))
        colors = ['red' if x >= 0 else 'blue' for x in -V1]
        ax.barh(np.arange(len(V1)), -V1, color=colors, height=0.8)
        ax.axvline(0, color='k', lw=.8)
        ax.set_title("Leading Eigenvector")
        ax.set_xlabel("Value")
        ax.set_ylabel("ROI")
        # ax.set_yticks([i for i in range(len(V1))], [f"{i}" for i in range(len(V1))])
        ax.set_xticks([-0.15, 0, 0.15])
        ax.invert_yaxis()
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)   
        plt.show()

        # def plot_centers_bar(out_png, centers, roi_names=None):
    # if centers.ndim == 1:
    #     centers = centers.reshape(1, -1)
    # k, n = centers.shape
    # fig, axes = plt.subplots(1, k, figsize=(3 * k, max(4, 0.25 * n)), sharey=True)
    # if k == 1:
    #     axes = [axes]
    # for i, ax in enumerate(axes):
    #     v = centers[i]
    #     colors = ['red' if x >= 0 else 'blue' for x in v]
    #     ax.barh(range(n), v, color=colors, height=0.8)
    #     ax.axvline(0, color='k', lw=.8)
    #     ax.set_title(f"Center {i}", fontsize=10)
    #     ax.set_xlabel("Value", fontsize=8)
    #     ax.tick_params(axis='x', labelsize=7)
    #     ax.invert_yaxis()
    # if roi_names is not None and k > 0:
    #     axes[0].set_yticks(range(n))
    #     axes[0].set_yticklabels(roi_names, fontsize=8)
    # else:
    #     axes[0].set_yticks([])
    # fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    # fig.suptitle("Cluster Center Profiles", fontsize=14)
    # fig.savefig(out_png, dpi=150, bbox_inches='tight')
    # plt.close(fig)

# ------------------ Main ------------------#
if __name__ == "__main__":
    
    # epochs = mne.read_epochs("../data/archive/source/s_101_Coordination-source-epo.fif")
    epochs = mne.read_epochs_eeglab("../data/raw_eeg/raw_all/PPT1/s_101_Coordination.set")
    epochs.crop(tmin=0.0)


    data = epochs.get_data()
    print(f"Data shape: {data.shape}")  # Data shape: (87 epochs, 68 channels , 1280 samples)
    fs = epochs.info['sfreq']
    print(f"Sampling frequency: {fs} Hz")
    window_size = int(fs / 8)  # e.g. 125 ms window

    # Instantiate the analyzer
    leida = LEiDAEEGAnalyzer(fs=fs,
                             freq_band='alpha',
                             window_size=window_size,
                             remove_edges=False,
                             do_plots=True,
                             verbose=True)

    # Compute leading eigenvectors
    all_eigenvectors = leida.compute_leading_eigenvectors(data)
    print("Final shape of leading eigenvectors:", all_eigenvectors.shape)
