In [1]:
import pandas as pd
import numpy as np, os
import matplotlib.pyplot as plt

from pathlib import Path
import yaml
import re

import gbd_mapping as gm
from vivarium import Artifact

from db_queries import get_ids, get_outputs, get_population, get_covariate_estimates
from get_draws.api import get_draws

import vivarium_helpers as vh
import vivarium_helpers.id_helper as idh
from vivarium_helpers.vph_output.operations import VPHOperator

!date
!whoami
!pwd

Tue Oct 28 08:23:50 PDT 2025
lutzes
/mnt/share/homes/lutzes/vivarium_research_alzheimers


# Load Needed Data

In [2]:
# Project directory
%cd /mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/

/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers


In [3]:
locations = [
    'United States of America',
    # 'Brazil',
    # 'China',
    # 'Germany',
    # 'Israel',
    # 'Japan',
    # 'Spain',
    # 'Sweden',
    # 'Taiwan (Province of China)',
    # 'United Kingdom',
]

# Define some shorter names to use for plotting
location_to_short_name = ({loc: loc for loc in locations}| {
    'Taiwan (Province of China)': 'Taiwan',
    'United Kingdom': 'UK',
    'United States of America': 'USA',
})

# Select a subset of locations to draw plots for
locations_to_plot = locations[:2]

project_dir = '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/'

model_number = '7.0' # Artifacts are stored here
run_subdirectories = [
    'results/model7.4/united_states_of_america/2025_10_24_16_02_54/',
]
run_dirs = [project_dir + run_subdir for run_subdir in run_subdirectories]
results_dirs = [run_dir + 'results/' for run_dir in run_dirs]

# Option 1: One artifact per location
location_to_results_dir = {
    loc: path for loc, path in zip(locations, results_dirs)}

# # Option 2: All locations in one artifact
# location_to_results_dir = {'all': results_dirs[0]}

location_to_artifact_subdir = {loc: loc.lower().replace(' ', '_') for loc in locations}
artifact_subpaths = [f'artifacts/model{model_number}/' + subdir + '.hdf' for subdir in location_to_artifact_subdir.values()]
location_to_artifact_path = {loc: project_dir + subpath for loc, subpath in zip(locations, artifact_subpaths)}
artifact_path_to_location = {path: loc for loc, path in location_to_artifact_path.items()}
# artifact_path_to_location = {project_dir + subpath: loc for subpath, loc in zip(artifact_subpaths, locations)}
artifact_path_to_location

{'/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model7.0/united_states_of_america.hdf': 'United States of America'}

# Get list of draws and draw columns from `keyspace.yaml`, and reduce to a subset of draws to save memory and time

In [4]:
with open(run_dirs[0] + 'keyspace.yaml', 'r') as keyspace_file:
    keyspace = yaml.safe_load(keyspace_file)
draws = keyspace['input_draw']
print(draws)

[457, 169, 323, 392, 346, 480, 258, 446, 356, 273, 158, 24, 265, 317, 177, 199, 2, 13, 46, 89, 41, 113, 262, 460, 499]


In [5]:
# I ended up never actually using this -- I just used all the draws
draws = sorted(draws[:10]) # reduce to a subset of draws to save memory, and sort
draw_cols = [f'draw_{i}' for i in draws]
print(draw_cols)

['draw_169', 'draw_258', 'draw_273', 'draw_323', 'draw_346', 'draw_356', 'draw_392', 'draw_446', 'draw_457', 'draw_480']


# Load one artifact and define age bins

In [6]:
usa_artifact_path = location_to_artifact_path['United States of America']
usa_art = Artifact(usa_artifact_path)
# print(usa_art.load('metadata.locations'))
# print(usa_art)

In [8]:
# age_bins is an empty DataFrame with a MultiIndex storing age group data
age_bins = usa_art.load('population.age_bins')
age_dictionary = (
    age_bins
    .reset_index()
    .assign(age_group=lambda df: df['age_group_name'].str.replace(' ', '_'))
    # Filter to ages that actually appear in our sim
    .query("age_start >= 25")
)
#age_dictionary

In [84]:
scale = pd.DataFrame()
for location in locations:
    artifact_path = location_to_artifact_path[location]
    art = Artifact(artifact_path)
    temp = art.load('population.structure').reset_index() 
    temp['location'] = location
    
    df_prev_pop = pd.merge(
        art.load('population.scaling_factor').query("year_start == 2023"),
        art.load('population.structure').query("year_start==2025").droplevel(['year_start', 'year_end']),
        left_index=True,
        right_index=True,
        suffixes=['_prev', '_pop']
    )
    prev = ((df_prev_pop.filter(like='draw_').filter(like='_prev')
            * df_prev_pop.filter(like='draw_').filter(like='_pop').values).mean(axis=1)).sum(axis=0)
    # TODO: use draw-specific scale instead of mean
    
    ratio = 100_000 / prev
    print(ratio)

    temp['ratio'] = ratio

    temp = temp.rename(columns={'year_start': 'event_year'})
    temp = temp.merge(age_dictionary, on=['age_start','age_end'])
    mini = temp.loc[temp['event_year'] == 2050]
    for year in range(2051, 2100):
        temp = pd.concat([temp, mini.assign(event_year=year)], ignore_index=True)
    scale = pd.concat([scale, temp], ignore_index=True)
