In [19]:
import torch
import polars as pl
import pandas as pd
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, Dataset

import lightning as L

In [None]:
class EEGDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [20]:
train = pl.read_csv("../data/train.csv")
train.head()

eeg_id,eeg_sub_id,eeg_label_offset_seconds,spectrogram_id,spectrogram_sub_id,spectrogram_label_offset_seconds,label_id,patient_id,expert_consensus,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
i64,i64,f64,i64,i64,f64,i64,i64,str,i64,i64,i64,i64,i64,i64
1628180742,0,0.0,353733,0,0.0,127492639,42516,"""Seizure""",3,0,0,0,0,0
1628180742,1,6.0,353733,1,6.0,3887563113,42516,"""Seizure""",3,0,0,0,0,0
1628180742,2,8.0,353733,2,8.0,1142670488,42516,"""Seizure""",3,0,0,0,0,0
1628180742,3,18.0,353733,3,18.0,2718991173,42516,"""Seizure""",3,0,0,0,0,0
1628180742,4,24.0,353733,4,24.0,3080632009,42516,"""Seizure""",3,0,0,0,0,0


In [21]:
target_cols = [
    "seizure_vote",
    "lpd_vote",
    "gpd_vote",
    "lrda_vote",
    "grda_vote",
    "other_vote",
]

In [23]:
train["patient_id"].n_unique()

1950

In [28]:
train["patient_id"].value_counts().sort(by="count", descending=True)

patient_id,count
i64,u32
30631,2215
2641,2185
35627,1403
28330,1362
54199,1350
35225,1288
57378,1165
38549,1012
56450,952
32481,942


In [29]:
train["eeg_id"].value_counts().sort(by="count", descending=True)

eeg_id,count
i64,u32
2259539799,743
2428433259,664
1641054670,562
2860052642,534
525664301,531
1712056492,433
1480985066,416
188361788,412
3525185677,286
1596590162,275


In [31]:
train.filter(pl.col("eeg_id") == 2259539799).select(
    *[pl.col(target).mean().alias(f"{target}_mean") for target in target_cols],
    *[pl.col(target).min().alias(f"{target}_min") for target in target_cols],
    *[pl.col(target).max().alias(f"{target}_max") for target in target_cols],
)

seizure_vote_mean,lpd_vote_mean,gpd_vote_mean,lrda_vote_mean,grda_vote_mean,other_vote_mean,seizure_vote_min,lpd_vote_min,gpd_vote_min,lrda_vote_min,grda_vote_min,other_vote_min,seizure_vote_max,lpd_vote_max,gpd_vote_max,lrda_vote_max,grda_vote_max,other_vote_max
f64,f64,f64,f64,f64,f64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64
5.0,0.0,10.0,0.0,0.0,0.0,5,0,10,0,0,0,5,0,10,0,0,0


In [32]:
train.filter(pl.col("eeg_id") == 2428433259).select(
    *[pl.col(target).mean().alias(f"{target}_mean") for target in target_cols],
    *[pl.col(target).min().alias(f"{target}_min") for target in target_cols],
    *[pl.col(target).max().alias(f"{target}_max") for target in target_cols],
)

seizure_vote_mean,lpd_vote_mean,gpd_vote_mean,lrda_vote_mean,grda_vote_mean,other_vote_mean,seizure_vote_min,lpd_vote_min,gpd_vote_min,lrda_vote_min,grda_vote_min,other_vote_min,seizure_vote_max,lpd_vote_max,gpd_vote_max,lrda_vote_max,grda_vote_max,other_vote_max
f64,f64,f64,f64,f64,f64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64
0.0,0.0,0.0,0.0,13.0,2.0,0,0,0,0,13,2,0,0,0,0,13,2


In [33]:
train.filter(pl.col("eeg_id") == 952904038).select(
    *[pl.col(target).mean().alias(f"{target}_mean") for target in target_cols],
    *[pl.col(target).min().alias(f"{target}_min") for target in target_cols],
    *[pl.col(target).max().alias(f"{target}_max") for target in target_cols],
)

