In [11]:
import pandas as pd
import sqlalchemy

db_name = "seg/seg_evals.sqlite"

engine = sqlalchemy.create_engine("sqlite:///%s" % db_name, execution_options={"sqlite_raw_colnames": True},
                                 connect_args={'timeout': 5})


In [12]:

df = {name: pd.read_sql_table(name, engine) for name in 
      ('evaluation_runs', 'evaluation_status', 'experiments', 'metrics', 'results')}

In [13]:
df['results']

Unnamed: 0,id,evaluation_id,metric_id,value
0,1,2,1,7.243400e+04
1,2,2,2,
2,3,2,3,8.929600e+04
3,4,2,4,5.330500e+04
4,5,2,5,
...,...,...,...,...
3771,3772,6,468,2.425939e-01
3772,3773,6,469,1.417415e-01
3773,3774,6,470,-inf
3774,3775,6,471,3.578369e-01


In [14]:
metrics = df['metrics'].rename(columns={'name': 'metric', 'id': 'metric_id'})
eval_runs = df['evaluation_runs'].rename(columns={'id': 'evaluation_id'})
experiments = df['experiments'].rename(columns={'name': 'experiment', 'id': 'experiment_id'})
eval_status = df['evaluation_status'].rename(columns={'id': 'status_id', 'name': 'status'})

res = pd.merge(df['results'], metrics, left_on='metric_id', right_on='metric_id', how='left')
res = pd.merge(res, eval_runs, left_on='evaluation_id', right_on='evaluation_id', how='left')
res = pd.merge(res, experiments, left_on='experiment_id', right_on='experiment_id', how='left')
res = pd.merge(res, eval_status, left_on='status_id', right_on='status_id', how='left')


In [15]:
res['step'] = res.snapshot.str.extract('(\d+)').astype(int)

In [16]:
res

Unnamed: 0,id,evaluation_id,metric_id,value,metric,created_at,updated_at,status_id,experiment_id,snapshot,experiment,status,step
0,1,2,1,7.243400e+04,LeadingAKIPredictionAccuracy.n_timestamps_nega...,2024-08-07 07:02:24.982616,2024-08-08 08:44:56.853237,2,2,step7260.eqx,monotonic_bce_inicenodelite_gru_g0,FINISHED,7260
1,2,2,2,,LeadingAKIPredictionAccuracy.n_timestamps_unknown,2024-08-07 07:02:24.982616,2024-08-08 08:44:56.853237,2,2,step7260.eqx,monotonic_bce_inicenodelite_gru_g0,FINISHED,7260
2,3,2,3,8.929600e+04,LeadingAKIPredictionAccuracy.n_timestamps_firs...,2024-08-07 07:02:24.982616,2024-08-08 08:44:56.853237,2,2,step7260.eqx,monotonic_bce_inicenodelite_gru_g0,FINISHED,7260
3,4,2,4,5.330500e+04,LeadingAKIPredictionAccuracy.n_timestamps_late...,2024-08-07 07:02:24.982616,2024-08-08 08:44:56.853237,2,2,step7260.eqx,monotonic_bce_inicenodelite_gru_g0,FINISHED,7260
4,5,2,5,,LeadingAKIPredictionAccuracy.n_timestamps_reco...,2024-08-07 07:02:24.982616,2024-08-08 08:44:56.853237,2,2,step7260.eqx,monotonic_bce_inicenodelite_gru_g0,FINISHED,7260
...,...,...,...,...,...,...,...,...,...,...,...,...,...
3771,3772,6,468,2.425939e-01,ObsPredictionLossMetric.mae,2024-08-08 03:47:28.986545,2024-08-13 04:43:47.554999,2,6,step0350.eqx,monotonic_bce_inicenodeliteicnn_g0,FINISHED,350
3772,3773,6,469,1.417415e-01,ObsPredictionLossMetric.mse,2024-08-08 03:47:28.986545,2024-08-13 04:43:47.554999,2,6,step0350.eqx,monotonic_bce_inicenodeliteicnn_g0,FINISHED,350
3773,3774,6,470,-inf,ObsPredictionLossMetric.r2,2024-08-08 03:47:28.986545,2024-08-13 04:43:47.554999,2,6,step0350.eqx,monotonic_bce_inicenodeliteicnn_g0,FINISHED,350
3774,3775,6,471,3.578369e-01,ObsPredictionLossMetric.rms,2024-08-08 03:47:28.986545,2024-08-13 04:43:47.554999,2,6,step0350.eqx,monotonic_bce_inicenodeliteicnn_g0,FINISHED,350


