In [1]:
import numpy as np
import pandas as pd
import json
import collections
from collections import defaultdict 
from functools import partial
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

import jax

# Global flag to set a specific platform, must be used at startup.
jax.config.update('jax_platform_name', 'gpu')
jax.config.update("jax_debug_nans", True)
jax.config.update("jax_debug_infs", True)

jax.config.update('jax_log_compiles', False)
jax.config.update('jax_check_tracer_leaks', False)

In [2]:
# Good read: https://iq-inc.com/importerror-attempted-relative-import/

import sys
import importlib
from mimicnet import concept
from mimicnet import jax_interface
from mimicnet import dag
from mimicnet import glove
from mimicnet import gram
from mimicnet import ode

importlib.reload(sys.modules['mimicnet.concept'])
importlib.reload(sys.modules['mimicnet.dag'])
importlib.reload(sys.modules['mimicnet.jax_interface'])
importlib.reload(sys.modules['mimicnet.glove'])
importlib.reload(sys.modules['mimicnet.gram'])
importlib.reload(sys.modules['mimicnet.ode'])

<module 'mimicnet.ode' from '/home/asem/GP/MIMIC-SNONET/mimicnet/ode.py'>

In [3]:

KG = dag.CCSDAG()

In [4]:
# multi_visit_mimic_dir = '/home/am8520/GP/ehr-data/mimic3-multi-visit'
multi_visit_mimic_dir = '/home/asem/GP/ehr-data/mimic3-multi-visit'
transformed_mimic_dir = '/home/asem/GP/ehr-data/mimic3-transforms'
mimic_dir = '/home/asem/GP/ehr-data/mimic3-v1.4/physionet.org/files/mimiciii/1.4'
experiments_dir = '/home/asem/GP/ehr-data/mimic3-snonet-exp'
experiment_prefix = 'NOV29'

In [5]:
D_LABITEMS = pd.read_csv(f'{mimic_dir}/D_LABITEMS.csv.gz')
D_ITEMS = pd.read_csv(f'{mimic_dir}/D_ITEMS.csv.gz')
D_TEST = pd.concat([D_LABITEMS, D_ITEMS], join='inner')
test_label_dict = dict(zip(D_TEST.ITEMID, D_TEST.LABEL))
test_cat_dict = dict(zip(D_TEST.ITEMID, D_TEST.CATEGORY))

In [6]:
static_df = pd.read_csv(f'{transformed_mimic_dir}/static_df.csv.gz')
adm_df = pd.read_csv(f'{transformed_mimic_dir}/adm_df.csv.gz')
diag_df = pd.read_csv(f'{transformed_mimic_dir}/diag_df.csv.gz', dtype={'ICD9_CODE': str})
proc_df = pd.read_csv(f'{transformed_mimic_dir}/proc_df.csv.gz', dtype={'ICD9_CODE': str})
test_df = pd.read_csv(f'{transformed_mimic_dir}/test_df.csv.gz')


In [7]:
# Cast columns of dates to datetime64

static_df['DOB'] = pd.to_datetime(static_df.DOB, infer_datetime_format=True).dt.normalize()
adm_df['ADMITTIME'] = pd.to_datetime(adm_df.ADMITTIME, infer_datetime_format=True).dt.normalize()
adm_df['DISCHTIME'] = pd.to_datetime(adm_df.DISCHTIME, infer_datetime_format=True).dt.normalize()
test_df['DATE'] = pd.to_datetime(test_df.DATE, infer_datetime_format=True).dt.normalize()

In [8]:
static_df.ETHNIC_GROUP.nunique()

7

In [9]:
test_df

Unnamed: 0,SUBJECT_ID,ITEMID,DATE,VALUE
0,17,50852,2134-12-22,-1.201339
1,17,50861,2134-12-22,-0.809370
2,17,50862,2134-12-22,1.634410
3,17,50863,2134-12-22,-1.452596
4,17,50867,2134-12-22,-0.608894
...,...,...,...,...
4546204,99982,227456,2157-02-22,1.794272
4546205,99982,227457,2157-02-22,-0.723781
4546206,99982,227465,2157-02-22,2.247434
4546207,99982,227466,2157-02-22,-0.029395


In [10]:
patients = concept.Subject.to_list(static_df, adm_df, diag_df, proc_df, test_df)

In [11]:
len(patients)

4434

