In [1]:
from pathlib import Path
import polars as pl

In [2]:
ROOT_DIR = Path("/home/bobby/repos/latent-neural-dynamics-modeling")
saved_dir = ROOT_DIR / "data"

In [41]:
participants = pl.read_parquet(
    saved_dir / "resampled" / "participants_intermediate.parquet"
)

In [42]:
PDI_labels = pl.read_csv(saved_dir / "PDI_labels.csv")
PDI_labels_trials = pl.read_csv(saved_dir / "PDI_labels_trials.csv")
PDI_trial_durations = pl.read_csv(saved_dir / "PDI_trial_durations.csv")

In [43]:
PDI_labels_trials = PDI_labels_trials.select(
    pl.col("subject").alias("participant_id"),
    pl.col("session").cast(pl.UInt32),
    pl.col("ix_block").cast(pl.UInt32).alias("block"),
    pl.col("ix_trial").cast(pl.UInt32).alias("trial"),
    pl.when(pl.col("dbs_stim") == "on").then(1).otherwise(0).alias("dbs_stim"),
    pl.col("yscore").cast(pl.Float32),
)

PDI_labels_trials = PDI_labels_trials.sort(
    ["participant_id", "session", "block", "trial"]
)

PDI_labels_trials = (
    PDI_labels_trials.group_by(
        ["participant_id", "session", "block"], maintain_order=True
    )
    .agg(
        pl.col("trial").alias("trials"),
        pl.col("dbs_stim"),
        pl.col("yscore").alias("yscores"),
    )
    .with_columns(pl.col("trials").list.len().alias("trial_count"))
)

In [None]:
PDI_labels = PDI_labels.select(
    pl.col("subject").alias("participant_id"),
    pl.col("session").cast(pl.UInt32),
    pl.col("ix_block").cast(pl.UInt32).alias("block"),
    pl.when(pl.col("dbs_stim") == "on").then(1).otherwise(0).alias("dbs_stim"),
    pl.col("valid_trial_cnt").cast(pl.UInt32),
)

PDI_labels = PDI_labels.sort(["participant_id", "session", "block"])

In [45]:
PDI_trial_durations = PDI_trial_durations.select(
    pl.col("sub").str.split(by="-").list.get(-1).alias("participant_id"),
    pl.col("ses").str.split(by="-").list.get(-1).cast(pl.UInt32).alias("session"),
    pl.col("block").cast(pl.UInt32),
    pl.col("dt_s").cast(pl.Float32),
).sort(["participant_id", "session", "block"])

In [46]:
participants = (
    participants.with_columns(pl.col("participant_id").str.split(by="-").list.get(-1))
    .with_columns(
        pl.col("ieeg_headers_file")
        .str.split("/")
        .list.get(-1)
        .str.split(".")
        .list.get(0)
        .alias("base_ieeg_file")
    )
    .with_columns(
        pl.concat_str(
            pl.lit(str(saved_dir / "resampled")),
            pl.concat_str(pl.col("base_ieeg_file"), pl.lit("parquet"), separator="."),
            separator="/",
        ).alias("ieeg_parquet"),
        pl.col("session").cast(pl.UInt32),
        pl.col("run").cast(pl.UInt32).alias("block"),
        pl.col("session_path")
        .str.split(by="/")
        .list.tail(2)
        .list.join(separator="/")
        .map_elements(lambda s: str(saved_dir.joinpath(s)), return_dtype=pl.String)
        .alias("session_path"),
    )
    .drop("base_ieeg_file", "ieeg_headers_file", "durations", "run")
)

participants = participants.drop("participant_path", "ieeg_path", "ieeg_file").unique()

In [None]:
PDI_trial_durations = PDI_trial_durations.join(
    PDI_labels, on=["participant_id", "session", "block"], how="left"
)
PDI_trial_durations = PDI_trial_durations.join(
    PDI_labels_trials, on=["participant_id", "session", "block"], how="left"
)

In [48]:
PDI_trial_durations = PDI_trial_durations.with_columns(
    pl.col("valid_trial_cnt").fill_null(-1), pl.col("trial_count").fill_null(-1)
)

In [None]:
# PDI4-3-12 all trials valid
# PDI3-4-12 all good for classification but won't take the x and y coordinates
PDI_trial_durations.filter(pl.col("valid_trial_cnt") != pl.col("trial_count"))