In [17]:
res.experiment.unique()

array(['monotonic_bce_inicenodelite_gru_g0', 'monotonic_bce_inkoopman_g0',
       'monotonic_bce_gruodebayes_g0', 'mlp_mse_inicenodelite_g0',
       'monotonic_bce_inicenodelite_g0', 'monotonic_mse_inkoopman_g0',
       'mlp_mse_inicenodeliteicnn_g0',
       'monotonic_bce_inicenodeliteicnn_g0'], dtype=object)

In [18]:
res.snapshot.unique()

array(['step7260.eqx', 'step1640.eqx', 'step0120.eqx', 'step6770.eqx',
       'step6570.eqx', 'step2160.eqx', 'step0260.eqx', 'step0350.eqx'],
      dtype=object)

In [5]:
def model_name(exp):
    if 'icenodelite' in exp:
        return 'ICE-NODE-L'
    if 'icenode' in exp:
        return 'ICE-NODE'
    if 'retain' in exp:
        return 'RETAIN'
    if exp.endswith('gru'):
        return 'GRU'
    if exp.endswith('grujump'):
        return 'GRU-Jump'
    if 'koopman' in exp:
        return 'Koopman'
    
def loss_name(exp):
    if 'mse' in exp:
        return 'mse'
    if 'dtw' in exp:
        return 'soft-dtw'
    if 'mae' in exp:
        return 'mae'
    
def predictor_name(exp):
    if '48' in exp:
        suffix = '48'
    else:
        suffix = ''
        
    if 'mlp' in exp:
        return f'MLP{suffix}'
    if 'monotonic' in exp:
        return f'monotonic{suffix}'

def state_modularity(exp):
    if 'modular' in exp:
        return 'Modular'
    else:
        return ''

In [6]:
res['model'] = res.experiment.map(model_name)
res['loss'] = res.experiment.map(loss_name)
res['predictor'] = res.experiment.map(predictor_name)
res['state_modularity'] =  res.experiment.map(state_modularity)

In [7]:
res[(res['snapshot'] == 'step0480.eqx') & (res['experiment'] == 'modular_mlp_mse_icenodelite48') ]

Unnamed: 0,id,evaluation_id,metric_id,value,metric,created_at,updated_at,status_id,experiment_id,snapshot,experiment,status,step,model,loss,predictor,state_modularity
67705,67706,666,1,0.098572,LossMetric.obs_mae,2024-01-15 02:59:11.080996,2024-01-19 15:25:42.146663,2,8,step0480.eqx,modular_mlp_mse_icenodelite48,FINISHED,480,ICE-NODE-L,mse,MLP48,Modular
67706,67707,666,2,0.027929,LossMetric.obs_mse,2024-01-15 02:59:11.080996,2024-01-19 15:25:42.146663,2,8,step0480.eqx,modular_mlp_mse_icenodelite48,FINISHED,480,ICE-NODE-L,mse,MLP48,Modular
67707,67708,666,3,0.167119,LossMetric.obs_rms,2024-01-15 02:59:11.080996,2024-01-19 15:25:42.146663,2,8,step0480.eqx,modular_mlp_mse_icenodelite48,FINISHED,480,ICE-NODE-L,mse,MLP48,Modular
67708,67709,666,4,0.21438,LossMetric.lead_mae,2024-01-15 02:59:11.080996,2024-01-19 15:25:42.146663,2,8,step0480.eqx,modular_mlp_mse_icenodelite48,FINISHED,480,ICE-NODE-L,mse,MLP48,Modular
67709,67710,666,5,0.070077,LossMetric.lead_mse,2024-01-15 02:59:11.080996,2024-01-19 15:25:42.146663,2,8,step0480.eqx,modular_mlp_mse_icenodelite48,FINISHED,480,ICE-NODE-L,mse,MLP48,Modular
67710,67711,666,6,0.21438,LossMetric.lead_rms,2024-01-15 02:59:11.080996,2024-01-19 15:25:42.146663,2,8,step0480.eqx,modular_mlp_mse_icenodelite48,FINISHED,480,ICE-NODE-L,mse,MLP48,Modular
67711,67712,666,7,0.140154,LossMetric.lead_softdtw(0.1),2024-01-15 02:59:11.080996,2024-01-19 15:25:42.146663,2,8,step0480.eqx,modular_mlp_mse_icenodelite48,FINISHED,480,ICE-NODE-L,mse,MLP48,Modular
67712,67713,666,8,16483.588653,LossMetric.eval_time,2024-01-15 02:59:11.080996,2024-01-19 15:25:42.146663,2,8,step0480.eqx,modular_mlp_mse_icenodelite48,FINISHED,480,ICE-NODE-L,mse,MLP48,Modular
67713,67714,666,9,43543.0,LeadingPredictionAccuracy.n_timestamps_negative,2024-01-15 02:59:11.080996,2024-01-19 15:25:42.146663,2,8,step0480.eqx,modular_mlp_mse_icenodelite48,FINISHED,480,ICE-NODE-L,mse,MLP48,Modular
67714,67715,666,10,,LeadingPredictionAccuracy.n_timestamps_unknown,2024-01-15 02:59:11.080996,2024-01-19 15:25:42.146663,2,8,step0480.eqx,modular_mlp_mse_icenodelite48,FINISHED,480,ICE-NODE-L,mse,MLP48,Modular