In [12]:
# ehr_dfs = concept.Subject.to_df(patients)

In [13]:
# def df_eq(df1, df2):
#     df11 = df1.sort_values(by=df1.columns.tolist()).reset_index(drop=True)
#     df22 = df2.sort_values(by=df2.columns.tolist()).reset_index(drop=True)
#     return df11[df22.columns].equals(df22)

# all(map(df_eq, ehr_dfs, [static_df, adm_df, diag_df, proc_df, test_df]))

### [FORK] Skip the cell below to load the jaxified data from a stored file on disc

In [14]:
# subjects_interface = jax_interface.SubjectJAXInterface(patients, set(test_df.ITEMID), KG)
# import pickle
# with open(f'{experiments_dir}/{experiment_prefix}_subjects_interface.pkl', 'wb') as pickleFile:
#     pickle.dump(subjects_interface, pickleFile)

In [15]:
import pickle
with open(f'{experiments_dir}/{experiment_prefix}_subjects_interface.pkl', 'rb') as pickleFile:
    subjects_interface = pickle.load(pickleFile)

## GloVe Initialization

In [16]:
%load_ext autoreload
%autoreload 2

In [17]:
import logging
logging.basicConfig(level=logging.DEBUG)
logging.debug("test")

DEBUG:root:test


In [18]:
glove_args = {
    'diag_idx': subjects_interface.diag_multi_ccs_idx,
    'proc_idx': subjects_interface.proc_multi_ccs_idx,
    'ccs_dag': KG,
    'subjects': patients,
    'diag_vector_size': 100,
    'proc_vector_size': 60,
    'iterations': 30,
    'window_size_days': 2 * 365
}

diag_glove_rep, proc_glove_rep = glove.glove_representation(**glove_args)

In [19]:
print(f'#point_indices: {len(subjects_interface.nth_points)}')
print(f'#total_points: {sum(len(points) for n, points in subjects_interface.nth_points.items())}')

#[len(points) for n, points in subjects_interface.nth_points.items()]

#point_indices: 1085
#total_points: 129334


## GRAM objects

In [20]:
from datetime import datetime
daily_tracer = "/tmp/tensorboard/"+ datetime.now().strftime("%Y%m%d-%H%M%S") 
print(daily_tracer)

/tmp/tensorboard/20211202-222535


In [21]:


logs = '/tmp/tensorboard/20210708-182059'
#server = jax.profiler.start_server(9999)

In [22]:
       
config = {
    'gram_config': {
        'diag': {
            'ccs_dag': KG,
            'code2index': subjects_interface.diag_multi_ccs_idx,
            'attention_method': 'tanh', #l2, tanh
            'attention_dim': 100,
            'ancestors_mat': subjects_interface.diag_multi_ccs_ancestors_mat,
            'basic_embeddings': diag_glove_rep
        },
        'proc': {
            'ccs_dag': KG,
            'code2index': subjects_interface.proc_multi_ccs_idx,
            'attention_method': 'tanh',
            'attention_dim': 60,
            'ancestors_mat': subjects_interface.proc_multi_ccs_ancestors_mat,
            'basic_embeddings': proc_glove_rep
        }
    },
    'model': {
        'ode_dyn': 'mlp', # gru, mlp
        'state_size': 50,
        'numeric_hidden_size': 40,
        'bias': True
    },
    'training': {
        'train_validation_split': 0.8,
        'batch_size': 12,
        'epochs': 5,
        'lr': 1e-3,
        'diag_loss': 'balanced_focal', # balanced_focal, bce
        'tay_reg': None, # Order of regularized derivative of the dynamics function (None for disable).
        'loss_mixing': {
            'num_alpha': 0.1,
            'diag_alpha': 0.1,
            'ode_alpha': 1e-3,
            'l1_reg': 1e-6,
            'l2_reg': 1e-5,
            'dyn_reg': 1e-5
        },
        'eval_freq': 10,
        'save_freq': 100,
        'save_params_prefix': None
    }
}


In [23]:
diag_gram = gram.DAGGRAM(**config['gram_config']['diag'])

DEBUG:absl:Initializing backend 'interpreter'
DEBUG:absl:Backend 'interpreter' initialized
DEBUG:absl:Initializing backend 'cpu'
DEBUG:absl:Backend 'cpu' initialized
DEBUG:absl:Initializing backend 'tpu_driver'
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: 
DEBUG:absl:Initializing backend 'gpu'
DEBUG:absl:Backend 'gpu' initialized
DEBUG:absl:Initializing backend 'tpu'
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.


