In [None]:
import pandas as pd, numpy as np
from pathlib import Path
from vivarium import Artifact
import db_queries
from get_draws.api import get_draws
# import matplotlib.pyplot as 

from vivarium_testing_utils.automated_validation import ValidationContext

from vivarium_gates_mncnh.validation.measures import NeonatalCauseSpecificMortalityRisk

In [None]:
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

In [None]:
base_results_dir = Path('/mnt/team/simulation_science/pub/models/vivarium_gates_mncnh/results/model21.0.1/')
locations = []
results_dirs = []
for location_path in base_results_dir.iterdir():
    location = location_path.name
    locations.append(location)

    timestamps = sorted([p.name for p in location_path.iterdir()])
    last_timestamp = timestamps[-1]
    if len(timestamps) > 1:
        print(f'Multiple timestamps: {timestamps}, using {last_timestamp}')

    results_dirs.append(str(Path(location) / last_timestamp / 'results/'))

location_ids = db_queries.get_ids('location')
location_ids = location_ids.loc[location_ids.location_name.isin([x.title() for x in locations])]
#locations = list(location_ids.location_name.values)
#locations = [x.lower() for x in locations]
#location_ids = list(location_ids.location_id.values)
#location_ids, locations, results_dirs

results_dict = {
    location: results_dirs[locations.index(location)]
    for location
    in locations
}
results_dict

In [None]:
results_path = base_results_dir + results["ethiopia"]

In [None]:
# Create comparison
vc = ValidationContext(results_path, scenario_columns=("scenario"))

In [None]:
# List outputs
vc.get_sim_outputs()

In [None]:
# Artifact keys
vc.get_artifact_keys()

In [None]:
# Subset to neonatal mortality risk artifact keys
[key for key in vc.get_artifact_keys() if "mortality_risk" in key]

In [None]:
preterm_birth_comare_key = "cause.neonatal_preterm_birth.mortality_risk"

In [None]:
# Add custom NeonatalCauseSpecificMortalityRisk measure class measure mapper
# This allows the ValidationContext to use custom measure classes that are not 
# included in the standard measure classes in VTU
vc.add_new_measure(preterm_birth_comare_key, NeonatalCauseSpecificMortalityRisk)

In [None]:
# Compare simulation outputs to artifact
vc.add_comparison(
    preterm_birth_comare_key,
    test_source="sim",
    ref_source="artifact",
)

In [None]:
# Comparison metadata
preterm_csmrisk_metadata = vc.metadata(preterm_birth_comare_key)
preterm_csmrisk_metadata

In [None]:
preterm_csmrisk_frame = vc.get_frame(vc.get_frame(
    preterm_birth_comare_key, 
    aggregate_draws=True, 
    # stratifications=[],
)
preterm_csmrisk_frame