In [8]:
res.columns

Index(['id', 'evaluation_id', 'metric_id', 'value', 'metric', 'created_at',
       'updated_at', 'status_id', 'experiment_id', 'snapshot', 'experiment',
       'status', 'step', 'model', 'loss', 'predictor', 'state_modularity'],
      dtype='object')

In [9]:
res = res.sort_values(['experiment_id', 'step'])
res['last_max'] = float('nan')
res['last_min'] = float('nan')
res['is_max'] = False
res['is_min'] = False

for exp, exp_df in res.groupby('experiment_id'):
    for metric, metric_df in exp_df.groupby('metric'):
        index = metric_df.index
        res.loc[index, 'last_max'] = metric_df['value'].cummax()
        res.loc[index, 'last_min'] = metric_df['value'].cummin()
        res.loc[index, 'is_max'] = metric_df['value'] == res.loc[index, 'last_max']
        res.loc[index, 'is_min'] = metric_df['value'] == res.loc[index, 'last_min']
        

In [10]:
res.metric.unique()


array(['LossMetric.obs_mae', 'LossMetric.obs_mse', 'LossMetric.obs_rms',
       'LossMetric.lead_mae', 'LossMetric.lead_mse',
       'LossMetric.lead_rms', 'LossMetric.lead_softdtw(0.1)',
       'LossMetric.eval_time',
       'LeadingPredictionAccuracy.n_timestamps_negative',
       'LeadingPredictionAccuracy.n_timestamps_unknown',
       'LeadingPredictionAccuracy.n_timestamps_first_pre_emergence',
       'LeadingPredictionAccuracy.n_timestamps_later_pre_emergence',
       'LeadingPredictionAccuracy.n_timestamps_recovery_window',
       'LeadingPredictionAccuracy.n_admissions_negative',
       'LeadingPredictionAccuracy.n_admissions_unknown',
       'LeadingPredictionAccuracy.n_admissions_first_pre_emergence',
       'LeadingPredictionAccuracy.n_admissions_later_pre_emergence',
       'LeadingPredictionAccuracy.n_admissions_recovery_window',
       'LeadingPredictionAccuracy.n_timestamps_first_pre_emergence_1-72',
       'LeadingPredictionAccuracy.AUC_first_pre_emergence_1-72',
      

In [11]:
# METRIC = 'LossMetric.lead_mse'
METRIC = 'LeadingPredictionAccuracy.AUC_pre_emergence_48-72'
# MODEL = 'ICE-NODE-L'
# MODEL = 'ICE-NODE'

# MODEL = 'GRU'
# MODEL = 'GRU-Jump'
# MODEL = 'Koopman'
# MODEL = 'RETAIN'
LOSS = 'mse'
PREDICTOR = 'MLP'
MODELS = ['ICE-NODE-L', 'GRU-Jump']
res_metric = res[res['metric'] == METRIC]
# res_metric = res_metric[res_metric['model'] == MODEL]
res_metric = res_metric[res_metric['model'].isin(MODELS)]

res_metric = res_metric[res_metric['loss'] == LOSS]
res_metric = res_metric[res_metric['predictor'] == PREDICTOR]
res_metric = res_metric[~res_metric['state_modularity'].str.startswith('Modular')]

In [12]:
res_metric

Unnamed: 0,id,evaluation_id,metric_id,value,metric,created_at,updated_at,status_id,experiment_id,snapshot,...,status,step,model,loss,predictor,state_modularity,last_max,last_min,is_max,is_min
53238,53239,1012,54,0.649953,LeadingPredictionAccuracy.AUC_pre_emergence_48-72,2024-01-15 14:13:02.141582,2024-01-15 14:59:38.496145,2,16,step0515.eqx,...,FINISHED,515,GRU-Jump,mse,MLP,,0.649953,0.649953,True,True
52743,52744,1007,54,0.644476,LeadingPredictionAccuracy.AUC_pre_emergence_48-72,2024-01-15 13:59:19.808327,2024-01-15 14:39:03.344705,2,16,step1030.eqx,...,FINISHED,1030,GRU-Jump,mse,MLP,,0.649953,0.644476,False,True
52413,52414,1003,54,0.689997,LeadingPredictionAccuracy.AUC_pre_emergence_48-72,2024-01-15 13:50:56.135667,2024-01-15 14:32:29.470032,2,16,step1545.eqx,...,FINISHED,1545,GRU-Jump,mse,MLP,,0.689997,0.644476,True,False
52248,52249,999,54,0.627081,LeadingPredictionAccuracy.AUC_pre_emergence_48-72,2024-01-15 13:45:46.185605,2024-01-15 14:26:38.296637,2,16,step2060.eqx,...,FINISHED,2060,GRU-Jump,mse,MLP,,0.689997,0.627081,False,True
52138,52139,994,54,0.682980,LeadingPredictionAccuracy.AUC_pre_emergence_48-72,2024-01-15 13:38:21.967060,2024-01-15 14:20:45.244638,2,16,step2576.eqx,...,FINISHED,2576,GRU-Jump,mse,MLP,,0.689997,0.627081,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
113903,113904,2085,54,0.689357,LeadingPredictionAccuracy.AUC_pre_emergence_48-72,2024-02-12 08:19:42.096889,2024-02-12 15:09:42.467186,2,21,step4590.eqx,...,FINISHED,4590,ICE-NODE-L,mse,MLP,,0.692035,0.608219,False,False
111593,111594,2050,54,0.691031,LeadingPredictionAccuracy.AUC_pre_emergence_48-72,2024-02-12 00:49:28.505722,2024-02-12 06:40:21.717422,2,21,step4650.eqx,...,FINISHED,4650,ICE-NODE-L,mse,MLP,,0.692035,0.608219,False,False
111428,111429,2043,54,0.691900,LeadingPredictionAccuracy.AUC_pre_emergence_48-72,2024-02-11 23:23:52.207091,2024-02-12 05:38:06.348558,2,21,step4665.eqx,...,FINISHED,4665,ICE-NODE-L,mse,MLP,,0.692035,0.608219,False,False
113078,113079,2066,54,0.683711,LeadingPredictionAccuracy.AUC_pre_emergence_48-72,2024-02-12 04:18:15.273278,2024-02-12 12:42:25.450917,2,21,step4680.eqx,...,FINISHED,4680,ICE-NODE-L,mse,MLP,,0.692035,0.608219,False,False


In [13]:
import numpy as np

from bokeh.plotting import figure, show, curdoc
from bokeh.io import output_notebook, export_svgs
output_notebook()

In [14]:
from bokeh.palettes import  mpl, small_palettes, viridis,inferno, cividis, YlOrRd4, Spectral

In [15]:
res_metric.experiment.unique()

array(['onestate_mlp_mse_ingrujump', 'onestate_mlp_mse_inicenodelite'],
      dtype=object)

In [16]:
p = figure(y_axis_label=METRIC, x_axis_label="Training Step")

colors = palette = Spectral[res_metric.experiment.nunique() + 3]
res_metric = res_metric.sort_values('step')
for i, (exp, df) in enumerate(res_metric.groupby('experiment')):
    color = colors[i + 3]
    model_label = df['model'].iloc[0]
    loss_label = df['loss'].iloc[0]
    modularity = df['state_modularity'].iloc[0]
    predictor_label = df['predictor'].iloc[0]
    
    label = f'{" ".join((modularity, model_label))} ({loss_label}) ({predictor_label})'
    
    p.line(x='step', y='last_max', color=color,
           line_width=4, legend_label=label, source=df)
    p.scatter(x='step', y='value', color=color,
           line_width=2, legend_label=label, source=df[df['is_max']])
    
p.legend.location = "bottom_right"
p.yaxis.axis_label = 'Prediction AUC 48-hours in-advance'
p.legend.label_text_font_size = '16pt'

curdoc().theme = 'caliber'
p.xaxis.axis_label_text_font_size = "20pt"
p.yaxis.axis_label_text_font_size = "20pt"
p.xaxis.major_label_text_font_size = '20px'
p.yaxis.major_label_text_font_size = '20px'

show(p)

In [17]:
p.output_backend = "svg"
export_svgs(p, filename="aki_prediction.svg")

['aki_prediction.svg']