participant_id,session,block,dt_s,dbs_stim,valid_trial_cnt,trials,dbs_stim_right,yscores,trial_count
str,u32,u32,f32,i32,i64,list[u32],list[i32],list[f32],i64
"""PDI3""",4,12,9.018765,1,11,,,,-1
"""PDI4""",3,12,9.002478,1,12,,,,-1


In [50]:
"""
Process the PDI_trial_durations DataFrame according to the following logic:
    - For any row where trial_count is -1 but valid_trial_cnt is not -1, set the 'trials'
        column to the list [1, 2, ..., 12]. (This indicates there are no pre-assigned trial labels.)
    - For rows where both trial_count and valid_trial_cnt are -1, remove these rows from
        the DataFrame.
The function prints all the rows that were removed.
"""

"\nProcess the PDI_trial_durations DataFrame according to the following logic:\n    - For any row where trial_count is -1 but valid_trial_cnt is not -1, set the 'trials'\n        column to the list [1, 2, ..., 12]. (This indicates there are no pre-assigned trial labels.)\n    - For rows where both trial_count and valid_trial_cnt are -1, remove these rows from\n        the DataFrame.\nThe function prints all the rows that were removed.\n"

In [None]:
removed_rows = PDI_trial_durations.filter(
    (pl.col("trial_count") == -1) & (pl.col("valid_trial_cnt") == -1)
)
print("Removed rows:")
print(
    removed_rows.select(
        "participant_id", "session", "block", "valid_trial_cnt", "trial_count"
    )
)

Removed rows:
shape: (4, 5)
┌────────────────┬─────────┬───────┬─────────────────┬─────────────┐
│ participant_id ┆ session ┆ block ┆ valid_trial_cnt ┆ trial_count │
│ ---            ┆ ---     ┆ ---   ┆ ---             ┆ ---         │
│ str            ┆ u32     ┆ u32   ┆ i64             ┆ i64         │
╞════════════════╪═════════╪═══════╪═════════════════╪═════════════╡
│ PDI3           ┆ 4       ┆ 1     ┆ -1              ┆ -1          │
│ PDI3           ┆ 4       ┆ 2     ┆ -1              ┆ -1          │
│ PDI3           ┆ 4       ┆ 3     ┆ -1              ┆ -1          │
│ PDI4           ┆ 3       ┆ 9     ┆ -1              ┆ -1          │
└────────────────┴─────────┴───────┴─────────────────┴─────────────┘


In [None]:
# Exclude the rows where both trial_count and valid_trial_cnt are -1
PDI_trial_durations = PDI_trial_durations.filter(
    ~((pl.col("trial_count") == -1) & (pl.col("valid_trial_cnt") == -1))
)

In [None]:
PDI_trial_durations = PDI_trial_durations.with_columns(
    [
        pl.when(pl.col("trial_count") == -1)
        .then(pl.lit(list(range(1, 13))))
        .otherwise(pl.col("trials"))
        .alias("trials"),
        pl.when(pl.col("trial_count") == -1)
        .then(pl.col("valid_trial_cnt").eq(12))
        .otherwise(pl.col("trial_count") == pl.col("valid_trial_cnt"))
        .alias("labels"),
    ]
)

In [55]:
PDI_trial_durations.filter(~pl.col("labels"))

participant_id,session,block,dt_s,dbs_stim,valid_trial_cnt,trials,dbs_stim_right,yscores,trial_count,labels
str,u32,u32,f32,i32,i64,list[u32],list[i32],list[f32],i64,bool
"""PDI3""",4,12,9.018765,1,11,"[1, 2, … 12]",,,-1,False


In [None]:
PDI_trial_durations = PDI_trial_durations.drop(
    "valid_trial_cnt", "trial_count", "dbs_stim_right"
)

In [None]:
participants = participants.join(
    PDI_trial_durations, on=["participant_id", "session", "block"], how="left"
)

In [59]:
participants.filter(pl.col("dbs_stim").is_null())

participant_id,session_path,session,onsets,ieeg_parquet,block,dt_s,dbs_stim,trials,yscores,labels
str,str,u32,list[f64],str,u32,f32,i32,list[u32],list[f32],bool
"""PDI4""","""/home/bobby/repos/latent-neura…",3,"[10.723909, 25.000409, … 190.073364]","""/home/bobby/repos/latent-neura…",9,,,,,
"""PDI3""","""/home/bobby/repos/latent-neura…",4,"[12.854636, 30.407727, … 268.204773]","""/home/bobby/repos/latent-neura…",2,,,,,
"""PDI3""","""/home/bobby/repos/latent-neura…",4,"[62.370591, 84.879727, … 269.472682]","""/home/bobby/repos/latent-neura…",1,,,,,
"""PDI3""","""/home/bobby/repos/latent-neura…",4,"[12.227955, 29.381682, … 199.263364]","""/home/bobby/repos/latent-neura…",3,,,,,


