In [69]:
import pandas as pd
import numpy as np
from config import *
from xgboost import XGBClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import confusion_matrix
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
import matplotlib.pyplot as plt
import seaborn as sns

In [70]:
df = pd.read_parquet(DATA_DIR / "anti_saccade_processed.pq")

In [104]:
n_correct_trials_df = (df
 .query("stimulus_active == True")
 .sort_values(by=["participant_id", "trial_id", "stand_time"])
 .assign(stimulus_time = lambda x: np.select([x.event == "FIXPOINT", x.event != "FIXPOINT"], [x.stand_time, None]))
 .ffill()
 .assign(saccade_direction = lambda x: np.where(x["sacc_end_x"] > x["sacc_start_x"], "right", "left"))
 .assign(is_trial_correct = lambda x: np.where(x["saccade_direction"] != x["stimulus_side"], True, False))
 .query("event == 'ESACC'")
 .groupby(["experiment","participant_id", "trial_id"])
 .first()
 .reset_index()
 .groupby(["experiment","participant_id"])
 .agg(n_correct_trials = ('is_trial_correct', 'sum'),
      n_trials = ('is_trial_correct', 'count'))
 .reset_index()
 .assign(prop_correct_trials = lambda x: x["n_correct_trials"] / x["n_trials"])
 [["experiment", "participant_id", "prop_correct_trials"]]
)

  .ffill()


In [None]:

reaction_time_df = (df
 .query("stimulus_active == True")
 .sort_values(by=["participant_id", "trial_id", "stand_time"])
 .assign(stimulus_time = lambda x: np.select([x.event == "FIXPOINT", x.event != "FIXPOINT"], [x.stand_time, None]))
 .ffill()
 .assign(saccade_direction = lambda x: np.where(x["sacc_end_x"] > x["sacc_start_x"], "right", "left"))
 .assign(is_trial_correct = lambda x: np.where(x["saccade_direction"] != x["stimulus_side"], True, False))
 .query("event == 'ESACC'")
 .groupby(["experiment","participant_id", "trial_id", "is_trial_correct"])
 .first()
 .reset_index()
 .assign(reaction_time = lambda x: x["stand_start_time"] - x["stimulus_time"])
 .groupby(["experiment","participant_id","is_trial_correct"])
 .agg(mean_reaction_time = ('reaction_time', 'mean'))
 .reset_index()
 .pivot(index=["experiment", "participant_id"], columns="is_trial_correct",values="mean_reaction_time")
 .reset_index()
 .rename({True: 'correct_reaction_time',
          False: 'incorrect_reaction_time'}, axis=1)
)


  .ffill()


is_trial_correct,experiment,participant_id,incorrect_reaction_time,correct_reaction_time
0,ANTI_SACCADE,106,254.851530,284.381760
1,ANTI_SACCADE,111,210.245131,307.381760
2,ANTI_SACCADE,113,288.445679,160.612346
3,ANTI_SACCADE,121,418.045996,717.879329
4,ANTI_SACCADE,122,851.298322,252.131656
...,...,...,...,...
165,ANTI_SACCADE,399,394.128375,285.631760
166,ANTI_SACCADE,401,523.551108,269.819260
167,ANTI_SACCADE,402,341.257276,303.881760
168,ANTI_SACCADE,403,351.841997,313.756760


In [85]:
features = (df.groupby(["experiment", "participant_id"])
 .agg({'peak_velocity': [np.mean, np.min, np.max, np.median, np.std],
       'amplitude': [np.mean, np.min, np.max, np.median, np.std],
       'duration': [np.mean, np.min, np.max, np.median, np.std],
       'avg_pupil_size': [np.mean, np.min, np.max, np.median, np.std]
       })
 .reset_index()
)
    