In [24]:
proc_gram = gram.DAGGRAM(**config['gram_config']['proc'])

## GRU-ODE-Bayes

In [None]:
import random
%load_ext autoreload
%autoreload 2
from absl import logging
logging.set_verbosity(logging.INFO)

#with jax.profiler.trace(logs):
res = ode.train_ehr(subject_interface=subjects_interface,
                diag_gram=diag_gram,
                proc_gram=proc_gram,
                rng=random.Random(42),
                **config['model'],
                **config['training'],
                verbose_debug=False,
                shape_debug=False,
                nan_debug=False,
                memory_profile=False)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


INFO:ode:#params: 293044
INFO:numexpr.utils:NumExpr defaulting to 8 threads.
info retrieval:   0%|          | 1/1477 [04:19<106:35:09, 259.97s/it]                 INFO:ode:
                        Training    Validation
prejump_num_loss       14.317818    917.747375
postjump_num_loss     890.755676  61111.664062
prejump_diag_loss       0.002124      0.076024
postjump_diag_loss      0.002105      0.074312
num_loss              101.961604   6937.139044
diag_loss               0.002122      0.075853
ode_loss                0.104082      7.012916
l1_loss             14247.754883  14247.754883
l1_loss_per_point      31.732194      0.562463
l2_loss              1785.656128   1785.656128
l2_loss_per_point       3.976962      0.070493
dyn_loss                 14200.0      775924.0
dyn_loss_per_week      10.344469     7.9545236
loss                 0.104449585      7.013223
INFO:ode:
                          Training     Valdation
accuracy                  0.506583      0.486989
recall        

info retrieval:   3%|▎         | 41/1477 [33:01<10:43:15, 26.88s/it]                 INFO:ode:
                        Training    Validation
prejump_num_loss       16.301437    978.553467
postjump_num_loss     781.436096  49125.906250
prejump_diag_loss       0.001322      0.049367
postjump_diag_loss      0.001341      0.049680
num_loss               92.814903   5793.288745
diag_loss               0.001324      0.049398
ode_loss                0.094137      5.842637
l1_loss             14348.020508  14348.020508
l1_loss_per_point      46.584482      0.566421
l2_loss              1760.824951   1760.824951
l2_loss_per_point       5.716964      0.069513
dyn_loss                  9208.0      782686.0
dyn_loss_per_week       8.288029      8.023846
loss                  0.09451132      5.842947
INFO:ode:
                          Training     Valdation
accuracy                  0.804049      0.799993
recall                    0.830409      0.780383
npv                       0.993489      0.9

INFO:ode:
                          Training     Valdation
accuracy                  0.793427      0.790047
recall                    0.768340      0.802502
npv                       0.986937      0.990483
specificity               0.794566      0.789568
precision                 0.145150      0.127777
f1-score                  0.244172      0.220453
tp                        0.033367      0.029687
tn                        0.760060      0.760360
fp                        0.196512      0.202647
fn                        0.010060      0.007306
points_count            229.000000  25331.000000
odeint_weeks_per_visit    4.089832      3.850815
nfe_per_point             0.000000      0.000000
nfe_per_week              0.000000      0.000000
nfex1000                  0.000000      0.000000
INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.000000        0.000000         0.009000          0.006429
p1       0.000000        0.000000         0.013609        

INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.000000        0.000000         0.007715          0.010929
p1       0.035714        0.107143         0.024556          0.044970
p2       0.191176        0.132353         0.364190          0.288063
p3       0.901961        0.882353         0.714646          0.747790
p4       1.000000        0.963636         0.998612          0.991396
info retrieval:   9%|▉         | 131/1477 [1:34:27<10:44:32, 28.73s/it]                 INFO:ode:
                        Training    Validation
prejump_num_loss       16.610407    995.001038
postjump_num_loss     618.306885  37928.191406
prejump_diag_loss       0.000999      0.046697
postjump_diag_loss      0.001009      0.047370
num_loss               76.780055   4688.320074
diag_loss               0.001000      0.046764
ode_loss                0.077779      4.735038
l1_loss             14359.715820  14359.715820
l1_loss_per_point      39.558446      0.566883
l2_loss 

