In [None]:
import os
import polars as pl
import mne

In [None]:
ROOT_PATH = "/home/bobby/repos/latent-neural-dynamics-modeling"
DATA_PATH = os.path.join(ROOT_PATH, "data")

In [None]:
participants = pl.read_csv(
    os.path.join(DATA_PATH, "participants.tsv"), separator="\t", null_values="n/a"
)

In [None]:
def list_files(folder_path: str, root_: bool = False) -> list:
    if root_:
        return os.listdir(folder_path)
    else:
        return os.listdir(os.path.join(DATA_PATH, folder_path))

In [None]:
participants = participants.with_columns(
    pl.col("participant_id")
    .map_elements(lambda pid: list_files(pid), return_dtype=pl.List(pl.String))
    .alias("session")
).explode(pl.col("session"))

## iEEG

In [None]:
participants_ieeg = participants.with_columns(
    pl.concat_str(
        [
            pl.lit(DATA_PATH),
            pl.col("participant_id"),
            pl.col("session"),
            pl.lit("ieeg"),
        ],
        separator="/",
    ).alias("ieeg_path"),
)

In [None]:
participants_ieeg = participants_ieeg.with_columns(
    pl.col("ieeg_path")
    .map_elements(
        lambda ieeg_path: list_files(ieeg_path, root_=True),
        return_dtype=pl.List(pl.String),
    )
    .alias("ieeg_file")
).explode(pl.col("ieeg_file"))

In [None]:
participants_ieeg = (
    participants_ieeg.with_columns(
        pl.col("ieeg_file").str.split(by="_").alias("splitted_file")
    )
    .with_columns(
        pl.col("splitted_file").list.get(-1).str.split(".").list.get(0).alias("type"),
        pl.col("splitted_file")
        .list.get(-1)
        .str.split(".")
        .list.get(-1)
        .alias("data_format"),
        pl.col("splitted_file").list.get(-2).alias("run"),
    )
    .drop("splitted_file")
).filter(~((pl.col("type") == "channels") & (pl.col("data_format") == "tsv")))

In [None]:
def read_csv_(row: dict[str, str]) -> pl.Series:
    keys = row.keys()
    for k in keys:
        if "path" in k:
            path_ = row[k]
        elif "file" in k:
            file_ = row[k]
    df = pl.read_csv(
        os.path.join(path_, file_),
        separator="\t",
        null_values="n/a",
    )
    , 
    return df.to_struct()

In [None]:
events_schema = pl.List(
    pl.Struct(
        [
            pl.Field("onset", pl.Float64),
            pl.Field("duration", pl.Float64),
            pl.Field("trial_type", pl.Int64),
            pl.Field("value", pl.Int64),
            pl.Field("sample", pl.Int64),
        ]
    )
)

In [None]:
events_df = participants_ieeg.filter(
    (pl.col("type") == "events") & (pl.col("data_format") == "tsv")
).select(
    "participant_id",
    "session",
    "run",
    pl.struct(["ieeg_path", "ieeg_file"])
    .map_elements(read_csv_, return_dtype=events_schema)
    .alias("events"),
)

In [None]:
events_df = events_df.explode("events").filter(
    pl.col("events").struct.field("value") != 25
)

In [None]:
events_df = events_df.sort(
    by=[
        pl.col("participant_id"),
        pl.col("session"),
        pl.col("run"),
        pl.col("events").struct.field("onset"),
    ]
)

In [None]:
events_df = events_df.group_by(["participant_id", "session", "run"]).agg(
    pl.col("events")
)

In [None]:
participants_ieeg = participants_ieeg.join(
    events_df, on=["participant_id", "session", "run"], how="left"
).filter(~((pl.col("type") == "events") & (pl.col("data_format") == "tsv")))

In [None]:
participants_ieeg = participants_ieeg.filter(~(pl.col("data_format") == "json"))

In [None]:
headers_df = participants_ieeg.filter(
    (pl.col("type") == "ieeg") & (pl.col("data_format") == "vhdr")
).select(
    "participant_id",
    "session",
    "run",
    pl.col("ieeg_file").alias("ieeg_headers_file"),
)