scale.head()

0.018283685923580616


Unnamed: 0,location,sex,age_start,age_end,event_year,year_end,draw_0,draw_1,draw_2,draw_3,...,draw_494,draw_495,draw_496,draw_497,draw_498,draw_499,ratio,age_group_id,age_group_name,age_group
0,United States of America,Female,25.0,30.0,2021,2022,11211250.0,11669210.0,11351830.0,10881110.0,...,10920460.0,11323230.0,11662140.0,11038690.0,10522820.0,10899720.0,0.018284,10,25 to 29,25_to_29
1,United States of America,Female,25.0,30.0,2022,2023,11211130.0,11660510.0,11357430.0,10881320.0,...,10930500.0,11325370.0,11652530.0,11041420.0,10530330.0,10909250.0,0.018284,10,25 to 29,25_to_29
2,United States of America,Female,25.0,30.0,2023,2024,11213620.0,11653960.0,11366900.0,10885900.0,...,10940210.0,11329190.0,11643920.0,11047920.0,10544490.0,10922530.0,0.018284,10,25 to 29,25_to_29
3,United States of America,Female,25.0,30.0,2024,2025,11223640.0,11659970.0,11388270.0,10901210.0,...,10961140.0,11341260.0,11644920.0,11065020.0,10568790.0,10948600.0,0.018284,10,25 to 29,25_to_29
4,United States of America,Female,25.0,30.0,2025,2026,11250620.0,11682560.0,11429710.0,10935410.0,...,10999470.0,11372270.0,11666910.0,11101230.0,10608270.0,10993980.0,0.018284,10,25 to 29,25_to_29


# Create VPHOperator object to perform operations on simulation output

In [10]:
ops = VPHOperator()
ops.index_cols.append('location')
ops.index_cols

vh.vph_output.operations.INDEX_COLUMNS

['input_draw', 'scenario']

# Define functions to load simulation output and summarize it for plotting

In [47]:
def load_sim_output(
        measure,
        results_dict=location_to_results_dir,
        artifact_path_to_location=artifact_path_to_location,
        drop_superfluous_cols=True, # drop redundant or empty columns
        force_categorical=True,
        aggregate_seeds=True,
        raw=False, # Overrides other parameters if True
        **kwargs, # keyword args to pass to .read_parquet
    ):
    """Load simulation output from .parquet files for all locations,
    optionally reducing the size of the data when possible. Returns
    concatenated outputs with a 'location' column added.
    """
    if raw:
        drop_superfluous_cols = False
        force_categorical = False
        aggregate_seeds = False

    dfs = []
    for location, directory in results_dict.items():
        df = pd.read_parquet(Path(directory) / f'{measure}.parquet', **kwargs)
        if drop_superfluous_cols:
            # Drop redundant columns
            for col1, col2 in [
                ('input_draw', 'input_draw_number'),
                ('entity', 'sub_entity'),
            ]:
                if (col1 in df and col2 in df and df[col1].equals(df[col2])):
                    df.drop(columns=col2, inplace=True)
            # Drop empty columns
            for col in df:
                if df[col].isna().all():
                    df.drop(columns=col, inplace=True)
        if force_categorical:
            convert_to_categorical(df, inplace=True)
        if aggregate_seeds:
            # Use default index and value columns when aggregating
            df = vh.vph_output.operations.marginalize(df, 'random_seed')
        if location == 'all':
            if 'artifact_path' in df:
                df['location'] = df['artifact_path'].map(artifact_path_to_location)
        else:
            df['location'] = location
        dfs.append(df)
    df = pd.concat(dfs)
    return df

# TODO: Consider making certain columns ordered Categoricals
def convert_to_categorical(df, inplace=False):
    """Convert all columns except float columns to categorical. This
    saves lots of memory, allowing us to load and manipulate larger
    DataFrames.
    """
    if not inplace:
        df = df.copy()
    for col in df:
        if df[col].dtype not in ('float', 'category'):
            df[col] = df[col].astype('category')
    if not inplace:
        return df
    else:
        return None

