In [1]:
import os
import re
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mne


def split_dataframe_on_blink(df):
    """Split the DataFrame into segments based on 'blink' events in the 'Elements' column."""
    # Identify indices where the 'Elements' column equals "blink"
    blink_indices = df.index[df["Elements"] == "blink"].tolist()

    segments = []
    start = 0
    # Iterate over each blink index to split the DataFrame
    for blink_idx in blink_indices:
        # Include the blink row in the current segment
        segment = df.iloc[start:blink_idx].reset_index(drop=True)
        segments.append(segment)
        # Update the starting index for the next segment to the row following the blink
        start = blink_idx + 1

    # If there are rows after the last blink, add them as the final segment
    if start < len(df):
        segments.append(df.iloc[start:].reset_index(drop=True))

    assert len(segments) > 0

    return segments


def compute_real_sfreq(df):
    """Compute the real sampling frequency from the TimeStamp column."""

    # Calculate sampling frequency from TimeStamp column
    # Convert timestamps from string format (HH:MM:SS.mmm) to seconds
    def time_to_seconds(time_str):
        # Split the time string into hours, minutes, and seconds (with milliseconds)
        hours, minutes, seconds = time_str.split(":")
        # Convert to total seconds
        total_seconds = (int(hours) * 3600) + (int(minutes) * 60) + float(seconds)
        return total_seconds

    # Apply conversion to all timestamps
    timestamps_seconds = df["TimeStamp"].apply(time_to_seconds).values

    # Calculate time differences between consecutive samples (in seconds)
    time_diffs = np.diff(timestamps_seconds)

    # Calculate sampling frequency as 1 / average time difference
    avg_time_diff = np.mean(time_diffs)
    sfreq = 1.0 / avg_time_diff if avg_time_diff > 0 else 256.0

    return sfreq


def load_csv_to_raw(csv_file, sfreq, do_compute_sfreq=True):
    """Load an EEG CSV file and convert it to an MNE Raw object."""

    df = pd.read_csv(csv_file)

    computed_sfreqs = []
    if do_compute_sfreq:
        dfs = split_dataframe_on_blink(df.copy())
        for sdf in dfs:
            without_blink_freq = compute_real_sfreq(sdf)
            computed_sfreqs.append(without_blink_freq)
    else:
        computed_sfreqs.append(sfreq)

    if "Elements" in df.columns:
        df = df[df["Elements"] != "blink"]

    # First column is TimeStamp, exclude last two columns (Battery, Elements)
    data = df.iloc[:, 1:-2].values.T  # shape becomes (n_channels, n_samples)
    ch_names = df.columns[1:-2].tolist()

    # Create channel info; all channels are EEG
    info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=["eeg"] * len(ch_names))
    raw = mne.io.RawArray(data, info, verbose=None)
    return raw, computed_sfreqs


def load_all_data(base_dir, sfreq, do_compute_sfreq):
    """Load all EEG data from the given base directory."""
    categories = ["Read", "See", "Translate", "Update"]

    computed_sfreqs = []

    data = {}
    for cat in categories:
        cat_data = {}
        cat_dir = os.path.join(base_dir, cat)
        csv_files = glob.glob(os.path.join(cat_dir, "*.csv"))
        print(f'Category "{cat}": found {len(csv_files)} files.')
        for i, f in enumerate(csv_files):
            if i > 10:
                break
            person, order, seq = re.search(r"(P\d{2})-(\d{2})-(S\d{3})", f).groups()
            if person not in cat_data:
                cat_data[person] = {}

            raw, computed_sfreq = load_csv_to_raw(f, sfreq, do_compute_sfreq)
            computed_sfreqs.extend(computed_sfreq)
            cat_data[person][seq] = raw

        data[cat] = cat_data

    computed_sfreqs = np.array(computed_sfreqs)
    print(computed_sfreqs)
    freq = {
        "expected": sfreq,
        "computed_avg": float(np.average(computed_sfreqs)),
        "computed_std": float(np.std(computed_sfreqs)),
        "computed_min": float(np.min(computed_sfreqs)),
        "computed_max": float(np.max(computed_sfreqs)),
    }

    return data, freq


def create_epochs_from_global_raw(global_raw, events, event_id, tmin=0, tmax=None):
    """Create an Epochs object from the global Raw dataset."""
    if tmax is None:
        event_samples = events[:, 0]
        # Append the total number of samples to compute the last trial length
        trial_lengths = np.diff(np.append(event_samples, global_raw.n_times))
        min_length = np.min(trial_lengths)
        tmax = (min_length - 1) / global_raw.info["sfreq"]

    epochs = mne.Epochs(global_raw, events, event_id, tmin=tmin, tmax=tmax, baseline=None, preload=True)
    return epochs