In [None]:
participants_ieeg = participants_ieeg.join(
    headers_df, on=["participant_id", "session", "run"], how="left"
).filter(~((pl.col("type") == "ieeg") & ~(pl.col("data_format") == "eeg")))

In [None]:
participants_ieeg = participants_ieeg.drop(
    "type", "data_format", "channels_info_right", strict=False
)

In [None]:
def band_pass_resample(ieeg_headers_file: str) -> pl.Struct | None:
    sfreq = 1000
    low_freq = 3
    high_freq = 100
    os.makedirs("./resampled", exist_ok=True)
    ieeg_file = f"./resampled/{ieeg_headers_file.split('/')[-1].split('.')[0]}.fif"

    if not os.path.exists(ieeg_headers_file):
        return None

    try:
        raw = mne.io.read_raw_brainvision(
            ieeg_headers_file, preload=True, verbose=False
        )

        raw.notch_filter(freqs=[50, 100], verbose=False)
        raw.filter(l_freq=low_freq, h_freq=high_freq)
        raw.resample(sfreq=sfreq, verbose=False)

        data = raw.get_data()
        channels_data = {ch: d.tolist() for ch, d in zip(raw.ch_names, data)}
        print(channels_data)
        return pl.DataFrame(channels_data).to_struct()
    except Exception as e:
        return None

In [None]:
participants_ieeg = participants_ieeg.with_columns(
    pl.concat_str(pl.col("ieeg_path"), pl.col("ieeg_headers_file"), separator="/")
    .map_elements(band_pass_resample, return_dtype=pl.List(pl.Struct))
    .alias("ieeg_raw")
)

In [None]:
participants_ieeg

In [None]:
participants_ieeg = participants_ieeg.unnest("ieeg_raw")

In [None]:
participants_ieeg

In [None]:
participants_ieeg.write_parquet(
    "./participants_ieeg", partition_by=["participant_id", "session"]
)

In [None]:
participants_ieeg = pl.read_parquet("./participants_ieeg")

In [None]:
participants_ieeg

## Motion

In [None]:
participants_motion = participants.with_columns(
    pl.concat_str(
        [
            pl.lit(DATA_PATH),
            pl.col("participant_id"),
            pl.col("session"),
            pl.lit("motion"),
        ],
        separator="/",
    ).alias("motion_path"),
)

In [None]:
participants_motion = participants_motion.with_columns(
    pl.col("motion_path")
    .map_elements(
        lambda motion_path: list_files(motion_path, root_=True),
        return_dtype=pl.List(pl.String),
    )
    .alias("motion_file")
).explode(pl.col("motion_file"))

In [None]:
participants_motion = (
    participants_motion.with_columns(
        pl.col("motion_file").str.split(by="_").alias("splitted_file")
    )
    .with_columns(
        pl.col("splitted_file")
        .list.get(-1)
        .str.split(".")
        .list.get(-1)
        .alias("data_format"),
        pl.col("splitted_file").list.get(-1).str.split(".").list.get(0).alias("type"),
        pl.col("splitted_file").list.get(-3).alias("chunk"),
        pl.col("splitted_file").list.get(-4).alias("run"),
    )
    .drop("splitted_file")
)

In [None]:
participants_motion = participants_motion.filter(pl.col("data_format") != "json").drop(
    "data_format"
)
participants_motion = participants_motion.filter(pl.col("type") != "channels").drop(
    "type"
)

In [None]:
participants_motion

In [None]:
motion_schema = pl.List(
    pl.Struct(
        [
            pl.Field("x", pl.List(pl.Float64)),
            pl.Field("y", pl.List(pl.Float64)),
        ]
    )
)

In [None]:
participants_motion = participants_motion.with_columns(
    pl.struct(["motion_path", "motion_file"])
    .map_elements(read_csv_, return_dtype=motion_schema)
    .alias("motion_coordinates"),
)

In [None]:
participants_motion.collect_schema()

In [None]:
participants_motion