In [1]:
%load_ext autoreload
%autoreload 2

import sys
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

sys.path.append("../../..")

import lib.ehr.example_datasets.mimiciv_aki as m4aki
from lib.ehr.tvx_ehr import TVxEHR
from lib.utils import modified_environ, write_config


In [8]:
# create a database connection
sqluser = 'postgres'
dbname = 'mimiciv'
hostname = 'localhost'
password = 'qwerasdf'
port = '5432'

url = f'postgresql+psycopg2://{sqluser}:{password}@{hostname}:{port}/{dbname}'

with modified_environ(MIMICIV_URL=url):
    mimiciv_config = m4aki.AKIMIMICIVDatasetConfig()
    tvx_config = m4aki.TVxAKIMIMICIVDatasetConfig()

    write_config(mimiciv_config.to_dict(), 'dataset_mimiciv_aki_config.json')
    write_config(tvx_config.to_dict(), 'tvx_mimiciv_aki_config.json')


In [9]:
# with modified_environ(MIMICIV_URL=url):
#     dataset = m4aki.AKIMIMICIVDataset(config=mimiciv_config)           
# dataset = dataset.execute_pipeline()
# tvx = m4aki.TVxAKIMIMICIVDataset(config=tvx_config, dataset=dataset)
# tvx = tvx.execute_pipeline()

In [4]:
# tvx.dataset.pipeline_report.to_csv('dataset_pipeline_report.csv')
# tvx.pipeline_report.to_csv('tvx_pipeline_report.csv')
# tvx.save('tvx_aki.h5', True)

In [3]:
tvx = m4aki.TVxAKIMIMICIVDataset.load('/home/asem/GP/ehr-data/mimic4aki-cohort/tvx_aki_tb6.h5')

In [8]:
print('x')

In [5]:
# len(tvx.subjects)

In [6]:
tvx.config.splits

In [6]:
# from lib.ehr.tvx_transformations import TrainingSplitGroups
# tvx_list = TrainingSplitGroups()(tvx, n_groups=10, seed=0)

In [7]:
# for i, tvx_item in enumerate(tvx_list):
#     tvx_item.save(f'tvx_aki_training_groups/tvx_aki_{i}.h5', True)

In [11]:
tvx0 = m4aki.TVxAKIMIMICIVDataset.load('tvx_aki_training_groups/tvx_aki_0.h5')

In [9]:
len(tvx0.subjects)

In [15]:
from lib.ml.embeddings import InICENODEEmbeddingsConfig, InterventionsEmbeddingsConfig
from lib.ml.in_models import InpatientModelConfig, ICENODEConfig, InICENODELite, GRUODEBayes
from lib.ml.model import Precomputes


In [16]:
emb_config = InICENODEEmbeddingsConfig(dx_codes=50, demographic=10, 
                                       interventions=InterventionsEmbeddingsConfig(icu_inputs=10,
                                                                                   icu_procedures=10,
                                                                                   hosp_procedures=10,
                                                                                   interventions=20))

model_config = ICENODEConfig(state=50, lead_predictor='monotonic')

In [17]:
import jax.random as jrandom

# model = InICENODELite.from_tvx_ehr(tvx_ehr=tvx0, config=model_config, embeddings_config=emb_config)
model = GRUODEBayes.from_tvx_ehr(tvx_ehr=tvx0, config=model_config, embeddings_config=emb_config)


In [13]:
tvx0.config.leading_observable

In [14]:
tvx0.subjects['10002760'].admissions[0]

In [15]:
adm = tvx0.subjects['10002760'].admissions[0]
admission_emb = model.f_emb(adm, tvx0.admission_demographics[adm.admission_id])

In [16]:
adm.leading_observable.mask

In [17]:
out = model(admission=adm, embedded_admission=admission_emb, precomputes=Precomputes())

In [18]:
out.leading_observable.value

In [22]:
from lib.ml.experiment import Experiment, ExperimentConfig
from lib.ml.trainer import Trainer, TrainerConfig, OptimizerConfig, ReportingConfig, LossMixer

opt = OptimizerConfig()
trainer_config=TrainerConfig(obs_loss='mse', lead_loss='mse', optimizer=opt)
reporting_config = ReportingConfig(output_dir='test',
                                   console=True,
                                   parameter_snapshots=True,
                                   config_json=True,
                                   model_stats=False)
loss_mix = LossMixer()


In [23]:
# from lib.ml.experiment import Experiment, ExperimentConfig
# from lib.ml.trainer import ProbTrainer, ProbTrainerConfig, ProbLossMixer, OptimizerConfig, ReportingConfig
# opt = OptimizerConfig()
# trainer_config=ProbTrainerConfig(prob_obs_loss='log_normal', prob_adjusted_obs_loss='kl_gaussian', lead_loss='mse', optimizer=opt)
# reporting_config = ReportingConfig(output_dir='test',
#                                    console=True,
#                                    parameter_snapshots=True,
#                                    config_json=True,
#                                    model_stats=False)
# loss_mix = ProbLossMixer()


In [24]:
experiment_config = ExperimentConfig(model=model_config,
                                       embeddings=emb_config,
                                       trainer=trainer_config,
                                       model_classname='GRUODEBayes',
                                       trainer_classname='ProbTrainer',
                                       reporting=reporting_config,
                                       model_snapshot_frequency=10,
                                     continue_training=True,
                                     loss_mixer=loss_mix)
                                     
               
               

In [25]:
from lib.utils import write_config
write_config(experiment_config.to_dict(), 'prob_config_template.json')

In [6]:
import equinox as eqx
import jax.numpy as jnp

eqx.filter_vmap(lambda a, b,c: a+b+c)(jnp.arange(10), jnp.arange(10), None)