In [129]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [130]:
import src.eeg_proj.__main__ as ep
from src.eeg_proj.leapd import LeapdModel
import numpy as np
import numpy.typing as npt
import pandas as pd

In [131]:
%%capture
rng = np.random.RandomState(420)
patients = ep.load_data("ds004584/")

In [132]:
df = pd.DataFrame()
d = []
for p in patients:
    y, x = p.raw.get_data(picks=ep.ELECTRODES, return_times=True)
    y_normalized = [raw / np.abs(raw).max() for raw in y]
    y_bandpassed = [ep.band_pass(norm, highcut=14) for norm in y_normalized]
    tmp = {
        "Group": p.group,
        "Moca": p.moca,
        "time": x,
        **dict(zip(ep.ELECTRODES, y_bandpassed)),
    }
    d.append(tmp)
df = pd.DataFrame(d)

### K-fold Cross-Validation

In [137]:
from sklearn.model_selection import StratifiedKFold

y = pd.get_dummies(df[["Group", "Moca"]])
y_test_group = y["Group_Control"]
y_test_moca_ = y["Moca_COGNITIVE NORMAL"]
x = df[ep.ELECTRODES]

skf = StratifiedKFold(5, shuffle=True, random_state=rng)

for train_index, test_index in skf.split(np.zeros(len(y_test_group)), y_test_group):
    for electrode in ep.ELECTRODES:
        x_train = df.iloc[train_index]  # [electrode].to_numpy()  # .loc[electrode]
        x_parkinsons = x_train.loc[df['Group'] == "PD", electrode]
        x_control = x_train.loc[df['Group'] == "Control", electrode]
        model = LeapdModel(x_parkinsons[:], x_control[:], 4)

        pd_test = []
        control_test = []
        for i in test_index:
            test_value = df.iloc[i]
            group = test_value["Group"]
            test_x = test_value[electrode]
            rho = model.classify(test_x)
            if group == "PD":
                pd_test.append(rho > 0.5)
            else:
                control_test.append(rho < 0.5)
            # print(rho)
        n_pd, n_c = len(pd_test), len(control_test)
        c_pd = np.sum(pd_test)
        c_c = np.sum(control_test)
        print(f"Total PD {n_pd} - Correct PD {c_pd}")
        print(f"Total Control {n_c} - Correct Control {c_c}")
        accuracy = (c_pd + c_c) / (n_pd + n_c)
        print(f"Acc: {accuracy}")

Total PD 20 - Correct PD 20
Total Control 10 - Correct Control 0
Acc: 0.6666666666666666
Total PD 20 - Correct PD 0
Total Control 10 - Correct Control 10
Acc: 0.3333333333333333
Total PD 20 - Correct PD 18
Total Control 10 - Correct Control 0
Acc: 0.6
Total PD 20 - Correct PD 20
Total Control 10 - Correct Control 0
Acc: 0.6666666666666666
Total PD 20 - Correct PD 0
Total Control 10 - Correct Control 9
Acc: 0.3
Total PD 20 - Correct PD 20
Total Control 10 - Correct Control 0
Acc: 0.6666666666666666
Total PD 20 - Correct PD 0
Total Control 10 - Correct Control 10
Acc: 0.3333333333333333
Total PD 20 - Correct PD 19
Total Control 10 - Correct Control 0
Acc: 0.6333333333333333
Total PD 20 - Correct PD 20
Total Control 10 - Correct Control 0
Acc: 0.6666666666666666
Total PD 20 - Correct PD 19
Total Control 10 - Correct Control 0
Acc: 0.6333333333333333
Total PD 20 - Correct PD 1
Total Control 10 - Correct Control 9
Acc: 0.3333333333333333
Total PD 20 - Correct PD 20
Total Control 10 - Correc

In [134]:
len(pd_test)
len(control_test)

9