def create_global_raw(data, sfreq):
    """Construct a global Raw object from a hierarchical dataset."""
    all_raws = []
    events = []
    event_id = {}
    current_sample = 0  # This keeps track of the sample index in the concatenated raw.
    next_event_code = 1

    for category in data:
        if category not in event_id:
            event_id[category] = next_event_code
            next_event_code += 1
        for person in data[category]:
            # Ensure consistent ordering of sequences (trials)
            for seq in sorted(data[category][person].keys()):
                raw = data[category][person][seq]
                n_samples = raw.n_times
                # Append an event at the start of this trial
                events.append([current_sample, 0, event_id[category]])
                all_raws.append(raw)
                current_sample += n_samples

    global_raw = mne.concatenate_raws(all_raws)
    events = np.array(events)
    return global_raw, events, event_id

In [2]:
def compute_average_signal(raw_list):
    """Compute the trial-average (ERP) from a list of Raw objects."""
    # Extract data from each trial
    data_arrays = [raw.get_data() for raw in raw_list]
    data_stack = np.stack(data_arrays, axis=0)  # shape: (n_trials, n_channels, n_samples)
    avg_signal = np.mean(data_stack, axis=0)  # mean over trials
    return avg_signal


def plot_raw_data(raw):
    """Visualize raw EEG data using MNE's built-in plotting function."""
    raw.plot(n_channels=len(raw.ch_names), show=True)


def plot_average_signal(avg_signal, sfreq, ch_names, category_name):
    """Plot the average EEG signal for a given category."""
    n_samples = avg_signal.shape[1]
    times = np.arange(n_samples) / sfreq  # convert sample indices to time (s)

    plt.figure(figsize=(12, 6))
    for idx, ch in enumerate(ch_names):
        plt.plot(times, avg_signal[idx, :], label=ch)
    plt.title(f"Average EEG Signal for category: {category_name}")
    plt.xlabel("Time (s)")
    plt.ylabel("Amplitude (µV)")
    plt.legend()
    plt.show()

In [None]:
COMPUTE_SFREQ = True
SFREQ = 256

data, freq = load_all_data("../ufal_emmt/preprocessed-data/eeg/", SFREQ, COMPUTE_SFREQ)

In [None]:
freq

In [None]:
data

In [None]:
# Create a global Raw dataset by concatenating individual trials
global_raw, events, event_id = create_global_raw(data, SFREQ)
print("Global Raw dataset created with shape:", global_raw.get_data().shape)
print("Event markers (first few):", events[:5])

# Create Epochs from the global Raw dataset
epochs = create_epochs_from_global_raw(global_raw, events, event_id)
print("Epochs object created with", len(epochs), "epochs.")

epochs

In [None]:
# ----- Analysis Section -----
def compute_global_erp(epochs):
    """Compute the global event-related potential (ERP) by averaging all epochs."""
    evoked = epochs.average()
    return evoked


def compute_category_erp(epochs, event_id):
    """Compute the average ERP for each category."""
    erp_dict = {}
    for category in event_id.keys():
        category_epochs = epochs[category]
        erp_dict[category] = category_epochs.average()
    return erp_dict


def compute_power_spectral_density(raw, fmin=1.0, fmax=50.0):
    """
    Compute the power spectral density (PSD) of the global raw signal using Welch's method.
    """
    psds, freqs = mne.time_frequency.psd_welch(raw, fmin=fmin, fmax=fmax)
    return psds, freqs


def compute_time_frequency(epochs, freqs=np.arange(6, 30, 2), n_cycles=2, use_fft=True, return_itc=False):
    """Compute time-frequency representations (TFR) of the epochs using Morlet wavelets."""
    power = mne.time_frequency.tfr_morlet(
        epochs,
        freqs=freqs,
        n_cycles=n_cycles,
        use_fft=use_fft,
        return_itc=return_itc,
        decim=3,
    )
    return power


def compute_inter_trial_variability(epochs):
    """Compute the inter-trial variability (standard deviation) across epochs for each channel."""
    data_stack = epochs.get_data()  # shape: (n_epochs, n_channels, n_times)
    variability = np.std(data_stack, axis=0)
    return variability


# 1. Compute and plot the global ERP
global_erp = compute_global_erp(epochs)
print("Global ERP")
global_erp.plot()

# 2. Compute category-specific ERP and plot one example (e.g., 'Read')
category_erps = compute_category_erp(epochs, event_id)
if "Read" in category_erps:
    print("ERP for Category: Read")
    category_erps["Read"].plot()

# 3. Compute and plot the power spectral density of the global raw signal
psds, freqs = compute_power_spectral_density(global_raw, fmin=1.0, fmax=50.0)
plt.figure(figsize=(10, 5))
plt.plot(freqs, psds.mean(axis=0))
plt.title("Global Power Spectral Density")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Power")
plt.show()

# 4. Compute time-frequency representation using Morlet wavelets
tfr = compute_time_frequency(epochs, freqs=np.arange(6, 30, 2), n_cycles=2)
tfr.plot(baseline=(None, 0), mode="logratio", title="Time-Frequency Representation")

# 5. Compute inter-trial variability and plot for the first channel
variability = compute_inter_trial_variability(epochs)
times = epochs.times
plt.figure(figsize=(10, 5))
plt.plot(times, variability[0, :])
plt.title("Inter-Trial Variability (First Channel)")
plt.xlabel("Time (s)")
plt.ylabel("Standard Deviation (µV)")
plt.show()