In [60]:
participants = participants.filter(pl.col("dbs_stim").is_not_null())

In [61]:
participants.write_parquet(
    saved_dir / "participants.parquet",
    partition_by=["participant_id", "session"],
)

In [9]:
participants = pl.read_parquet(
    "/home/bobby/repos/latent-neural-dynamics-modeling/resampled_recordings/participants/participant_id=PDI1/session=2/*"
)

In [10]:
participants

participant_id,session,block,trial,onset,duration,time,start_ts,trial_length_ts,chunk_margin,dbs_stim,yscore,LFP_1,LFP_2,LFP_3,LFP_4,LFP_5,LFP_6,LFP_7,LFP_8,LFP_9,LFP_10,LFP_11,LFP_12,LFP_13,LFP_14,LFP_15,LFP_16,ECOG_1,ECOG_2,ECOG_3,ECOG_4,x,y,tracing_coordinates_present
str,u32,u32,u32,f64,f32,list[f64],u32,u32,i32,i32,f32,list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[f32],list[i64],list[i64],bool
"""PDI1""",2,1,1,16.102864,11.019526,"[16.102864, 16.103864, … 27.120864]",16102,11019,2,0,2.703554,"[0.00001, 0.000007, … -0.000001]","[0.000009, 0.000005, … -0.000001]","[0.000027, 0.000002, … -0.00003]","[0.000006, 0.000079, … 0.000089]","[0.000004, 3.8836e-7, … -0.000002]","[-0.000028, -0.000034, … -0.000048]","[-0.000022, -0.000053, … -0.000019]","[0.000005, 0.000002, … 0.000002]","[0.000006, 2.1811e-7, … -0.000001]","[5.0463e-7, -0.000005, … 0.000002]","[0.000002, -0.000002, … -5.1700e-7]","[0.000002, -0.000002, … -2.2884e-7]","[0.000005, -7.1312e-7, … 6.5943e-7]","[0.000004, -0.000002, … 0.000003]","[0.000004, -0.000002, … -0.000003]","[0.000004, -0.000001, … -4.0172e-7]","[-0.000019, -0.000023, … 0.000012]","[-0.000017, -0.000018, … 0.000013]","[-0.000015, -0.000015, … 0.00001]","[-0.000013, -0.00001, … 0.000007]","[-15, -113, … -360]","[217, 45, … 47]",true
"""PDI1""",2,1,2,34.896455,11.019526,"[34.896455, 34.897455, … 45.914455]",34896,11019,2,0,1.811423,"[0.000001, -6.8187e-7, … 6.2975e-7]","[-0.000005, -0.000007, … -0.000005]","[0.000009, -0.000004, … -0.000032]","[5.1902e-7, -0.00005, … 0.000014]","[-0.000002, -0.000005, … -0.000006]","[0.00001, 0.00001, … -0.000009]","[-0.000022, -0.000035, … -0.000036]","[-0.000005, -0.000007, … -8.7612e-8]","[-4.5056e-7, -0.000002, … -0.000009]","[0.000001, -0.000001, … -0.000007]","[-0.000003, -0.000004, … -0.000004]","[-0.000005, -0.000007, … -0.000003]","[-0.000002, -0.000005, … -0.000003]","[-0.000004, -0.000006, … -0.000002]","[-0.000012, -0.000014, … 0.000004]","[-0.000006, -0.000008, … 6.7189e-9]","[0.000008, 0.000004, … -0.000005]","[0.000002, -0.000002, … -0.000007]","[0.000004, 2.4216e-7, … -0.00001]","[0.000008, 0.000007, … -0.000009]","[-6, -23, … -301]","[48, -6, … 24]",true
"""PDI1""",2,1,3,53.729091,11.019526,"[53.729091, 53.730091, … 64.747091]",53729,11019,2,0,2.518298,"[0.000004, 0.000006, … -0.000015]","[0.000004, 0.000005, … -0.000027]","[0.00005, 0.000018, … -0.000042]","[-0.000046, -0.00002, … -0.000041]","[0.000004, 0.000006, … -0.000023]","[0.000036, 0.000017, … -0.000023]","[0.000032, 0.000026, … -0.000038]","[0.000005, 0.000006, … -0.000026]","[0.000012, 0.000013, … -0.000019]","[0.000013, 0.000013, … -0.000016]","[0.000014, 0.000015, … -0.000023]","[0.00001, 0.000011, … -0.000026]","[0.000016, 0.000017, … -0.000028]","[-0.000005, -0.000004, … -0.000032]","[-0.000005, -0.000004, … -0.000043]","[0.000008, 0.000009, … -0.000029]","[0.000014, 0.000012, … 0.000009]","[0.000021, 0.000021, … 0.000016]","[0.000017, 0.000017, … 0.000014]","[0.000012, 0.000012, … 0.00001]","[-281, -209, … 3]","[-1, 110, … -193]",true
"""PDI1""",2,1,4,70.946727,11.019526,"[70.946727, 70.947727, … 81.964727]",70946,11019,2,0,2.052713,"[-0.000001, 0.000003, … 0.00002]","[-0.000005, -9.3995e-7, … 0.000018]","[-0.000016, -0.000022, … -0.000024]","[-0.000067, -0.000082, … -0.000057]","[-0.000002, 0.000001, … 0.000021]","[-0.000024, -0.000007, … 0.000027]","[0.000002, 0.000026, … 0.00002]","[-0.000003, 1.3734e-7, … 0.000019]","[-0.000008, -0.000004, … 0.000006]","[-0.000006, -0.000002, … 0.000003]","[4.5071e-7, 0.000004, … 0.000009]","[-0.000004, -6.8121e-7, … 0.000006]","[-0.000004, -0.000001, … 0.000007]","[-0.000005, -0.000003, … 0.00001]","[-0.000016, -0.000013, … 0.000006]","[-0.000004, -0.000001, … 0.00001]","[-0.000003, -0.000013, … -0.000015]","[-0.000007, -0.000018, … -0.000007]","[-0.000006, -0.000015, … -0.000002]","[-0.000013, -0.000021, … -0.000004]","[-56, 41, … -199]","[-175, 132, … -168]",true
"""PDI1""",2,1,5,89.455591,11.019526,"[89.455591, 89.456591, … 100.473591]",89455,11019,2,0,0.443145,"[0.000009, 0.000001, … -0.000025]","[0.000004, -0.000004, … -0.000023]","[-0.000008, -0.000029, … -0.000048]","[0.000063, -0.000077, … -0.000084]","[0.000004, -0.000003, … -0.000023]","[-0.000031, -0.000035, … 0.00002]","[0.000017, 0.000022, … -0.000005]","[0.000006, 8.9125e-8, … -0.000029]","[0.000008, 0.000003, … -0.000027]","[0.000012, 0.000007, … -0.000029]","[0.00001, 0.000005, … -0.000026]","[0.000008, 0.000001, … -0.000025]","[0.000012, 0.000007, … -0.000027]","[0.000003, -0.000002, … -0.000026]","[-0.000004, -0.00001, … -0.000015]","[0.000009, 0.000003, … -0.000025]","[0.000002, 0.000003, … 0.000005]","[0.000008, 0.000009, … -9.8528e-7]","[0.000009, 0.000009, … -0.000008]","[0.000008, 0.000007, … -0.000012]","[-277, -291, … -221]","[-85, -205, … 12]",true
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""PDI1""",2,12,8,134.893955,11.018443,"[134.893955, 134.894955, … 145.910955]",134893,11018,2,1,1.339005,"[0.000706, 0.000113, … 0.0]","[0.000015, 0.000009, … 5.4210e-20]","[-0.00001, -4.9192e-8, … 6.7763e-21]","[0.000267, 0.000046, … 0.0]","[6.3018e-8, -2.6153e-8, … 7.3017e-20]","[0.000403, 0.00001, … 0.0]","[-0.000063, -0.000059, … 5.0822e-21]","[0.00097, 0.000235, … 0.0]","[-0.000004, -0.000002, … 1.0164e-20]","[-0.000006, -0.000003, … -2.0329e-20]","[-0.000009, -0.000007, … -5.4210e-20]","[-0.000006, -0.000003, … -5.0822e-21]","[-0.000002, 0.000001, … 1.0164e-20]","[-0.000003, -4.5280e-7, … 0.0]","[-0.000009, -0.000007, … -6.7763e-21]","[-0.000005, -0.000002, … 1.0164e-20]","[0.000002, -0.000006, … -2.5411e-21]","[0.000006, -0.000001, … -6.7763e-21]","[0.000011, 0.000003, … -1.1858e-20]","[0.000017, 0.000006, … -4.2352e-21]","[-296, -222, … -293]","[-31, -136, … 119]",true
"""PDI1""",2,12,9,150.167591,11.018443,"[150.167591, 150.168591, … 161.184591]",150167,11018,2,1,2.221776,"[-0.000121, 0.000795, … 0.0]","[0.000013, 0.000014, … 5.4210e-20]","[0.000028, 0.000013, … 6.7763e-21]","[-0.000088, 0.000068, … 0.0]","[-9.2487e-9, 3.1576e-7, … 7.3017e-20]","[-0.000295, 0.000355, … 0.0]","[0.000024, 0.000003, … 5.0822e-21]","[-0.001348, 0.00062, … 0.0]","[0.000007, 0.000011, … 1.0164e-20]","[0.000018, 0.000021, … -2.0329e-20]","[0.000003, 0.000006, … -5.4210e-20]","[-0.000002, 7.5282e-7, … -5.0822e-21]","[0.000004, 0.000006, … 1.0164e-20]","[0.000034, 0.000036, … 0.0]","[-0.000006, -0.000004, … -6.7763e-21]","[0.000013, 0.000014, … 1.0164e-20]","[0.000009, 0.000006, … -2.5411e-21]","[0.000004, 3.4474e-7, … -6.7763e-21]","[0.000006, 0.000004, … -1.1858e-20]","[0.000011, 0.000007, … -4.2352e-21]","[-208, -42, … -17]","[-54, 200, … -146]",true
"""PDI1""",2,12,10,167.065227,11.018443,"[167.065227, 167.066227, … 178.082227]",167065,11018,2,1,-0.696225,"[0.00048, -0.000111, … 0.0]","[-0.000003, -0.000007, … 5.4210e-20]","[-0.000035, -0.00004, … 6.7763e-21]","[0.000264, 0.000042, … 0.0]","[-1.3278e-7, 2.9667e-10, … 7.3017e-20]","[0.000527, 0.000248, … 0.0]","[-0.000007, 0.000022, … 5.0822e-21]","[0.00096, 0.000974, … 0.0]","[-0.000003, -0.000005, … 1.0164e-20]","[-2.8684e-7, -0.000002, … -2.0329e-20]","[0.000002, -1.6764e-7, … -5.4210e-20]","[-0.000003, -0.000006, … -5.0822e-21]","[-0.000003, -0.000005, … 1.0164e-20]","[-0.000008, -0.000011, … 0.0]","[-0.000004, -0.000007, … -6.7763e-21]","[0.000004, 0.000002, … 1.0164e-20]","[0.000007, 0.000004, … -2.5411e-21]","[-0.000006, -0.000007, … -6.7763e-21]","[-0.00001, -0.00001, … -1.1858e-20]","[-0.000018, -0.000018, … -4.2352e-21]","[-70, 15, … -239]","[120, 190, … 215]",true
"""PDI1""",2,12,11,182.940955,11.018443,"[182.940955, 182.941955, … 193.957955]",182940,11018,2,1,-0.450917,"[-0.000564, 0.000683, … 0.0]","[0.000047, 0.000014, … 5.4210e-20]","[0.000014, -0.000007, … 6.7763e-21]","[-0.000349, 0.000079, … 0.0]","[1.9035e-7, 2.9359e-7, … 7.3017e-20]","[-0.000929, 0.000207, … 0.0]","[0.00001, 0.000003, … 5.0822e-21]","[-0.001829, 0.000117, … 0.0]","[-0.00003, -0.00002, … 1.0164e-20]","[-0.000026, -0.000019, … -2.0329e-20]","[-0.000037, -0.000026, … -5.4210e-20]","[-0.000033, -0.000023, … -5.0822e-21]","[-0.000021, -0.000013, … 1.0164e-20]","[-0.000047, -0.000039, … 0.0]","[-0.000016, -0.000007, … -6.7763e-21]","[-0.00003, -0.000022, … 1.0164e-20]","[-0.000019, -0.000022, … -2.5411e-21]","[-0.000011, -0.000013, … -6.7763e-21]","[-0.000009, -0.000008, … -1.1858e-20]","[-0.000013, -0.000012, … -4.2352e-21]","[-313, -49, … -233]","[-61, 182, … 218]",true