info retrieval:  12%|█▏        | 171/1477 [2:03:17<9:36:20, 26.48s/it]                  INFO:ode:
                        Training    Validation
prejump_num_loss       17.161905    999.980896
postjump_num_loss     652.004089  37614.722656
prejump_diag_loss       0.001277      0.046602
postjump_diag_loss      0.001314      0.047388
num_loss               80.646124   4661.455072
diag_loss               0.001280      0.046680
ode_loss                0.081925      4.708089
l1_loss             14207.440430  14207.440430
l1_loss_per_point      57.057994      0.560872
l2_loss              1780.918457   1780.918457
l2_loss_per_point       7.152283      0.070306
dyn_loss                  8208.0      800272.0
dyn_loss_per_week      3.5327106      8.204131
loss                  0.08234761      4.708406
INFO:ode:
                          Training     Valdation
accuracy                  0.778169      0.785526
recall                    0.782051      0.791563
npv                       0.990139      

INFO:ode:
                          Training     Valdation
accuracy                  0.781530      0.784337
recall                    0.821918      0.806430
npv                       0.991776      0.990599
specificity               0.780063      0.783488
precision                 0.119522      0.125170
f1-score                  0.208696      0.216704
tp                        0.028809      0.029832
tn                        0.752721      0.754505
fp                        0.212228      0.208502
fn                        0.006242      0.007161
points_count            246.000000  25331.000000
odeint_weeks_per_visit    2.495354      3.850815
nfe_per_point             0.000000      0.000000
nfe_per_week              0.000000      0.000000
nfex1000                  0.000000      0.000000
INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.058824        0.058824         0.010608          0.011893
p1       0.000000        0.022727         0.007396        

INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.000000        0.000000         0.008357          0.009643
p1       0.000000        0.000000         0.020118          0.061834
p2       0.403509        0.350877         0.276492          0.312728
p3       0.875000        0.625000         0.890783          0.736427
p4       0.985915        1.000000         0.998890          0.999167
info retrieval:  18%|█▊        | 261/1477 [3:06:59<9:14:48, 27.38s/it]                  INFO:ode:
                        Training    Validation
prejump_num_loss       15.866828    987.392822
postjump_num_loss     564.509705  35631.312500
prejump_diag_loss       0.001748      0.046589
postjump_diag_loss      0.001754      0.047047
num_loss               70.731116   4451.784790
diag_loss               0.001749      0.046634
ode_loss                0.072478      4.498373
l1_loss             14219.660156  14219.661133
l1_loss_per_point      64.342354      0.561354
l2_loss 

info retrieval:  20%|██        | 301/1477 [3:34:00<7:19:50, 22.44s/it]                 INFO:ode:
                        Training    Validation
prejump_num_loss       16.942181    966.001831
postjump_num_loss     599.654907  34366.835938
prejump_diag_loss       0.001652      0.046669
postjump_diag_loss      0.001789      0.047602
num_loss               75.213453   4306.085242
diag_loss               0.001666      0.046762
ode_loss                0.076878      4.352801
l1_loss             14829.550781  14829.550781
l1_loss_per_point      56.601339      0.585431
l2_loss              1854.801880   1854.801880
l2_loss_per_point       7.079396      0.073223
dyn_loss                  8786.0      835450.0
dyn_loss_per_week       8.592065      8.564765
loss                  0.07730517      4.353132
INFO:ode:
                          Training     Valdation
accuracy                  0.788537      0.784912
recall                    0.835000      0.798876
npv                       0.991532      0

INFO:ode:
                          Training     Valdation
accuracy                  0.790074      0.789454
recall                    0.799020      0.776213
npv                       0.991068      0.989235
specificity               0.789757      0.789963
precision                 0.118632      0.124315
f1-score                  0.206591      0.214308
tp                        0.027331      0.028715
tn                        0.762743      0.760740
fp                        0.203052      0.202267
fn                        0.006875      0.008279
points_count            234.000000  25331.000000
odeint_weeks_per_visit    7.581197      3.850815
nfe_per_point             0.000000      0.000000
nfe_per_week              0.000000      0.000000
nfex1000                  0.000000      0.000000
INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.034483        0.103448         0.018322          0.021536
p1       0.093023        0.093023         0.141124        

INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.000000        0.000000         0.014143          0.015429
p1       0.068966        0.103448         0.100000          0.143787
p2       0.272727        0.295455         0.250305          0.233557
p3       0.783784        0.432432         0.738005          0.579545
p4       1.000000        0.971429         0.997225          0.974188
info retrieval:  26%|██▋       | 391/1477 [4:33:48<6:28:43, 21.48s/it]                  INFO:ode:
                        Training    Validation
prejump_num_loss       11.259975    922.281006
postjump_num_loss     383.115723  31522.070312
prejump_diag_loss       0.001451      0.046120
postjump_diag_loss      0.001509      0.047371
num_loss               48.445550   3982.259937
diag_loss               0.001457      0.046245
ode_loss                0.049901      4.028458
l1_loss             16201.416992  16201.416992
l1_loss_per_point      39.806921      0.639589
l2_loss 

info retrieval:  29%|██▉       | 431/1477 [5:01:51<6:53:30, 23.72s/it]                  INFO:ode:
                        Training    Validation
prejump_num_loss       12.925032   1023.044983
postjump_num_loss     448.784119  35414.164062
prejump_diag_loss       0.001403      0.047684
postjump_diag_loss      0.001491      0.048350
num_loss               56.510940   4462.156891
diag_loss               0.001412      0.047750
ode_loss                0.057922      4.509860
l1_loss             17802.902344  17802.906250
l1_loss_per_point      48.377452      0.702811
l2_loss              2160.539795   2160.539795
l2_loss_per_point       5.871032      0.085292
dyn_loss                 10786.0      835372.0
dyn_loss_per_week       7.066168      8.563966
loss                 0.058292434     4.5101905
INFO:ode:
                          Training     Valdation
accuracy                  0.781470      0.780561
recall                    0.816901      0.791322
npv                       0.992488      

INFO:ode:
                          Training     Valdation
accuracy                  0.788579      0.785831
recall                    0.753943      0.771741
npv                       0.984369      0.988973
specificity               0.790346      0.786372
precision                 0.154994      0.121862
f1-score                  0.257127      0.210486
tp                        0.036589      0.028549
tn                        0.751990      0.757281
fp                        0.199479      0.205725
fn                        0.011941      0.008444
points_count            472.000000  25331.000000
odeint_weeks_per_visit    4.059625      3.850815
nfe_per_point             0.000000      0.000000
nfe_per_week              0.000000      0.000000
nfex1000                  0.000000      0.000000
INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.014925        0.000000         0.023787          0.016715
p1       0.067797        0.000000         0.049408        

INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0        0.02500        0.000000         0.014465          0.008357
p1        0.12500        0.035714         0.090237          0.059172
p2        0.45614        0.421053         0.354141          0.382460
p3        0.62500        0.625000         0.637311          0.647096
p4        1.00000        0.931818         0.962254          0.922287
info retrieval:  35%|███▌      | 521/1477 [6:03:17<7:46:21, 29.27s/it]                  INFO:ode:
                        Training    Validation
prejump_num_loss       19.121511    976.768616
postjump_num_loss     675.844910  34039.500000
prejump_diag_loss       0.001656      0.050674
postjump_diag_loss      0.001634      0.048418
num_loss               84.793851   4283.041754
diag_loss               0.001654      0.050449
ode_loss                0.086446      4.333440
l1_loss             19263.277344  19263.277344
l1_loss_per_point      76.746125      0.760463
l2_loss 

info retrieval:  38%|███▊      | 561/1477 [6:30:16<7:02:20, 27.66s/it]                 INFO:ode:
                        Training    Validation
prejump_num_loss       14.201287    945.594482
postjump_num_loss     481.563202  31472.572266
prejump_diag_loss       0.001552      0.046459
postjump_diag_loss      0.001603      0.047699
num_loss               60.937479   3998.292261
diag_loss               0.001557      0.046583
ode_loss                0.062493      4.044829
l1_loss             19633.724609  19633.724609
l1_loss_per_point      86.492179      0.775087
l2_loss              2450.131104   2450.131348
l2_loss_per_point      10.793529      0.096725
dyn_loss                  7252.0      859324.0
dyn_loss_per_week       4.566751      8.809514
loss                  0.06295277      4.045169
INFO:ode:
                          Training     Valdation
accuracy                  0.758699      0.768196
recall                    0.784173      0.799903
npv                       0.991630      0