features.columns = [''.join(col).strip() for col in features.columns.values]


  .agg({'peak_velocity': [np.mean, np.min, np.max, np.median, np.std],
  .agg({'peak_velocity': [np.mean, np.min, np.max, np.median, np.std],
  .agg({'peak_velocity': [np.mean, np.min, np.max, np.median, np.std],
  .agg({'peak_velocity': [np.mean, np.min, np.max, np.median, np.std],
  .agg({'peak_velocity': [np.mean, np.min, np.max, np.median, np.std],
  .agg({'peak_velocity': [np.mean, np.min, np.max, np.median, np.std],


In [107]:
features = pd.merge(features, reaction_time_df, left_on=["experiment","participant_id"], right_on=["experiment", "participant_id"], how='left')
features = pd.merge(features, n_correct_trials_df, left_on=["experiment","participant_id"], right_on=["experiment", "participant_id"], how='left')

# Load outcome

In [108]:
demographics = pd.read_excel(DATA_DIR / "demographic_info.xlsx")[["ID", "Group"]]

demographics["y"] = (demographics["Group"] == "PATIENT").astype(int)
demographics["participant_id"] = demographics["ID"].astype(str)
demographics = demographics[["participant_id", "y"]]



# Model training

In [136]:
data

Unnamed: 0,experiment,participant_id,peak_velocitymean,peak_velocitymin,peak_velocitymax,peak_velocitymedian,peak_velocitystd,amplitudemean,amplitudemin,amplitudemax,...,avg_pupil_sizemedian,avg_pupil_sizestd,incorrect_reaction_time_x,correct_reaction_time_x,incorrect_reaction_time_y,correct_reaction_time_y,incorrect_reaction_time,correct_reaction_time,prop_correct_trials,y
0,ANTI_SACCADE,106,210.209091,44.0,400.0,215.0,98.088061,4.626727,0.25,16.88,...,1569.0,158.880831,254.851530,284.381760,254.851530,284.381760,254.851530,284.381760,0.875000,1
1,ANTI_SACCADE,111,134.448630,39.0,441.0,94.0,95.472459,2.120034,0.00,18.52,...,1115.5,74.008895,210.245131,307.381760,210.245131,307.381760,210.245131,307.381760,0.375000,0
2,ANTI_SACCADE,113,312.754237,46.0,1215.0,262.0,244.816168,3.723190,0.23,14.43,...,1528.0,147.246028,288.445679,160.612346,288.445679,160.612346,288.445679,160.612346,0.666667,0
3,ANTI_SACCADE,121,326.761062,40.0,2247.0,238.0,332.641599,3.097615,0.23,16.35,...,1825.0,182.124466,418.045996,717.879329,418.045996,717.879329,418.045996,717.879329,0.500000,1
4,ANTI_SACCADE,122,298.308271,44.0,886.0,293.0,174.325678,5.100827,0.20,23.07,...,2349.0,168.067103,851.298322,252.131656,851.298322,252.131656,851.298322,252.131656,0.833333,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
165,ANTI_SACCADE,399,143.274510,43.0,491.0,82.0,104.099537,1.876928,0.05,9.79,...,1758.0,211.594525,394.128375,285.631760,394.128375,285.631760,394.128375,285.631760,0.875000,0
166,ANTI_SACCADE,401,318.266667,46.0,2169.0,235.0,345.185988,4.860333,0.19,17.65,...,2761.0,158.323412,523.551108,269.819260,523.551108,269.819260,523.551108,269.819260,0.875000,0
167,ANTI_SACCADE,402,250.290076,34.0,1293.0,110.0,250.869028,3.436260,0.02,17.05,...,1461.0,438.021165,341.257276,303.881760,341.257276,303.881760,341.257276,303.881760,0.937500,0
168,ANTI_SACCADE,403,188.622222,42.0,418.0,175.5,97.970023,4.277889,0.28,12.88,...,2356.0,277.240905,351.841997,313.756760,351.841997,313.756760,351.841997,313.756760,0.875000,0


In [163]:
data = pd.merge(features, demographics, how='left', on='participant_id')
y_data = data["y"]
X_data = data[["correct_reaction_time", "incorrect_reaction_time", "prop_correct_trials"]]
X_data
#X_data = data.drop(["experiment", "participant_id", "y"], axis=1)

Unnamed: 0,correct_reaction_time,incorrect_reaction_time,prop_correct_trials
0,284.381760,254.851530,0.875000
1,307.381760,210.245131,0.375000
2,160.612346,288.445679,0.666667
3,717.879329,418.045996,0.500000
4,252.131656,851.298322,0.833333
...,...,...,...
165,285.631760,394.128375,0.875000
166,269.819260,523.551108,0.875000
167,303.881760,341.257276,0.937500
168,313.756760,351.841997,0.875000


In [164]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, test_size=.2)

pipe = Pipeline([
    ("scaler", StandardScaler()),
    ("clf", XGBClassifier(n_estimators=10, max_depth=6, learning_rate=1, objective='binary:logistic'))
])

print(pipe.fit(X_train, y_train).score(X_test, y_test))

results=pd.DataFrame()
results['columns']=X_train.columns
results['importances'] = pipe["clf"].feature_importances_
results.sort_values(by='importances',ascending=False,inplace=True)

results

0.6176470588235294


Unnamed: 0,columns,importances
0,correct_reaction_time,0.441323
2,prop_correct_trials,0.28074
1,incorrect_reaction_time,0.277937