seizure_vote_mean,lpd_vote_mean,gpd_vote_mean,lrda_vote_mean,grda_vote_mean,other_vote_mean,seizure_vote_min,lpd_vote_min,gpd_vote_min,lrda_vote_min,grda_vote_min,other_vote_min,seizure_vote_max,lpd_vote_max,gpd_vote_max,lrda_vote_max,grda_vote_max,other_vote_max
f64,f64,f64,f64,f64,f64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64,i64
0.0,0.0,0.0,3.0,0.0,0.0,0,0,0,3,0,0,0,0,0,3,0,0


In [39]:
tmp = train.select(target_cols + ["eeg_id"]).unique()
tmp.shape

(20183, 7)

In [66]:
tmp["eeg_id"].value_counts().sort(by="count", descending=True).head(100)

eeg_id,count
i64,u32
188361788,43
3246176805,30
1460778765,29
1128738777,26
833106986,22
216551560,19
553638140,17
1378468467,17
1062096035,17
1954809341,15


In [56]:
pd.set_option("display.max_rows", 450)
train.filter(pl.col("eeg_id") == 653687449).to_pandas()

Unnamed: 0,eeg_id,eeg_sub_id,eeg_label_offset_seconds,spectrogram_id,spectrogram_sub_id,spectrogram_label_offset_seconds,label_id,patient_id,expert_consensus,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
0,653687449,0,0.0,683350709,0,0.0,1757843188,45536,GPD,5,0,7,0,0,2
1,653687449,1,2.0,683350709,1,2.0,2102025444,45536,GPD,5,0,7,0,0,2
2,653687449,2,6.0,683350709,2,6.0,3585042465,45536,GPD,5,0,7,0,0,2
3,653687449,3,10.0,683350709,3,10.0,2213345486,45536,GPD,5,0,7,0,0,2
4,653687449,4,20.0,683350709,4,20.0,3052998559,45536,GPD,5,0,7,0,0,2
5,653687449,5,22.0,683350709,5,22.0,1276719223,45536,GPD,5,0,7,0,0,2
6,653687449,6,26.0,683350709,6,26.0,3599136675,45536,GPD,5,0,7,0,0,2
7,653687449,7,28.0,683350709,7,28.0,2403855604,45536,GPD,5,0,7,0,0,2
8,653687449,8,34.0,683350709,8,34.0,4091879711,45536,GPD,5,0,7,0,0,2
9,653687449,9,36.0,683350709,9,36.0,3737790377,45536,GPD,5,0,7,0,0,2


In [70]:
train.filter(pl.col("patient_id") == 56885).to_pandas()

Unnamed: 0,eeg_id,eeg_sub_id,eeg_label_offset_seconds,spectrogram_id,spectrogram_sub_id,spectrogram_label_offset_seconds,label_id,patient_id,expert_consensus,seizure_vote,lpd_vote,gpd_vote,lrda_vote,grda_vote,other_vote
0,722738444,0,0.0,999431,0,0.0,557980729,56885,LRDA,0,1,0,14,0,1
1,722738444,1,2.0,999431,1,2.0,1949834128,56885,LRDA,0,1,0,14,0,1
2,722738444,2,4.0,999431,2,4.0,3790867376,56885,LRDA,0,1,0,14,0,1
3,722738444,3,6.0,999431,3,6.0,2641122017,56885,LRDA,0,1,0,14,0,1
4,722738444,4,8.0,999431,4,8.0,1991146353,56885,LRDA,0,1,0,14,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
727,3525185677,281,708.0,2011177737,281,708.0,1139532324,56885,LRDA,0,2,0,13,0,0
728,3525185677,282,710.0,2011177737,282,710.0,1840764203,56885,LRDA,0,2,0,13,0,0
729,3525185677,283,714.0,2011177737,283,714.0,2775502362,56885,LRDA,0,2,0,13,0,0
730,3525185677,284,716.0,2011177737,284,716.0,3596412880,56885,LRDA,0,2,0,13,0,0


In [68]:
train["patient_id"].value_counts()