INFO:ode:
                          Training     Valdation
accuracy                  0.776634      0.770018
recall                    0.841785      0.810661
npv                       0.990563      0.990624
specificity               0.773599      0.768457
precision                 0.147634      0.118549
f1-score                  0.251211      0.206848
tp                        0.037468      0.029989
tn                        0.739166      0.740029
fp                        0.216324      0.222978
fn                        0.007042      0.007004
points_count            439.000000  25331.000000
odeint_weeks_per_visit    2.908884      3.850815
nfe_per_point             0.000000      0.000000
nfe_per_week              0.000000      0.000000
nfex1000                  0.000000      0.000000
INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.000000        0.000000         0.021536          0.007393
p1       0.039604        0.128713         0.116864        

INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.028986        0.028986         0.018322          0.023787
p1       0.120879        0.098901         0.156509          0.133432
p2       0.309524        0.214286         0.285627          0.294153
p3       0.491228        0.526316         0.434343          0.478220
p4       0.824561        0.842105         0.890092          0.892034
info retrieval:  44%|████▍     | 651/1477 [7:32:51<4:52:05, 21.22s/it]                  INFO:ode:
                        Training    Validation
prejump_num_loss       15.497334   1078.232422
postjump_num_loss     513.971497  35470.035156
prejump_diag_loss       0.001235      0.047014
postjump_diag_loss      0.001275      0.047979
num_loss               65.344751   4517.412695
diag_loss               0.001239      0.047110
ode_loss                0.066582      4.564476
l1_loss             19037.414062  19037.414062
l1_loss_per_point      55.021428      0.751546
l2_loss 

info retrieval:  47%|████▋     | 691/1477 [7:59:22<6:04:48, 27.85s/it]                 INFO:ode:
                        Training    Validation
prejump_num_loss       19.002375   1156.663452
postjump_num_loss     639.723511  39060.488281
prejump_diag_loss       0.002270      0.059566
postjump_diag_loss      0.002059      0.055743
num_loss               81.074488   4947.045935
diag_loss               0.002249      0.059183
ode_loss                0.083321      5.006170
l1_loss             19168.255859  19168.253906
l1_loss_per_point      78.881711      0.756711
l2_loss              2385.247070   2385.247070
l2_loss_per_point       9.815832      0.094163
dyn_loss                 10488.0      887152.0
dyn_loss_per_week        6.88318      9.094797
loss                 0.083880566      5.006521
INFO:ode:
                          Training     Valdation
accuracy                  0.773721      0.776693
recall                    0.769633      0.765456
npv                       0.989195      0

INFO:ode:
                          Training     Valdation
accuracy                  0.773239      0.778495
recall                    0.800971      0.765637
npv                       0.990394      0.988575
specificity               0.772196      0.778989
precision                 0.116856      0.117447
f1-score                  0.203956      0.203654
tp                        0.029049      0.028323
tn                        0.744190      0.750172
fp                        0.219542      0.212835
fn                        0.007218      0.008670
points_count            260.000000  25331.000000
odeint_weeks_per_visit    4.889560      3.850815
nfe_per_point             0.000000      0.000000
nfe_per_week              0.000000      0.000000
nfex1000                  0.000000      0.000000
INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.028571        0.000000         0.022822          0.030858
p1       0.088889        0.200000         0.109467        

INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.000000        0.016129         0.020572          0.018644
p1       0.112903        0.145161         0.093787          0.129290
p2       0.084746        0.152542         0.169915          0.221376
p3       0.781250        0.500000         0.688763          0.559028
p4       0.961538        0.980769         0.959201          0.948099
info retrieval:  53%|█████▎    | 781/1477 [9:01:01<4:21:07, 22.51s/it]                  INFO:ode:
                        Training    Validation
prejump_num_loss       16.392815   1109.787476
postjump_num_loss     549.186951  37319.085938
prejump_diag_loss       0.001625      0.049492
postjump_diag_loss      0.001672      0.049253
num_loss               69.672228   4730.717322
diag_loss               0.001630      0.049468
ode_loss                0.071301      4.780136
l1_loss             19536.496094  19536.496094
l1_loss_per_point      70.528867      0.771249
l2_loss 

info retrieval:  56%|█████▌    | 821/1477 [9:27:17<4:11:30, 23.00s/it]                 INFO:ode:
                        Training    Validation
