EEG generation with VAE - taken from https://github.com/arkanivasarkar/EEG-Data-Augmentation-using-Variational-Autoencoder/blob/main/VAE%20model%20for%20EEG%20Data%20Augmentation.ipynb, translated from tensorflow to pytorch

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import mne


In [None]:
# ----------------------------
# Output directory MUST match your loader
# ----------------------------
OUT_DIR = r"bci_iv_2a_data/A01/train/0"
os.makedirs(OUT_DIR, exist_ok=True)

# ----------------------------
# Download / load settings
# ----------------------------
SUBJECTS = [1]          # start with one subject; add more later (1..109)
RUNS = [6, 10]          # motor imagery runs in EEGMMIDB
SFREQ_TARGET = 250      # so 4 seconds -> 1000 samples
TMIN, TMAX = 0.0, 4.0   # epoch length 4s

# A consistent 22-channel subset (falls back if some missing)
PICK_22 = [
    "Fz","FC3","FC1","FCz","FC2","FC4",
    "C5","C3","C1","Cz","C2","C4","C6",
    "CP3","CP1","CPz","CP2","CP4",
    "Pz","POz","Oz","P3"
]

def pick_22_channels(raw_or_epochs, desired):
    obj = raw_or_epochs.copy().pick_types(eeg=True, meg=False, eog=False, ecg=False, stim=False, emg=False, exclude=[])
    chs = obj.ch_names
    chosen = [ch for ch in desired if ch in chs]
    if len(chosen) >= 22:
        return chosen[:22]
    return chs[:22]

file_count = 0

for subj in SUBJECTS:
    edf_files = eegbci.load_data(subj, RUNS)
    raws = [mne.io.read_raw_edf(f, preload=True, verbose=False) for f in edf_files]
    raw = mne.concatenate_raws(raws)

    raw.set_montage("standard_1020", on_missing="ignore")

    # Basic MI preprocessing
    raw.filter(8.0, 30.0, fir_design="firwin")
    raw.notch_filter([50, 100])
    raw.set_eeg_reference("average", projection=False)

    # Force 1000 timesteps per epoch (4s at 250 Hz)
    raw.resample(SFREQ_TARGET)

    events, event_id = mne.events_from_annotations(raw)

    # Keep imagery events. EEGMMIDB uses T1/T2 for imagery and T0 for rest.
    keep = {k: v for k, v in event_id.items() if k in ("T1", "T2")}
    if not keep:
        raise RuntimeError(f"No T1/T2 events found. event_id={event_id}")

    epochs = mne.Epochs(
        raw, events, event_id=keep,
        tmin=TMIN, tmax=TMAX,
        baseline=None, preload=True, verbose=False
    )

    chosen_22 = pick_22_channels(epochs, PICK_22)
    epochs = epochs.copy().pick_channels(chosen_22)

    X = epochs.get_data()  # (n_epochs, 22, 1000)
    if X.shape[1] != 22 or X.shape[2] != 1000:
        raise RuntimeError(f"Got shape {X.shape}, expected (n, 22, 1000). Check picks or resampling/window.")

    # Write CSVs in EXACT format your loader expects:
    # 22 rows, 1001 cols: [channel_name | 1000 samples]
    for i in range(X.shape[0]):
        arr = X[i]  # (22,1000)
        df = pd.DataFrame(arr)
        df.insert(0, 0, chosen_22)  # first column index 0 contains strings (channel names)
        out_path = os.path.join(OUT_DIR, f"sub{subj:03d}_epoch{i:05d}.csv")
        df.to_csv(out_path, header=False, index=False)
        file_count += 1

print(f"Wrote {file_count} CSV epoch files to {OUT_DIR}")


FileNotFoundError: [Errno 2] No such file or directory: 'bci_iv_2a_data/A01/train/0/'