patient_id,count
i64,u32
65026,25
57448,11
56308,59
54455,36
29885,11
25107,8
52421,32
15064,60
65430,46
39877,35


In [90]:
# Observations:
# Same EEG has mostly same labels
# Only 100 of 17089 EEGs have more than 5 labels
# Only 1000 jave more than 2 different labelling

# As such it would make sense to only sample EEG with unique target once per epoch while training

# Samples need to weighted by no. of labelers

# Only one sample per patient per epoch. Sample EEG of patient with a probability inversely proportional to the number of EEGs of that patient

# Patient with multiple disgnonsis are difficult cases, need closer look

# 1100/1950 have more than 1 label, #400 have more than 2 labels, 130 have more than 3 labels
# Those 100 patients need to looked at more closely

In [88]:
train.group_by("patient_id").agg(
    pl.col("expert_consensus").n_unique().alias("count")
).sort(by="count", descending=True).head(140)

patient_id,count
i64,u32
44615,6
35627,6
6935,6
17948,6
58475,6
65356,6
6284,6
27986,6
13252,5
21145,5


In [91]:
norm_targets = train.select(target_cols).to_numpy()
norm_targets = norm_targets / norm_targets.sum(axis=1, keepdims=True)
# train = train.drop(target_cols)
train = train.with_columns(
    *[
        pl.Series("norm_" + target, norm_targets[:, i])
        for i, target in enumerate(target_cols)
    ]
)

In [92]:
train.group_by("patient_id").agg(
    *[pl.col("norm_" + target).mean().alias(f"{target}") for target in target_cols],
).select(
    [pl.col(target).mean().alias(f"{target}_mean") for target in target_cols],
)

seizure_vote_mean,lpd_vote_mean,gpd_vote_mean,lrda_vote_mean,grda_vote_mean,other_vote_mean
f64,f64,f64,f64,f64,f64
0.34483,0.039809,0.047374,0.081013,0.234644,0.252329


In [93]:
# Dataset

# For training, we only take 1 sample per patient
# For validation, we take all samples but we average the score per patient first and then average across all patients

In [105]:
train = (
    train.with_columns(
        pl.len().over("patient_id").alias("patient_sample_count"),
    )
    .with_columns(
        (
            1
            - pl.len().over(["eeg_id", *target_cols]).alias("eeg_sample_count")
            / pl.col("patient_sample_count")
        ).alias("sample_weight"),
    )
    .with_columns(
        (
            pl.col("sample_weight") / pl.col("sample_weight").sum().over("patient_id")
        ).alias("sample_weight")
    )
)

In [153]:
#
import random


def load_eeg_data(eeg_id, offset):
    df = pd.read_parquet(f"../data/train_eegs/{eeg_id}.parquet")
    offset = int(offset * 200)
    return df.iloc[offset : offset + 200 * 50].values


class HMSTrainEEGData(Dataset):
    def __init__(self, patient_ids, df):
        self.patient_ids = patient_ids
        self.df = df

    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        patient_df = self.df.filter(pl.col("patient_id") == patient_id)
        idx = random.choices(
            range(len(patient_df)), weights=patient_df["sample_weight"].to_numpy()
        )[0]
        patient_df = patient_df[idx].to_pandas()
        eeg_id = patient_df["eeg_id"].iloc[0]
        offset = patient_df["eeg_label_offset_seconds"].iloc[0]
        data = load_eeg_data(eeg_id, offset)
        targets = patient_df[target_cols]
        return data, targets


class HMSValEEGData(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):

        patient_df = self.df[idx].to_pandas()
        eeg_id = patient_df["eeg_id"].iloc[0]
        offset = patient_df["eeg_label_offset_seconds"].iloc[0]
        data = load_eeg_data(eeg_id, offset)
        targets = patient_df[target_cols]
        return data, targets

In [148]:
data = HMSTrainEEGData(train["patient_id"].unique().to_numpy(), train)

In [151]:
x, y = data[0]

In [154]:
len(data)

1950

In [155]:
datav1 = HMSValEEGData(train)

In [156]:
len(datav1)

106800