prejump_num_loss       17.329493   1094.135010
postjump_num_loss     585.177917  36549.941406
prejump_diag_loss       0.001440      0.050245
postjump_diag_loss      0.001466      0.049459
num_loss               74.114335   4639.715649
diag_loss               0.001443      0.050166
ode_loss                0.075556      4.689832
l1_loss             19584.375000  19584.375000
l1_loss_per_point      71.475821      0.773139
l2_loss              2458.484863   2458.484863
l2_loss_per_point       8.972572      0.097054
dyn_loss                  9386.0      892168.0
dyn_loss_per_week       8.434146       9.14622
loss                 0.076014526     4.6901855
INFO:ode:
                          Training     Valdation
accuracy                  0.766308      0.765392
recall                    0.830097      0.790294
npv                       0.991248      0

INFO:ode:
                          Training     Valdation
accuracy                  0.767438      0.760373
recall                    0.812500      0.798574
npv                       0.990534      0.989907
specificity               0.765679      0.758906
precision                 0.119188      0.112876
f1-score                  0.207881      0.197795
tp                        0.030516      0.029542
tn                        0.736922      0.730832
fp                        0.225520      0.232175
fn                        0.007042      0.007451
points_count            240.000000  25331.000000
odeint_weeks_per_visit    4.856548      3.850815
nfe_per_point             0.000000      0.000000
nfe_per_week              0.000000      0.000000
nfex1000                  0.000000      0.000000
INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.020408        0.020408         0.018965          0.012536
p1       0.071429        0.119048         0.082249        

INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.037037        0.037037         0.012858          0.014143
p1       0.178571        0.214286         0.118935          0.139053
p2       0.275862        0.344828         0.327345          0.299026
p3       0.320000        0.320000         0.458333          0.428977
p4       0.967742        1.000000         0.991119          0.985012
info retrieval:  62%|██████▏   | 911/1477 [10:30:58<3:37:56, 23.10s/it]                 INFO:ode:
                        Training    Validation
prejump_num_loss       13.108520   1059.991943
postjump_num_loss     436.391235  35106.246094
prejump_diag_loss       0.001430      0.049438
postjump_diag_loss      0.001460      0.048819
num_loss               55.436791   4464.617358
diag_loss               0.001433      0.049376
ode_loss                0.056868      4.513944
l1_loss             19842.220703  19842.220703
l1_loss_per_point      42.763407      0.783318
l2_loss 

info retrieval:  64%|██████▍   | 951/1477 [11:00:05<4:38:12, 31.74s/it]                 INFO:ode:
                        Training    Validation
prejump_num_loss       17.980234   1056.454224
postjump_num_loss     612.331848  34982.953125
prejump_diag_loss       0.001095      0.049765
postjump_diag_loss      0.001101      0.048395
num_loss               77.415396   4449.104114
diag_loss               0.001096      0.049628
ode_loss                0.078510      4.498683
l1_loss             19902.640625  19902.640625
l1_loss_per_point      56.864688      0.785703
l2_loss              2520.234619   2520.234619
l2_loss_per_point       7.200670      0.099492
dyn_loss                 11344.0      895228.0
dyn_loss_per_week      10.278023      9.177589
loss                 0.078926794     4.4990373
INFO:ode:
                          Training     Valdation
accuracy                  0.759977      0.763237
recall                    0.789744      0.805221
npv                       0.989130      

INFO:ode:
                          Training     Valdation
accuracy                  0.759585      0.758147
recall                    0.820144      0.809573
npv                       0.993411      0.990419
specificity               0.757893      0.756171
precision                 0.086495      0.113117
f1-score                  0.156486      0.198499
tp                        0.022300      0.029949
tn                        0.737285      0.728198
fp                        0.235524      0.234809
fn                        0.004890      0.007044
points_count            305.000000  25331.000000
odeint_weeks_per_visit    4.580328      3.850815
nfe_per_point             0.000000      0.000000
nfe_per_week              0.000000      0.000000
nfex1000                  0.000000      0.000000
INFO:ode:
    Training(pre)  Training(post)  Validation(pre)  Validation(post)
p0       0.064516        0.032258         0.028608          0.024751
p1       0.269231        0.230769         0.194083        

#### Possible modifications:
- Add more layers to the adjustment function
- Use days instead of weeks for odeint