# NOTE: Differs from version in Vivarium Helpers in that here,
# dropna=False
def marginalize(
    df:pd.DataFrame,
    marginalized_cols,
    value_cols=None,
    reset_index=True,
    func='sum',
    args=(), # Positional args to pass to func in DataFrameGroupBy.agg
    **kwargs, # Keywords to pass to DataFrameGroupBy.agg
)->pd.DataFrame:
    if value_cols is None:
        value_cols = vh.vph_output.operations.value_col
    marginalized_cols = vh.utils._ensure_iterable(marginalized_cols)
    value_cols = vh.utils._ensure_iterable(value_cols)
    # Move Index levels into columns to enable passing index
    # level names as well as column names to marginalize
    df = vh.utils._ensure_columns_not_levels(df, marginalized_cols)
    groupby_cols = df.columns.difference(
        # must convert Index to list for groupby to work properly
        [*marginalized_cols, *value_cols]).to_list()
    aggregated_data = df.groupby(
        # observed=True needed for Categorical data
        groupby_cols, as_index=(not reset_index),
        observed=True, dropna=False,
    )[value_cols].agg(func, *args, **kwargs)
    return aggregated_data

def summarize_sim_data(df, age_dictionary=age_dictionary):
    """Summarize simulation data for plotting."""
    # Merge to get an age_start column for plotting
    if 'age_group' in df:
        df = df.merge(age_dictionary, on='age_group')
    # Summarize, and rename percentiles to match artifact
    summary = ops.describe(df).rename(
        columns={'2.5%': 'lower', '97.5%': 'upper'})
    return summary


In [85]:
def dataframe_beutification_and_summarizing(df, measure_name):

    # Add in the scale factor multiplication
    df['event_year'] = df['event_year'].astype(int)
    df = df.merge(
        scale[['location','sex','age_group','ratio','event_year']],
        on=['location','sex','age_group','event_year'])
    df['value'] = df['value'] / df['ratio']    

    # Need to set this up for number and rate to be included 
    df['Metric'] = 'Number'
    df_rate = df.copy()
    df_rate['value'] = df_rate['value'] / 100_000
    df_rate['Metric'] = 'Rate per 100,000'
    df = pd.concat([df, df_rate], ignore_index=True)

    # Renaming, dropping columns, and recategorising
    df = df.rename(columns={'event_year': 'Year ID',
                            'age_group': 'Age',
                            'location': 'Location',
                            'sex':'Sex',
                            'scenario':'Scenario',
                            'sub_entity':'Disease Severity'})
    df['Measure'] = measure_name
    df['Scenario'] = df['Scenario'].cat.rename_categories({
        'baseline': 'Reference',
        'bbbm_testing': 'BBBM Testing Only',
        'bbbm_testing_and_treatment' : 'BBBM Testing and Treatment'
    })
    df['Disease Severity'] = df['Disease Severity'].cat.rename_categories({
        'alzheimers_blood_based_biomarker_state': 'Preclinical AD',
        'alzheimers_mild_cognitive_impairment_state': 'MCI due to AD',
        'alzheimers_disease_state' : 'Clinical AD'
    })

    # Now we summarize the data
    df = df.groupby(['Year ID', 'Location', 'Age', 'Sex' , 'Disease Severity' , 'Scenario', 'Measure', 'Metric', 'input_draw']).value.sum().reset_index()
    df = df.groupby(['Year ID', 'Location', 'Age', 'Sex' , 'Disease Severity' , 'Scenario', 'Measure', 'Metric']).value.describe(percentiles=[0.025,0.975]).reset_index()

    df = df.rename(columns={'mean': 'Mean',
                            '2.5%': '95% UI Lower',
                            '97.5%': '95% UI Upper'})

    #Reorder the columns in df 
    column_order = ['Year ID', 'Location', 'Age', 'Sex' , 'Disease Severity' , 'Scenario', 'Measure', 'Metric', 'Mean', '95% UI Lower', '95% UI Upper'] 
    df = df[column_order]

    return df

In [86]:
prevalence = load_sim_output(
    'person_time_alzheimers_disease_and_other_dementias',
    )
prevalence.head()

Unnamed: 0,age_group,entity,entity_type,event_year,input_draw,measure,scenario,sex,sub_entity,treatment,value,location
0,25_to_29,alzheimers_disease_and_other_dementias,cause,2025,499,person_time,baseline,Female,alzheimers_blood_based_biomarker_state,susceptible_to_treatment,0.498289,United States of America
1,25_to_29,alzheimers_disease_and_other_dementias,cause,2025,499,person_time,baseline,Female,alzheimers_blood_based_biomarker_state,waiting_for_treatment,0.0,United States of America
2,25_to_29,alzheimers_disease_and_other_dementias,cause,2025,499,person_time,baseline,Female,alzheimers_blood_based_biomarker_state,full_effect_long,0.0,United States of America
3,25_to_29,alzheimers_disease_and_other_dementias,cause,2025,499,person_time,baseline,Female,alzheimers_blood_based_biomarker_state,full_effect_short,0.0,United States of America
4,25_to_29,alzheimers_disease_and_other_dementias,cause,2025,499,person_time,baseline,Female,alzheimers_blood_based_biomarker_state,waning_effect_long,0.0,United States of America


In [92]:
prevalence_final = dataframe_beutification_and_summarizing(prevalence, 'Prevalent person-time')

In [93]:
prevalence_final.to_csv('/ihme/homes/lutzes/vivarium_research_alzheimers/2025_10_28_prevalence_final.csv')