In [2]:
import glob
import os
import mne
import matplotlib.pyplot as plt
import plotly.graph_objects as go

In [17]:
def get_subjects(release: int):
    subject_paths = glob.glob(os.path.join(f"release{release}", "sub*"))
    subject_names = [os.path.basename(path) for path in subject_paths]
    return subject_names


def get_bdfs_for_subject(release: int, subject: str):
    path = f"release{release}/{subject}/eeg"
    bdf_files = glob.glob(os.path.join(path, "*.bdf"))
    return bdf_files


def get_dicts_for_release(release: int):
    r_subjects = get_subjects(release)

    subj_to_all_bdfs = {subj: get_bdfs_for_subject(release, subj) for subj in r_subjects}
    subj_to_split_bdfs = {}

    for subj in r_subjects:
        bdfs = get_bdfs_for_subject(release, subj)

        tasks = [os.path.basename(bdf).split("task-")[-1].replace("_eeg.bdf", "") for bdf in bdfs]

        task_to_bdfs = {tasks[i]: bdfs[i] for i in range(len(bdfs))}

        subj_to_split_bdfs[subj] = task_to_bdfs

    return subj_to_all_bdfs, subj_to_split_bdfs

def get_tasks_for_subj(release: int, subj: str):
    subj_to_all_bdfs, subj_to_split_bdfs = get_dicts_for_release(release)
    return list(subj_to_split_bdfs[subj].keys())


In [32]:
def plot_eeg_segment(
    release: int, 
    subj: str, 
    task: str, 
    times: list[float],
    eegs: list[int] = None
):
    start_time = times[0]
    end_time = times[1]
    
    subj_to_all_bdfs, subj_to_split_bdfs = get_dicts_for_release(release)
    print(subj_to_split_bdfs[subj][task])
    bdf_path = subj_to_split_bdfs[subj][task]

    raw = mne.io.read_raw_bdf(bdf_path, preload=False)
    raw.pick(["eeg", "eog"])

    ch_names = raw.ch_names
    if eegs is None:
        selected_idxs = range(len(ch_names))
    else:
        selected_idxs = []
        for num in eegs:
            if num == 0:  # 0 corresponds to Cz
                if "Cz" in ch_names:
                    selected_idxs.append(ch_names.index("Cz"))
                else:
                    raise ValueError("Channel 'Cz' not found in data.")
            else:
                ch_name = f"E{num}"
                if ch_name in ch_names:
                    selected_idxs.append(ch_names.index(ch_name))
                else:
                    raise ValueError(f"Channel '{ch_name}' not found in data.")

    sfreq = raw.info["sfreq"]
    start_sample = int(start_time * sfreq)
    stop_sample = int(end_time * sfreq)
    data, times = raw.get_data(start=start_sample, stop=stop_sample, return_times=True)

    offset = 50e-6
    fig = go.Figure()
    for i, ch_idx in enumerate(selected_idxs):
        fig.add_trace(go.Scatter(
            x=times, 
            y=data[ch_idx] + i * offset, 
            mode="lines", 
            name=ch_names[ch_idx]
        ))

    fig.update_layout(
        height=800,
        title=f"{subj}, {task} from {start_time}s to {end_time}s",
        xaxis_title="Time (s)",
        yaxis_title="EEG Channels (offset for clarity)",
        legend=dict(
            itemsizing='constant', 
            tracegroupgap=2, 
            yanchor="top", 
            y=1, 
            xanchor="left", 
            x=1.02
        )
    )
    fig.show()

In [33]:
temp_subj = "sub-NDARAC904DMU"

for task in get_tasks_for_subj(1, temp_subj):
    if "contrast" in task:
        plot_eeg_segment(
            release=1, 
            subj=temp_subj, 
            task=task, 
            times=[3,6],
            eegs=None
        )

release1/sub-NDARAC904DMU/eeg/sub-NDARAC904DMU_task-contrastChangeDetection_run-3_eeg.bdf
Extracting EDF parameters from /Users/carinaxguo/ESE 5380/eeg_challenge/release1/sub-NDARAC904DMU/eeg/sub-NDARAC904DMU_task-contrastChangeDetection_run-3_eeg.bdf...
BDF file detected
Setting channel info structure...
Creating raw.info structure...


release1/sub-NDARAC904DMU/eeg/sub-NDARAC904DMU_task-contrastChangeDetection_run-2_eeg.bdf
Extracting EDF parameters from /Users/carinaxguo/ESE 5380/eeg_challenge/release1/sub-NDARAC904DMU/eeg/sub-NDARAC904DMU_task-contrastChangeDetection_run-2_eeg.bdf...
BDF file detected
Setting channel info structure...
Creating raw.info structure...


release1/sub-NDARAC904DMU/eeg/sub-NDARAC904DMU_task-contrastChangeDetection_run-1_eeg.bdf
Extracting EDF parameters from /Users/carinaxguo/ESE 5380/eeg_challenge/release1/sub-NDARAC904DMU/eeg/sub-NDARAC904DMU_task-contrastChangeDetection_run-1_eeg.bdf...
BDF file detected
Setting channel info structure...
Creating raw.info structure...


In [37]:
def plot_overlay(
    release: int,
    subj1: str,
    task1: str,
    subj2: str,
    task2: str,
    times1: list[float],
    times2: list[float],
    eegs1: list[int] = None,
    eegs2: list[int] = None
):
    subj_to_all_bdfs, subj_to_split_bdfs = get_dicts_for_release(release)
    
    bdf_path1 = subj_to_split_bdfs[subj1][task1]
    raw1 = mne.io.read_raw_bdf(bdf_path1, preload=False)
    raw1.pick(["eeg", "eog"])
    ch_names1 = raw1.ch_names

    bdf_path2 = subj_to_split_bdfs[subj2][task2]
    raw2 = mne.io.read_raw_bdf(bdf_path2, preload=False)
    raw2.pick(["eeg", "eog"])
    ch_names2 = raw2.ch_names

    def get_selected_idxs(ch_names, eegs):
        if eegs is None:
            return range(len(ch_names))
        selected = []
        for num in eegs:
            if num == 0:
                if "Cz" in ch_names:
                    selected.append(ch_names.index("Cz"))
                else:
                    raise ValueError("Channel 'Cz' not found in data.")
            else:
                ch_name = f"E{num}"
                if ch_name in ch_names:
                    selected.append(ch_names.index(ch_name))
                else:
                    raise ValueError(f"Channel '{ch_name}' not found in data.")
        return selected

    idxs1 = get_selected_idxs(ch_names1, eegs1)
    idxs2 = get_selected_idxs(ch_names2, eegs2)

    sfreq1 = raw1.info["sfreq"]
    start_sample1 = int(times1[0] * sfreq1)
    stop_sample1 = int(times1[1] * sfreq1)
    data1, times1_data = raw1.get_data(start=start_sample1, stop=stop_sample1, return_times=True)
    times1_data = times1_data - times1_data[0]

    sfreq2 = raw2.info["sfreq"]
    start_sample2 = int(times2[0] * sfreq2)
    stop_sample2 = int(times2[1] * sfreq2)
    data2, times2_data = raw2.get_data(start=start_sample2, stop=stop_sample2, return_times=True)
    times2_data = times2_data - times2_data[0]

    offset = 50e-6
    fig = go.Figure()

    for i, ch_idx in enumerate(idxs1):
        fig.add_trace(go.Scatter(
            x=times1_data,
            y=data1[ch_idx] + i * offset,
            mode="lines",
            name=f"{subj1} - {task1} - {ch_names1[ch_idx]}"
        ))

    for i, ch_idx in enumerate(idxs2):
        fig.add_trace(go.Scatter(
            x=times2_data,
            y=data2[ch_idx] + i * offset,
            mode="lines",
            name=f"{subj2} - {task2} - {ch_names2[ch_idx]}"
        ))

    fig.update_layout(
        height=800,
        title=f"Overlay: {subj1}/{task1} vs {subj2}/{task2} (aligned at start)",
        xaxis_title="Time (s, aligned to start)",
        yaxis_title="EEG Channels (offset for clarity)",
        legend=dict(
            itemsizing='constant',
            tracegroupgap=3,
            yanchor="top",
            y=1,
            xanchor="left",
            x=1.02
        )
    )

    fig.show()

In [39]:
temp_subj = "sub-NDARAC904DMU"
temp_eegs = [12, 13, 14, 15, 16, 17]
task1 = "contrastChangeDetection_run-1"
task2 = "contrastChangeDetection_run-2"

plot_overlay(
    release=1,
    subj1=temp_subj,
    task1=task1,
    subj2=temp_subj,
    task2=task2,
    times1=[4.9, 6.9],
    times2=[4.5, 6.5],
    eegs1=temp_eegs,
    eegs2=temp_eegs
)

Extracting EDF parameters from /Users/carinaxguo/ESE 5380/eeg_challenge/release1/sub-NDARAC904DMU/eeg/sub-NDARAC904DMU_task-contrastChangeDetection_run-1_eeg.bdf...
BDF file detected
Setting channel info structure...
Creating raw.info structure...
Extracting EDF parameters from /Users/carinaxguo/ESE 5380/eeg_challenge/release1/sub-NDARAC904DMU/eeg/sub-NDARAC904DMU_task-contrastChangeDetection_run-2_eeg.bdf...
BDF file detected
Setting channel info structure...
Creating raw.info structure...


In [None]:
def plot_overlay(
    release: int,
    subj1: str,
    task1: str,
    start_time1: float,
    end_time1: float,
    subj2: str,
    task2: str,
    start_time2: float,
    end_time2: float,
    eegs1: list[int] = None,
    eegs2: list[int] = None
):
