In [2]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import numpy as np, os
import matplotlib.pyplot as plt

from pathlib import Path
import yaml
import re
import datetime

import pyarrow as pa
import pyarrow.parquet as pq

import gbd_mapping as gm
from vivarium import Artifact

from db_queries import get_ids, get_outputs, get_population, get_covariate_estimates
from get_draws.api import get_draws

import vivarium_helpers as vh
import vivarium_helpers.id_helper as idh
from vivarium_helpers.vph_output.operations import VPHOperator
from vivarium_helpers.vph_output.measures import VPHResults
from vivarium_helpers.utils import convert_to_categorical, constant_categorical, print_memory_usage

!date
!whoami
!pwd

Fri Oct 31 11:31:28 PDT 2025
ndbs
/mnt/share/code/ndbs/vivarium_research_alzheimers/results_tables


# Find data

In [4]:
# Project directory
%cd /mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/

/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers


In [5]:
# Results directory containing model 7.6 results for all locations
model_run_subdir = 'results/abie_consistent_model_test/united_states_of_america/2025_10_28_08_55_05/'
!ls -halt $model_run_subdir/results

total 60M
drwxrwsr-x 5 abie IHME-Simulationscience 4.5K Oct 28 09:19 ..
drwxrwsr-x 2 abie IHME-Simulationscience 6.5K Oct 28 09:17 .
-rw-rw-r-- 1 abie IHME-Simulationscience  17M Oct 28 09:17 ylds.parquet
-rw-rw-r-- 1 abie IHME-Simulationscience 8.2M Oct 28 09:17 ylls.parquet
-rw-rw-r-- 1 abie IHME-Simulationscience 5.2M Oct 28 09:17 person_time_treatment.parquet
-rw-rw-r-- 1 abie IHME-Simulationscience 512K Oct 28 09:17 counts_newly_eligible_for_bbbm_testing.parquet
-rw-rw-r-- 1 abie IHME-Simulationscience 794K Oct 28 09:17 person_time_eligible_for_bbbm_testing.parquet
-rw-rw-r-- 1 abie IHME-Simulationscience 1.4M Oct 28 09:17 deaths.parquet
-rw-rw-r-- 1 abie IHME-Simulationscience 4.1M Oct 28 09:17 person_time_ever_eligible_for_bbbm_testing.parquet
-rw-rw-r-- 1 abie IHME-Simulationscience 841K Oct 28 09:17 counts_new_simulants.parquet
-rw-rw-r-- 1 abie IHME-Simulationscience 406K Oct 28 09:17 counts_bbbm_tests.parquet
-rw-rw-r-- 1 abie IHME-Simulationscience 2.4M Oct 28 09:17 counts_

In [6]:
!ls artifacts

2		  model1.0  model3.0  model4.1	model4.4  model6.0
basic_model	  model2.0  model3.1  model4.2	model4.5  model7.0
consistent-rates  model2.2  model4.0  model4.3	model5.0  model8.3


In [7]:
# This is where results will eventually be
!ls results/model8.4/model_spec

2025_10_29_20_39_18  2025_10_30_14_03_51  2025_10_31_01_03_40
2025_10_29_20_41_39  2025_10_30_16_32_03  2025_10_31_01_09_31
2025_10_29_20_45_13  2025_10_30_17_25_38


# Define directories

Output directory:

`J:\Project\simulation_science\alzheimers\results_10_31_2025`

In [74]:
r"J:\Project\simulation_science\alzheimers\results_10_31_2025".replace('\\', '/').replace('J:', '/snfs1')

'/snfs1/Project/simulation_science/alzheimers/results_10_31_2025'

In [75]:
output_dir = Path('/snfs1/Project/simulation_science/alzheimers/results_10_31_2025')
output_dir.exists()

True

In [8]:
locations = [
    'United States of America',
    'Brazil',
    'China',
    'Germany',
    'Israel',
    'Japan',
    'Spain',
    'Sweden',
    'Taiwan (Province of China)',
    'United Kingdom',
]

# Define some shorter names to use for plotting
location_to_short_name = ({loc: loc for loc in locations}| {
    'Taiwan (Province of China)': 'Taiwan',
    'United Kingdom': 'UK',
    'United States of America': 'USA',
})

# Select a subset of locations to draw plots for
locations_to_plot = locations[:2]

project_dir = '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/'

artifact_model_number = '8.3' # Artifacts are stored here
run_subdirectories = [
    model_run_subdir,
]
run_dirs = [project_dir + run_subdir for run_subdir in run_subdirectories]
results_dirs = [run_dir + 'results/' for run_dir in run_dirs]

# # Option 1: One results directory per location
# location_to_results_dir = {
#     loc: path for loc, path in zip(locations, results_dirs)}

# Option 2: All locations in one results directory
location_to_results_dir = {'all': results_dirs[0]}

location_to_artifact_subdir = {loc: loc.lower().replace(' ', '_') for loc in locations}
artifact_subpaths = [f'artifacts/model{artifact_model_number}/' + subdir + '.hdf' for subdir in location_to_artifact_subdir.values()]
location_to_artifact_path = {loc: project_dir + subpath for loc, subpath in zip(locations, artifact_subpaths)}
artifact_path_to_location = {path: loc for loc, path in location_to_artifact_path.items()}
# artifact_path_to_location = {project_dir + subpath: loc for subpath, loc in zip(artifact_subpaths, locations)}
artifact_path_to_location

{'/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/united_states_of_america.hdf': 'United States of America',
 '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/brazil.hdf': 'Brazil',
 '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/china.hdf': 'China',
 '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/germany.hdf': 'Germany',
 '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/israel.hdf': 'Israel',
 '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/japan.hdf': 'Japan',
 '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/spain.hdf': 'Spain',
 '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/sweden.hdf': 'Sweden',
 '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/taiwan_(province_of

# Define some ordered Categorical dtypes, and convert years to ints

In [9]:
# Order locations lexicographically
all_locations = [
    'United States of America',
    'Brazil',
    'China',
    'Germany',
    'Israel',
    'Japan',
    'Spain',
    'Sweden',
    'Taiwan (Province of China)',
    'United Kingdom',
]
all_locations_dtype = pd.CategoricalDtype(sorted(all_locations), ordered=True)

# int16 ranges from -32768 to 32767 (I think), which is sufficient to
# represent all years 2025-2100. uint8 only goes from 0 to 255, which is
# too small.
year_dtype = 'int16'

# Store draws as ints instead of categoricals since we'll be
# concatenating different draws from different results directories
input_draw_dtype = 'int16'

# Order age groups chronologically
age_groups = [f'{age}_to_{age + 4}' for age in range(25, 95, 5)] + ['95_plus']
age_group_dtype = pd.CategoricalDtype(age_groups, ordered=True)

# Order scenarios by complexity
scenarios = ['baseline', 'bbbm_testing', 'bbbm_testing_and_treatment']
scenario_dtype = pd.CategoricalDtype(scenarios, ordered=True)

colname_to_dtype = {
    'location': all_locations_dtype,
    'event_year': year_dtype,
    'age_group': age_group_dtype,
    'scenario': scenario_dtype,
    'input_draw': input_draw_dtype,
}

# Load one artifact to define age bins

In [10]:
artifact_path = location_to_artifact_path[locations[0]]
art = Artifact(artifact_path)
print(art.load('metadata.locations'))

# age_bins is an empty DataFrame with a MultiIndex storing age group data
age_bins = art.load('population.age_bins')
age_map = (
    age_bins
    .reset_index()
    .assign(age_group=lambda df: df['age_group_name'].str.replace(' ', '_'))
    # Filter to ages that actually appear in our sim
    .query("age_start >= 25")
)
age_map.tail()

['United States of America']


Unnamed: 0,age_group_id,age_group_name,age_start,age_end,age_group
14,20,75 to 79,75.0,80.0,75_to_79
15,30,80 to 84,80.0,85.0,80_to_84
16,31,85 to 89,85.0,90.0,85_to_89
17,32,90 to 94,90.0,95.0,90_to_94
18,235,95 plus,95.0,125.0,95_plus


In [11]:
ps = art.load('population.structure')
ps

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,draw_0,draw_1,draw_2,draw_3,draw_4,draw_5,draw_6,draw_7,draw_8,draw_9,...,draw_490,draw_491,draw_492,draw_493,draw_494,draw_495,draw_496,draw_497,draw_498,draw_499
location,sex,age_start,age_end,year_start,year_end,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1
United States of America,Female,5.0,10.0,2021,2022,9.732314e+06,1.012986e+07,9.854349e+06,9.445721e+06,9.566589e+06,9.327772e+06,9.296157e+06,9.627733e+06,9.700644e+06,1.013566e+07,...,1.078834e+07,8.907237e+06,1.018388e+07,9.617759e+06,9.479886e+06,9.829519e+06,1.012373e+07,9.582515e+06,9.134700e+06,9.461879e+06
United States of America,Female,5.0,10.0,2022,2023,9.641272e+06,1.003201e+07,9.764145e+06,9.357441e+06,9.481215e+06,9.241312e+06,9.207379e+06,9.539936e+06,9.612663e+06,1.003874e+07,...,1.068397e+07,8.825972e+06,1.008576e+07,9.530124e+06,9.394713e+06,9.738288e+06,1.002557e+07,9.493852e+06,9.051875e+06,9.376733e+06
United States of America,Female,5.0,10.0,2023,2024,9.548728e+06,9.932675e+06,9.672565e+06,9.268302e+06,9.393223e+06,9.153950e+06,9.117786e+06,9.450101e+06,9.522796e+06,9.939678e+06,...,1.057841e+07,8.743389e+06,9.986004e+06,9.442052e+06,9.306642e+06,9.645044e+06,9.925524e+06,9.403961e+06,8.968852e+06,9.290161e+06
United States of America,Female,5.0,10.0,2024,2025,9.462539e+06,9.841851e+06,9.588970e+06,9.186573e+06,9.313801e+06,9.073096e+06,9.034774e+06,9.368394e+06,9.439802e+06,9.848900e+06,...,1.048223e+07,8.668214e+06,9.895065e+06,9.361798e+06,9.226222e+06,9.558470e+06,9.832993e+06,9.321446e+06,8.892593e+06,9.211593e+06
United States of America,Female,5.0,10.0,2025,2026,9.394308e+06,9.768750e+06,9.524369e+06,9.122883e+06,9.253380e+06,9.009171e+06,8.969552e+06,9.305595e+06,9.377230e+06,9.776800e+06,...,1.040688e+07,8.610533e+06,9.823115e+06,9.299817e+06,9.163646e+06,9.490584e+06,9.759877e+06,9.257314e+06,8.832906e+06,9.151367e+06
United States of America,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
United States of America,Male,95.0,125.0,2046,2047,4.776578e+05,5.406595e+05,5.049817e+05,5.222527e+05,4.155201e+05,5.400147e+05,5.188244e+05,5.308816e+05,4.846771e+05,5.354384e+05,...,5.876757e+05,5.112679e+05,6.150060e+05,5.127036e+05,4.985853e+05,5.227748e+05,5.281299e+05,6.701626e+05,5.409489e+05,5.437706e+05
United States of America,Male,95.0,125.0,2047,2048,4.897427e+05,5.562757e+05,5.190315e+05,5.397755e+05,4.235024e+05,5.586342e+05,5.360702e+05,5.479343e+05,4.982888e+05,5.521345e+05,...,6.057714e+05,5.287325e+05,6.381382e+05,5.269859e+05,5.133524e+05,5.386952e+05,5.446530e+05,6.986519e+05,5.599898e+05,5.617033e+05
United States of America,Male,95.0,125.0,2048,2049,5.022796e+05,5.729534e+05,5.336017e+05,5.583712e+05,4.316039e+05,5.775186e+05,5.541319e+05,5.656445e+05,5.122992e+05,5.694017e+05,...,6.243642e+05,5.467937e+05,6.621348e+05,5.415670e+05,5.284282e+05,5.547403e+05,5.616294e+05,7.281752e+05,5.794990e+05,5.800382e+05
United States of America,Male,95.0,125.0,2049,2050,5.147342e+05,5.894088e+05,5.479230e+05,5.767340e+05,4.392530e+05,5.962732e+05,5.715906e+05,5.832752e+05,5.258001e+05,5.860537e+05,...,6.426140e+05,5.639182e+05,6.862095e+05,5.554695e+05,5.431262e+05,5.703575e+05,5.778054e+05,7.574645e+05,5.992580e+05,5.981090e+05


In [12]:
ps.index.names

FrozenList(['location', 'sex', 'age_start', 'age_end', 'year_start', 'year_end'])

# Define functions to load and merge Artifact data from all locations, and to summarize Artifact data

In [13]:
def load_artifact_data(
    key,
    filter_terms=None,
    location_to_artifact_path=location_to_artifact_path,
):
    dfs = {} # dict to map locations to artifact data
    for location, path in location_to_artifact_path.items():
        art = Artifact(path, filter_terms)
        # Check to make sure location matches artifact
        art_locations = art.load('metadata.locations')
        assert len(art_locations) == 1 and art_locations[0] == location, \
            f'Unexpected locations in artifact: {location=}, {art_locations=}'
        df = art.load(key)
        dfs[location] = df
    if all('location' in df.index.names for df in dfs.values()):
        data = pd.concat(dfs.values())
    else:
        data = pd.concat(dfs, names=['location', *df.index.names])
    return data

def lower(x):
    return x.quantile(0.025)

def upper(x):
    return x.quantile(0.975)

def summarize_artifact_data(df):
    summary = df.agg(['mean', lower, upper], axis='columns')
    return summary

# Define functions to load simulation results

In [53]:
# Create an operator object - treat each random seed as a separate draw,
# and add location to the index
ops = VPHOperator(location_col=True)
# ops.index_cols.extend(['location', 'random_seed'])

def load_sim_output(
        measure,
        results_dict=location_to_results_dir,
        # Pass None to skip filtering locations (when None, must also
        # pass assign_location=False or raw=True)
        location_to_artifact_path=location_to_artifact_path,
        # specify dtypes of certain columns
        colname_to_dtype=colname_to_dtype,
        drop_superfluous_cols=True, # drop redundant or empty columns
        # Sets the 'read_dictionary' key of kwargs, which is passed to
        # pyarrow.parquet.read_table()
        force_parquet_dictionaries=True,
        force_pandas_categoricals=True,
        aggregate_seeds=True,
        assign_location=True,
        raw=False, # Overrides other parameters if True
        **kwargs, # keyword args to pass to .read_parquet
    ):
    """Load simulation output from .parquet files for all locations,
    optionally reducing the size of the data when possible. Returns
    concatenated outputs with a 'location' column added.
    """
    # Override optional transformations if raw=True
    if raw:
        drop_superfluous_cols = False
        force_parquet_dictionaries = False
        force_pandas_categoricals = False
        aggregate_seeds = False
        assign_location = False

    # Determine whether results for all locations are stored in same
    # directory, or if different locations have different results
    # directories
    match location_to_results_dir:
        case {'all': _}:
            all_locations_together = True
        case _:
            all_locations_together = False
    
    if all_locations_together and assign_location and location_to_artifact_path is None:
        raise ValueError(
            "Must provide mapping of artifacts to locations  when" \
            " assign_location=True and all locations are concatenated" \
            " in the simulation outputs."
        )

    dfs = []
    for location, directory in results_dict.items():

        parquet_file_path = Path(directory) / f'{measure}.parquet'
        # Read the Parquet file's schema to get column names and data types
        parquet_schema = pq.read_schema(parquet_file_path)

        if (
            all_locations_together
            and location_to_artifact_path is not None
        ):
            if 'artifact_path' in parquet_schema.names:
                # Filter to locations in list
                location_filter = (
                    'artifact_path',
                    'in',
                    list(location_to_artifact_path.values()),
                )
                user_filters = kwargs.get('filters') # Defaults to None
                kwargs['filters'] = add_parquet_AND_filter(
                    location_filter, user_filters)
                # TODO: Use logging not printing
                print(location_filter)
            else:
                print("'artifact_path' column missing from parquet file."
                      " Not filtering locations.")

        if force_parquet_dictionaries:
            # Read all columns as dictionaries except those containing 
            # floating point values
            kwargs['read_dictionary'] = [
                col.name for col in parquet_schema
                if not pa.types.is_floating(col.type)]

        # Read the parquet file
        df = pd.read_parquet(parquet_file_path, **kwargs)

        if drop_superfluous_cols:
            # Drop redundant columns
            for col1, col2 in [
                ('input_draw', 'input_draw_number'),
                ('entity', 'sub_entity'),
            ]:
                if (col1 in df and col2 in df and df[col1].equals(df[col2])):
                    df.drop(columns=col2, inplace=True)
            # Drop empty columns (e.g., sub-entity)
            for col in df:
                if df[col].isna().all():
                    df.drop(columns=col, inplace=True)
        if colname_to_dtype is not None:
            # Filter to avoid KeyError
            colname_to_dtype = {c: dtype for c, dtype
                                in colname_to_dtype.items() if c in df}
            # NOTE: If copy-on-write is enabled, copy keyword is ignored
            df = df.astype(colname_to_dtype, copy=False)
        if force_pandas_categoricals:
            convert_to_categorical(
                df, exclude_cols=colname_to_dtype or (), inplace=True)
        if aggregate_seeds:
            # Use default index and value columns when aggregating
            df = vh.vph_output.operations.marginalize(df, 'random_seed')
        if assign_location:
            if all_locations_together:
                # NOTE: location_to_artifact_path is guaranteed not to
                # be None because assign_location and
                # all_locations_together are both True

                # Create a Categorical dtype with all locations
                location_dtype = pd.CategoricalDtype(
                    sorted(location_to_artifact_path.keys()), ordered=True)
                # Invert the dictionary so we can map artifact paths to
                # locations
                artifact_path_to_location = {
                    path: loc for loc, path
                    in location_to_artifact_path.items()}
                if 'artifact_path' in df:
                    df['location'] = df['artifact_path'].map(
                        artifact_path_to_location).astype(location_dtype)
                else:
                    # In case the engineers change the DataFrame format
                    # on us...
                    print("'artifact_path' column missing from DataFrame."
                          " Not assigning locations.")
            else:
                # NOTE: location_to_results_dir contains actual
                # locations as keys (not 'all') since
                # all_locations_together is False

                # Create a Categorical dtype with all locations to avoid
                # converting back to object dtype.
                location_dtype = pd.CategoricalDtype(
                    sorted(location_to_results_dir.keys()), ordered=True)
                df['location'] = location
                df['location'] = df['location'].astype(location_dtype)
        dfs.append(df)
    # TODO: Maybe if assign_location is False and all_locations_together
    # is also False (and there is more than one location?), we should
    # return a dict mapping locations to dataframes (or just a list of
    # dataframes?) instead of concatenating, since it won't be possible
    # to filter the resulting concatenated dataframe by location...
    df = pd.concat(dfs)
    return df
    
def add_parquet_AND_filter(new_filter, existing_filters):
    match existing_filters:
        case None:
            # No existing filters -- create a single AND group
            filters = [new_filter]
        case list([tuple((_, _, _)), *_]):
            # Existing filters consist of one AND group -- add the new filter
            filters = [new_filter, *existing_filters]
        case list([list([tuple((_, _, _)), *_]), *_]):
            # Add the filter to each AND group in the outer OR group
            filters = [[new_filter, *and_group] for and_group in existing_filters]
        case _:
            raise ValueError(f"Malformed parquet filter: {existing_filters}")
    return filters

# NOTE: Need to create ops before defining this function
def summarize_sim_data(df, age_map=age_map):
    """Summarize simulation data for plotting."""
    # Merge to get an age_start column for plotting
    if 'age_group' in df:
        df = df.merge(age_map, on='age_group')
    # Summarize, and rename percentiles to match artifact
    summary = ops.describe(df).rename(
        columns={'2.5%': 'lower', '97.5%': 'upper'})
    return summary

def current_time():
    print(datetime.datetime.now())

# Calculate model scale

## First read population structure and initial all-state prevalences from the artifact

In [17]:
# This is the number of people in each demographic group in each year --
# these numbers come from the FHS population forecasts
pop_structure = load_artifact_data('population.structure')
pop_structure.tail()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,draw_0,draw_1,draw_2,draw_3,draw_4,draw_5,draw_6,draw_7,draw_8,draw_9,...,draw_490,draw_491,draw_492,draw_493,draw_494,draw_495,draw_496,draw_497,draw_498,draw_499
location,sex,age_start,age_end,year_start,year_end,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1
United Kingdom,Male,95.0,125.0,2046,2047,129365.649457,130775.330848,132470.262731,122141.117458,122669.816649,116553.291993,137240.957664,146303.198839,134355.643238,129130.125638,...,129199.724763,133834.848564,128153.198273,126410.768209,108873.787649,129171.958392,132812.650873,145683.572345,153951.857746,133969.822489
United Kingdom,Male,95.0,125.0,2047,2048,129176.872428,131427.429524,133135.096509,122206.362659,121990.475534,116216.514579,137226.41911,147381.840121,134405.091989,128844.029807,...,128743.759762,134732.739968,129357.417043,127365.126985,108980.704141,128693.703184,133430.947564,146604.135918,154047.33539,134464.260699
United Kingdom,Male,95.0,125.0,2048,2049,128215.39423,131705.587524,132988.854403,121848.151327,120928.626569,115242.163772,136633.081048,147552.13904,133626.750369,127966.248102,...,127782.664392,134921.061786,129793.897779,127858.445137,108759.06186,127744.197044,133723.530245,146651.755678,153534.536161,133964.837991
United Kingdom,Male,95.0,125.0,2049,2050,127765.330292,131992.359105,132745.174212,121566.569294,120010.799977,114127.489703,136240.266101,147709.447517,133317.489612,127169.484355,...,127529.973622,135115.785472,130017.37333,128346.358923,108768.814727,127682.196818,134255.13048,146732.723463,153968.447922,134047.815962
United Kingdom,Male,95.0,125.0,2050,2051,128213.393685,133220.130675,133421.143899,122248.029681,119754.120437,113654.996368,136919.776385,148869.875216,133983.846752,126969.448946,...,128519.07973,135934.950475,130583.543525,129590.537377,109219.537772,128856.357034,135299.299123,147386.69597,155815.638277,135717.26801


In [18]:
# For each demographic group, the "population scaling factor" is the
# ratio of the real-world population that we want to simulate in that
# group to the total number of people in that group. For Model 4 and
# above, this equals the initial prevalence of all AD disease states
# combined (preclinical + MCI + AD-dementia), since we are modeling the
# population of people with any stage of AD. Note that this is defined
# for the population at the beginning of the simulation, so there is
# only one year of data.
art_all_states_initial_prev = load_artifact_data('population.scaling_factor')
art_all_states_initial_prev.tail()
# NOTE: This data has two age groups, 95-100 and 100-105, instead of the
# single age group 95-125 that's in the population structure. I'm not
# sure why. I'm going to drop the 100-105 age group and match the 95-100
# age group with the 95-125 age group from above

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,draw_0,draw_1,draw_2,draw_3,draw_4,draw_5,draw_6,draw_7,draw_8,draw_9,...,draw_490,draw_491,draw_492,draw_493,draw_494,draw_495,draw_496,draw_497,draw_498,draw_499
location,sex,age_start,age_end,year_start,year_end,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1
United Kingdom,Female,80,85,2025,2026,0.184057,0.205264,0.201615,0.203185,0.178318,0.175798,0.165233,0.152835,0.155335,0.174561,...,0.180658,0.197655,0.150986,0.137985,0.183084,0.153108,0.16389,0.135606,0.13778,0.178453
United Kingdom,Female,85,90,2025,2026,0.249261,0.251477,0.250982,0.279132,0.262662,0.236341,0.308766,0.292557,0.27689,0.286794,...,0.271585,0.290982,0.269875,0.259442,0.278545,0.276393,0.301692,0.294791,0.286761,0.287763
United Kingdom,Female,90,95,2025,2026,0.288238,0.285702,0.254878,0.292849,0.297609,0.266154,0.328095,0.294896,0.276581,0.281861,...,0.269374,0.26359,0.231749,0.22287,0.251981,0.253709,0.288297,0.272311,0.251585,0.264228
United Kingdom,Female,95,100,2025,2026,0.262417,0.263717,0.252117,0.27668,0.262065,0.260805,0.268706,0.25233,0.26228,0.255821,...,0.248522,0.244338,0.248601,0.242992,0.261974,0.263646,0.274336,0.246757,0.253818,0.256388
United Kingdom,Female,100,105,2025,2026,0.253172,0.257814,0.248932,0.272355,0.252686,0.253852,0.252204,0.245571,0.251487,0.24982,...,0.237925,0.215883,0.224073,0.222823,0.256208,0.267994,0.26968,0.248598,0.249409,0.246543


In [19]:
# There's only one year worth of data here
art_all_states_initial_prev.index.unique('year_end')

Int64Index([2026], dtype='int64', name='year_end')

In [None]:
art_all_states_initial_prev_counts.rename({2025: 2022})

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,draw_0,draw_1,draw_2,draw_3,draw_4,draw_5,draw_6,draw_7,draw_8,draw_9,...,draw_490,draw_491,draw_492,draw_493,draw_494,draw_495,draw_496,draw_497,draw_498,draw_499
location,sex,age_start,age_end,year_start,year_end,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1
Brazil,Female,30.0,35.0,2022,2026,0.073616,0.099191,0.060202,0.065540,0.553247,0.136304,0.022932,0.149443,0.007885,0.005169,...,0.008785,0.059043,0.067581,0.054968,0.106028,0.038484,0.059845,0.058674,0.173776,0.126197
Brazil,Female,35.0,40.0,2022,2026,0.377879,1.312264,1.170651,1.626404,1.135345,0.942462,0.845416,0.802527,0.923892,1.227090,...,1.783461,0.776389,1.149924,1.207990,0.943360,1.321899,1.336783,0.759363,0.634467,0.196024
Brazil,Female,40.0,45.0,2022,2026,6.206119,5.088725,3.136926,3.493659,7.262730,8.906555,7.066643,10.983498,9.706377,4.594043,...,6.996392,4.900188,12.694149,10.887607,5.084747,5.233626,9.591838,7.720006,8.792822,5.230435
Brazil,Female,45.0,50.0,2022,2026,12028.742764,13326.780832,11443.581264,11367.446568,12267.957958,11519.313628,13476.311608,12848.173685,11678.693945,11032.900180,...,10336.574069,11703.115990,11459.630096,10615.893340,15372.988355,12767.215768,13372.677298,14674.922434,13978.757481,13087.986879
Brazil,Female,50.0,55.0,2022,2026,16406.670660,18841.214094,13379.005172,15555.660218,14170.293812,14741.096206,12471.388725,13260.743680,14215.840972,12725.359404,...,12746.588949,13386.686957,18992.611343,16895.799644,17894.621111,15581.539058,16313.316722,17914.767542,16918.069936,13131.848837
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
United States of America,Male,75.0,80.0,2022,2026,336666.110275,290798.790349,365199.499489,294668.064208,350244.370410,352417.831459,384171.558605,388628.920179,386774.048939,376573.976075,...,410186.804512,221278.281504,365540.956084,310498.520056,318819.507229,344733.109952,347665.108267,339677.191640,341236.157962,285725.909552
United States of America,Male,80.0,85.0,2022,2026,339418.412521,384306.492653,405396.170164,375662.500814,354867.644497,340363.264916,304324.534162,294109.243827,307959.996523,357667.691508,...,376962.121566,369808.305008,311574.517292,277054.985629,369842.744026,312305.839913,330013.160066,271401.573090,265236.090834,317864.262488
United States of America,Male,85.0,90.0,2022,2026,294987.297843,310270.327219,288558.807685,303560.563817,284552.032710,244599.753697,328245.098744,315303.969894,292590.852505,326876.521087,...,338844.158949,305379.813172,318539.705080,290095.268213,302566.724193,325561.822857,361073.561385,327636.913224,295007.734933,328116.986890
United States of America,Male,90.0,95.0,2022,2026,140061.953770,146223.117621,122692.121502,139897.574581,143161.807051,117954.989975,146209.759661,138746.729338,128278.098464,137579.011796,...,135085.262986,113199.457672,110389.764190,99725.993505,115771.246023,124383.955504,144800.848778,127676.806184,106026.998539,123172.328035


## Now compute initial real-world all-state prevalence counts and model scale

In [124]:
def get_real_world_initial_population(
       population_structure,
       initial_prevalence,
       start_year=2022,
):
    years = initial_prevalence.index.unique('year_start')
    assert len(years) == 1, 'Unexpected years for initial prevalence!'
    year = years[0]
    # Use the specified start year for the population structure,
    # regardless of what single year is stored in the initial
    # prevalence. Rename year_start and year_end to properly match the
    # dataframes.
    initial_prevalence = (
        initial_prevalence
        .rename({year: start_year}, level='year_start')
        # NOTE: Only works if year_end = year_start + 1
        .rename({year+1: start_year+1}, level='year_end')
    )
    initial_prevalence_counts = (
        population_structure
        .query("year_start==@start_year")
        # Change end of oldest age group to match prevalence data
        .rename({125.0: 100.0}, level='age_end')
        * initial_prevalence
    ).dropna() # Drop age groups we don't have in sim
    return initial_prevalence_counts

art_all_states_initial_prev_counts = get_real_world_initial_population(
    pop_structure, art_all_states_initial_prev
)
art_all_states_initial_prev_counts.tail()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,draw_0,draw_1,draw_2,draw_3,draw_4,draw_5,draw_6,draw_7,draw_8,draw_9,...,draw_490,draw_491,draw_492,draw_493,draw_494,draw_495,draw_496,draw_497,draw_498,draw_499
location,sex,age_start,age_end,year_start,year_end,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1
United States of America,Male,75.0,80.0,2022,2023,279503.716159,241421.537335,303005.456325,245046.790482,291032.01973,292173.888069,319151.434213,322901.726876,320919.877307,312721.888359,...,340496.388986,183624.796107,303387.652317,257740.083278,264585.258062,286707.666817,288361.928754,281653.836405,283209.988963,237124.856198
United States of America,Male,80.0,85.0,2022,2023,295491.963153,334542.816129,352919.503418,327523.75573,309471.027115,296260.437344,265026.502337,256594.63654,268008.281527,311310.859435,...,328352.211159,321773.632827,270912.872088,241055.920798,321904.097725,272172.292983,286599.286178,235443.951067,230842.759084,276251.278427
United States of America,Male,85.0,90.0,2022,2023,265660.836082,279333.12336,259372.09353,273638.950137,256249.409965,220125.858697,295555.094216,283893.37546,262971.486971,293492.877529,...,304798.665607,274423.347832,285871.062247,260618.321976,271969.277566,292854.234936,324185.862549,293051.854129,265368.527577,293529.464736
United States of America,Male,90.0,95.0,2022,2023,135009.931912,140544.500518,117802.312175,135061.861965,138174.629778,113243.661152,140889.365668,133342.91333,123131.014317,132526.38297,...,129627.305356,108648.748859,105907.707572,95763.096448,111257.383359,119742.388165,139285.777478,121602.893171,101618.235854,117661.784595
United States of America,Male,95.0,100.0,2022,2023,29288.078689,30698.389085,30107.564684,32246.26041,30766.214728,29113.490401,29413.069566,29157.686021,30052.152467,30956.218671,...,32091.208571,26499.605071,31482.881082,29952.989616,32037.054667,33926.196355,34717.989582,29169.491828,28270.193829,31345.61671


In [123]:
temp = art_all_states_initial_prev_counts
temp.tail()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,draw_0,draw_1,draw_2,draw_3,draw_4,draw_5,draw_6,draw_7,draw_8,draw_9,...,draw_490,draw_491,draw_492,draw_493,draw_494,draw_495,draw_496,draw_497,draw_498,draw_499
location,sex,age_start,age_end,year_start,year_end,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1
United States of America,Male,75.0,80.0,2025,2026,336666.110275,290798.790349,365199.499489,294668.064208,350244.37041,352417.831459,384171.558605,388628.920179,386774.048939,376573.976075,...,410186.804512,221278.281504,365540.956084,310498.520056,318819.507229,344733.109952,347665.108267,339677.19164,341236.157962,285725.909552
United States of America,Male,80.0,85.0,2025,2026,339418.412521,384306.492653,405396.170164,375662.500814,354867.644497,340363.264916,304324.534162,294109.243827,307959.996523,357667.691508,...,376962.121566,369808.305008,311574.517292,277054.985629,369842.744026,312305.839913,330013.160066,271401.57309,265236.090834,317864.262488
United States of America,Male,85.0,90.0,2025,2026,294987.297843,310270.327219,288558.807685,303560.563817,284552.03271,244599.753697,328245.098744,315303.969894,292590.852505,326876.521087,...,338844.158949,305379.813172,318539.70508,290095.268213,302566.724193,325561.822857,361073.561385,327636.913224,295007.734933,328116.98689
United States of America,Male,90.0,95.0,2025,2026,140061.95377,146223.117621,122692.121502,139897.574581,143161.807051,117954.989975,146209.759661,138746.729338,128278.098464,137579.011796,...,135085.262986,113199.457672,110389.76419,99725.993505,115771.246023,124383.955504,144800.848778,127676.806184,106026.998539,123172.328035
United States of America,Male,95.0,100.0,2025,2026,38008.402852,39984.911499,39142.669366,41966.499708,40120.02974,38241.33186,38405.196666,38044.256565,39183.892639,40301.345259,...,41816.515797,34843.11934,41545.145156,39186.427463,41986.266513,44275.89234,45193.978094,39152.661564,37156.903534,41415.029795


In [125]:
# Ininitial simulated population per draw, from concept model
# TODO: Change this to 100 seeds once we get final runs
num_seeds = 5 # 100 # 5 seeds for V&V runs, 100 seeds for final runs
pop_per_seed = 20_000
initial_sim_pop = num_seeds * pop_per_seed

def calculate_model_scale(
        simulated_initial_population,
        real_world_initial_population,
    ):
    # Sum over age groups to get real-world population in each location
    total_real_world_initial_pop = (
        real_world_initial_population.groupby('location').sum())
    # Model scale is the ratio of our simulated population to the real-world
    # population at time 0
    model_scale = (
        simulated_initial_population / total_real_world_initial_pop)
    # This format (draws horizontally as column names, as strings) is
    # compatible with Artifacts
    return model_scale

# Compute model scale in Artifact format
art_model_scale = calculate_model_scale(
    initial_sim_pop, art_all_states_initial_prev_counts)
art_model_scale

Unnamed: 0_level_0,draw_0,draw_1,draw_2,draw_3,draw_4,draw_5,draw_6,draw_7,draw_8,draw_9,...,draw_490,draw_491,draw_492,draw_493,draw_494,draw_495,draw_496,draw_497,draw_498,draw_499
location,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
Brazil,0.059124,0.053128,0.057234,0.053931,0.054138,0.056098,0.050175,0.053436,0.053412,0.061014,...,0.064518,0.053006,0.05434,0.061677,0.049235,0.058628,0.058628,0.051967,0.058857,0.056216
China,0.006671,0.006505,0.005715,0.005781,0.006172,0.005957,0.005879,0.006201,0.005759,0.006323,...,0.005984,0.006285,0.006387,0.006715,0.006338,0.006193,0.006245,0.005969,0.006396,0.006172
Germany,0.043165,0.044635,0.044121,0.043983,0.038993,0.0431,0.043733,0.044159,0.042866,0.047304,...,0.047435,0.042426,0.043344,0.047762,0.047484,0.046563,0.041533,0.044561,0.043855,0.043414
Israel,2.062731,2.015539,2.010517,2.002877,1.958312,2.20061,1.789856,1.910496,2.061328,2.04187,...,2.090477,2.010696,2.009026,2.274123,2.36142,1.904982,1.907266,1.731055,2.424713,2.002196
Japan,0.029343,0.028741,0.030905,0.027669,0.028203,0.027022,0.024412,0.025587,0.028524,0.028222,...,0.02681,0.029577,0.029757,0.032507,0.031096,0.030167,0.028334,0.028847,0.029266,0.028393
Spain,0.160561,0.170945,0.153672,0.143598,0.154303,0.16178,0.144117,0.145225,0.162515,0.14936,...,0.172631,0.153515,0.183955,0.170476,0.150638,0.161405,0.140952,0.146733,0.155639,0.160408
Sweden,0.680744,0.689549,0.58984,0.596716,0.627214,0.619798,0.592754,0.674568,0.682529,0.617883,...,0.68394,0.643573,0.669699,0.721838,0.618795,0.590199,0.643935,0.625364,0.718512,0.619943
Taiwan (Province of China),0.440412,0.386021,0.486283,0.425787,0.422177,0.448623,0.440461,0.462848,0.396845,0.426235,...,0.415727,0.415297,0.446948,0.490541,0.456253,0.420806,0.43014,0.446591,0.423496,0.428008
United Kingdom,0.104953,0.101831,0.103883,0.100616,0.103059,0.101663,0.101167,0.10137,0.099046,0.099543,...,0.095263,0.097617,0.110445,0.119529,0.107823,0.104996,0.096381,0.101247,0.101059,0.100265
United States of America,0.021761,0.021462,0.020718,0.021562,0.0214,0.02223,0.020873,0.020811,0.02119,0.020348,...,0.018585,0.023058,0.021247,0.023783,0.02267,0.021725,0.020391,0.021758,0.023457,0.022257


In [126]:
art_model_scale.T.describe()

location,Brazil,China,Germany,Israel,Japan,Spain,Sweden,Taiwan (Province of China),United Kingdom,United States of America
count,500.0,500.0,500.0,500.0,500.0,500.0,500.0,500.0,500.0,500.0
mean,0.056273,0.006199,0.045402,2.055397,0.029518,0.15785,0.655658,0.440022,0.103484,0.022018
std,0.00455,0.000363,0.002716,0.174438,0.001708,0.010319,0.04214,0.026138,0.004784,0.001147
min,0.044334,0.005182,0.038993,1.624465,0.024412,0.135964,0.545795,0.378188,0.093223,0.018585
25%,0.053264,0.005957,0.043448,1.938195,0.028311,0.150649,0.622847,0.422211,0.100181,0.021244
50%,0.055972,0.00618,0.045206,2.041927,0.02946,0.157149,0.654567,0.439612,0.103251,0.021964
75%,0.059287,0.006397,0.047003,2.157508,0.030553,0.163973,0.682701,0.457267,0.106665,0.022742
max,0.071019,0.007572,0.05576,2.751542,0.034743,0.1906,0.785809,0.523962,0.119529,0.027132


In [127]:
# Reformat model scale to be compatible with simulation output: draws
# vertically in columns or index, as integers

# model_scale = (
#     art_model_scale
#     .rename_axis(columns='input_draw')
#     .pipe(lambda df: df.set_axis(
#         df.columns.str.removeprefix('draw_')
#         .astype(input_draw_dtype), axis=1))
#     .stack()
#     .sort_index()
#     .rename('value')
#     .reset_index()
#     .astype({'location': all_locations_dtype})
# )

def convert_to_sim_format(df, colname_to_dtype=colname_to_dtype):
    """Convert artifact data to a format compatible with sim output."""
    # TODO: Also convert age_start/age_end to age_group
    # input_draw_dtype = colname_to_dtype.get('input_draw', 'int')
    # colname_to_dtype = {c: dtype for c, dtype
    #                     in colname_to_dtype.items() if c in df.index.names}
    new_df = (
        df
        .rename_axis(columns='input_draw')
        .rename(columns=lambda s: int(s.removeprefix('draw_')))
        .stack()
        # .sort_index()
        .rename('value')
        .rename_axis(index={'year_start': 'event_year'})
        .reset_index()
        # Drop the year_end column if it exists
        .drop(columns='year_end', errors='ignore')
        .pipe(lambda df: df.astype(
            {c: dtype for c, dtype
             in colname_to_dtype.items() if c in df}
        ))
    )
    return new_df

model_scale = convert_to_sim_format(art_model_scale)
model_scale

Unnamed: 0,location,input_draw,value
0,Brazil,0,0.059124
1,Brazil,1,0.053128
2,Brazil,2,0.057234
3,Brazil,3,0.053931
4,Brazil,4,0.054138
...,...,...,...
4995,United States of America,495,0.021725
4996,United States of America,496,0.020391
4997,United States of America,497,0.021758
4998,United States of America,498,0.023457


In [128]:
model_scale.dtypes

location      category
input_draw       int16
value          float64
dtype: object

# Define functions to scale measures to real-world values, add rates, and generate final results

In [129]:
def scale_to_real_world(measure, model_scale=model_scale, ops=ops):
    """Divide the values in the `measure` dataframe by the values in
    `model_scale`, matching location and draw, and broadcasting across
    other columns in `measure`. This computes the value of the measure
    in the real-world population from the scaled-down version we get
    from the simulation.
    """
    draws = measure['input_draw'].unique()
    model_scale = model_scale.query("input_draw in @draws")
    measure = ops.value(measure)
    # NOTE: Reindexing preserves categoricals (in location column), but
    # results in all NaN's for some reason
    model_scale = ops.value(model_scale)#.reindex(measure.index)
    # scaled_measure = measure.divide(model_scale, axis=0).reset_index()
    scaled_measure = (measure / model_scale).reset_index()
    #.dropna() # Alternative to filtering draws above
    return scaled_measure

def calculate_rate(measure, population_structure=pop_structure, ops=ops):
    # Divide measure by total person time to get rate
    ...
    return measure

def summarize_and_beautify(
        df,
        disease_stage_column=None,
        # column_name_map={},
        model_scale=model_scale,
        population_structure=pop_structure,
        ops=ops,
    ):
    """Append rates, scale to real-world, summarize, rename columns,
    filter to desired columns, and put them in the right order.
    """
    # Calculate rates
    ...
    # Default column name map
    if disease_stage_column is None:
        disease_stage_column = 'sub_entity'
    column_name_map = {
        'event_year': 'Year',
        'age_group': 'Age',
        'location': 'Location',
        'sex': 'Sex',
        'scenario': 'Scenario',
        'measure': 'Measure',
        disease_stage_column: 'Disease Stage',
        'mean': 'Mean',
        'lower': '95% UI Lower',
        'upper': '95% UI Upper',
    }#.update(column_name_map) # This wasn't working for some reason

    disease_stage_name_map = {
        'alzheimers_blood_based_biomarker_state': 'Preclinical AD',
        'alzheimers_mild_cognitive_impairment_state': 'MCI due to AD',
        'alzheimers_disease_state' : 'AD Dementia'
    }
    scenario_name_map = {
        'baseline': 'Reference',
        'bbbm_testing': 'BBBM Testing Only',
        'bbbm_testing_and_treatment' : 'BBBM Testing and Treatment',
    }
    column_order = [
        'Year', 'Location', 'Age', 'Sex' , 'Disease Stage' , 'Scenario',
        'Measure', 'Metric', 'Mean', '95% UI Lower', '95% UI Upper',
    ]
    current_time()
    # Do transformations
    df = (
        df
        # Append rates
        # .pipe(lambda df: df)
        # Scale to real-world values
        .pipe(scale_to_real_world, model_scale, ops)
        .pipe(lambda df: current_time() or df)
        # Summarize data
        .pipe(ops.summarize_draws)
        .reset_index()
        .pipe(lambda df: current_time() or df)
        # Rename columns
        .rename(columns=column_name_map)
        .replace(
            {'Disease Stage': disease_stage_name_map,
             'Scenario': scenario_name_map})
        [column_order]
    )
    return df

# Deaths and averted deaths

In [None]:
# deaths.entity.unique(): ['alzheimers_disease_state', 'other_causes']
# Filter out other causes when loading since we don't need it
deaths_filter = [('entity', '=', 'alzheimers_disease_state')]
deaths = load_sim_output('deaths', filters=deaths_filter)
print_memory_usage(deaths)
deaths.tail()

('artifact_path', 'in', ['/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/united_states_of_america.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/brazil.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/china.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/germany.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/israel.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/japan.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/spain.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/sweden.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/taiwan_(province_of_china).hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/a

Unnamed: 0,age_group,artifact_path,entity,entity_type,event_year,input_draw,measure,scenario,sex,value,location
161995,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,457,deaths,baseline,Male,55.0,China
161996,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,457,deaths,bbbm_testing,Female,944.0,China
161997,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,457,deaths,bbbm_testing,Male,55.0,China
161998,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,457,deaths,bbbm_testing_and_treatment,Female,951.0,China
161999,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,457,deaths,bbbm_testing_and_treatment,Male,55.0,China


In [None]:
def process_deaths(deaths, ops=ops):
    """Preprocess the deaths dataframe and compute averted deaths."""
    # Filter to only deaths due to AD
    deaths = deaths.query("entity=='alzheimers_disease_state'")
    # Calculate averted deaths
    averted_deaths = (
        ops.averted(deaths, baseline_scenario='baseline')
        .assign(measure='Averted Deaths Associated with AD')
    )
    # Do transformations
    deaths = (
        deaths
        # Rename the measure
        .assign(measure='Deaths Associated with AD')
        # Concatenate deaths with averted deaths
        # TODO: Concatenate with rates also?
        .pipe(lambda df:
              # Use inner join to drop "subtracted_from" column added by
              # .averted
              pd.concat([df, averted_deaths], join='inner', ignore_index=True))
        .assign(Metric='Number')
        .pipe(convert_to_categorical)
    )
    return deaths

deaths_prepped = process_deaths(deaths)
deaths_prepped.tail()

4.541504 MB measure
1.517504 MB minuend
3.029504 MB subtrahend
0.979635 MB minuend re-indexed
1.951635 MB subtrahend re-indexed
1.951695 MB difference
3.461636 MB difference with reset index
3.569961 MB final difference


Unnamed: 0,age_group,artifact_path,entity,entity_type,event_year,input_draw,measure,scenario,sex,value,location,Metric
269995,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,392,Averted Deaths Associated with AD,bbbm_testing_and_treatment,Male,0.0,China,Number
269996,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,457,Averted Deaths Associated with AD,bbbm_testing,Female,0.0,China,Number
269997,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,457,Averted Deaths Associated with AD,bbbm_testing,Male,0.0,China,Number
269998,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,457,Averted Deaths Associated with AD,bbbm_testing_and_treatment,Female,-7.0,China,Number
269999,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,457,Averted Deaths Associated with AD,bbbm_testing_and_treatment,Male,0.0,China,Number


## Summarize deaths and save to file

In [131]:
deaths_output = summarize_and_beautify(deaths_prepped.query(f"location.isin({locations})"), 'entity')
deaths_output

2025-10-31 15:19:54.344969
2025-10-31 15:19:54.995965
2025-10-31 15:20:43.568479


Unnamed: 0,Year,Location,Age,Sex,Disease Stage,Scenario,Measure,Metric,Mean,95% UI Lower,95% UI Upper
0,2025,Brazil,25_to_29,Female,AD Dementia,BBBM Testing Only,Averted Deaths Associated with AD,Number,0.000000,0.000000,0.000000
1,2025,Brazil,25_to_29,Male,AD Dementia,BBBM Testing Only,Averted Deaths Associated with AD,Number,0.000000,0.000000,0.000000
2,2025,Brazil,25_to_29,Female,AD Dementia,BBBM Testing and Treatment,Averted Deaths Associated with AD,Number,0.000000,0.000000,0.000000
3,2025,Brazil,25_to_29,Male,AD Dementia,BBBM Testing and Treatment,Averted Deaths Associated with AD,Number,0.000000,0.000000,0.000000
4,2025,Brazil,25_to_29,Female,AD Dementia,Reference,Deaths Associated with AD,Number,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...
53995,2060,United States of America,95_plus,Male,AD Dementia,Reference,Deaths Associated with AD,Number,7429.848949,6255.963409,8252.754277
53996,2060,United States of America,95_plus,Female,AD Dementia,BBBM Testing Only,Deaths Associated with AD,Number,17903.336207,16910.776729,20713.459752
53997,2060,United States of America,95_plus,Male,AD Dementia,BBBM Testing Only,Deaths Associated with AD,Number,7429.848949,6255.963409,8252.754277
53998,2060,United States of America,95_plus,Female,AD Dementia,BBBM Testing and Treatment,Deaths Associated with AD,Number,17946.222330,16951.327141,20683.576600


In [216]:
# TODO: Chedk this
deaths_output.loc[
    (deaths_output['Year'] == 2055)
    & (deaths_output['Age'] == '80_to_84')
    & (deaths_output['Sex'] == 'Female')
    & (deaths_output['Disease Stage'] == 'AD Dementia')
    & (deaths_output['Metric'] == 'Number')
    & (deaths_output['Location'] == 'Brazil')
]

Unnamed: 0,Year,Location,Age,Sex,Disease Stage,Scenario,Measure,Metric,Mean,95% UI Lower,95% UI Upper
39900,2055,Brazil,80_to_84,Female,AD Dementia,BBBM Testing Only,Averted Deaths Associated with AD,Number,0.0,0.0,0.0
39902,2055,Brazil,80_to_84,Female,AD Dementia,BBBM Testing and Treatment,Averted Deaths Associated with AD,Number,677.267979,428.505008,770.994766
39904,2055,Brazil,80_to_84,Female,AD Dementia,Reference,Deaths Associated with AD,Number,38556.219109,31437.38888,45259.430006
39906,2055,Brazil,80_to_84,Female,AD Dementia,BBBM Testing Only,Deaths Associated with AD,Number,38556.219109,31437.38888,45259.430006
39908,2055,Brazil,80_to_84,Female,AD Dementia,BBBM Testing and Treatment,Deaths Associated with AD,Number,37878.951129,30691.966355,44553.509627


In [133]:
deaths_output.to_csv(output_dir / "deaths.csv", index=False)

# DALYs

In [219]:
# ylls.entity.unique(): ['alzheimers_disease_state', 'other_causes']
ylls_filter = [('entity', '==', 'alzheimers_disease_state')]
ylls = load_sim_output('ylls', filters=ylls_filter)
print(len(ylls), 'rows')
print_memory_usage(ylls)

('artifact_path', 'in', ['/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/united_states_of_america.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/brazil.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/china.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/germany.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/israel.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/japan.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/spain.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/sweden.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/taiwan_(province_of_china).hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/a

In [220]:
# ylds.entity.unique():
# ['alzheimers_disease_and_other_dementias', 'treatment', 'all_causes']
ylds_filter = [('entity', '==', 'alzheimers_disease_and_other_dementias')]
ylds = load_sim_output('ylds', filters=ylds_filter)
print(len(ylds), 'rows')
print_memory_usage(ylds)

('artifact_path', 'in', ['/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/united_states_of_america.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/brazil.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/china.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/germany.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/israel.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/japan.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/spain.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/sweden.hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model8.3/taiwan_(province_of_china).hdf', '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/a

In [217]:
ylds.tail()

Unnamed: 0,age_group,artifact_path,entity,entity_type,event_year,input_draw,measure,scenario,sex,sub_entity,value,location
485995,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_and_other_dementias,cause,2060,457,ylds,bbbm_testing_and_treatment,Female,alzheimers_mild_cognitive_impairment_state,50.921112,China
485996,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_and_other_dementias,cause,2060,457,ylds,bbbm_testing_and_treatment,Female,alzheimers_disease_state,2498.046147,China
485997,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_and_other_dementias,cause,2060,457,ylds,bbbm_testing_and_treatment,Male,alzheimers_blood_based_biomarker_state,0.0,China
485998,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_and_other_dementias,cause,2060,457,ylds,bbbm_testing_and_treatment,Male,alzheimers_mild_cognitive_impairment_state,3.943808,China
485999,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_and_other_dementias,cause,2060,457,ylds,bbbm_testing_and_treatment,Male,alzheimers_disease_state,151.782175,China


In [105]:
ylds.entity.unique()

['alzheimers_disease_and_other_dementias']
Categories (3, object): ['all_causes', 'alzheimers_disease_and_other_dementias', 'treatment']

In [218]:
ylls.tail()

Unnamed: 0,age_group,artifact_path,entity,entity_type,event_year,input_draw,measure,scenario,sex,value,location
161995,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,457,ylls,baseline,Male,450.140138,China
161996,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,457,ylls,bbbm_testing,Female,7615.669815,China
161997,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,457,ylls,bbbm_testing,Male,450.140138,China
161998,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,457,ylls,bbbm_testing_and_treatment,Female,7670.896236,China
161999,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_state,cause,2060,457,ylls,bbbm_testing_and_treatment,Male,450.140138,China


In [101]:
ylls.entity.unique()

['alzheimers_disease_state', 'other_causes']
Categories (2, object): ['alzheimers_disease_state', 'other_causes']

In [None]:
def process_dalys(ylls, ylds, ops=ops):
    """Process YLLs and YLDs dataframes to get DALYs and averted DALYs.
    """
    # Filter to only YLLs and YLDs due to AD, and rename so the entity
    # is the same between the two, so that the VPHResults object will
    # add YLLs dand YLDs instead of keeping them separate 
    ylls = (
        ylls
        .query("entity=='alzheimers_disease_state'")
        # Choose an arbitrary diseas name
        .replace({'entity': {'alzheimers_disease_state': 'AD'}})
        # Add a sub_entity column to specify disease stage
        .assign(sub_entity='alzheimers_disease_state')
        # Assign 0 YLLs to the MCI state so that when we sum with YLDs,
        # DALYs for MCI will equal YLDs. If we didn't add these 0's, it
        # would just aggregate across disease states instead of keeping
        # them separate.
        .pipe(
            lambda df: pd.concat([df, df.assign(
                sub_entity='alzheimers_mild_cognitive_impairment_state',
                value=0.0
            )])
        )
        .pipe(convert_to_categorical)
    )
    ylds = (
        ylds
        .query("entity=='alzheimers_disease_and_other_dementias'")
        # Choose the same arbitrary diseas name
        .replace({'entity': {'alzheimers_disease_and_other_dementias': 'AD'}})
        .pipe(convert_to_categorical)
    )
    # Create a VPHResults object to calculate DALYs
    results = VPHResults(ylls=ylls, ylds=ylds, ops=ops)
    # Calculate DALYs and compress
    dalys = results.get_burden('dalys').pipe(convert_to_categorical)
    # print_memory_usage(dalys, 'dalys')
    # print(dalys.dtypes)

    # Calculate averted DALYs
    averted_dalys = (
        ops.averted(dalys, baseline_scenario='baseline')
        .assign(measure='Averted DALYs Associated with AD')
    )
    dalys = (
        dalys
        # Rename the measure
        .assign(measure='DALYs Associated with AD')
        # Concatenate deaths with averted DALYs
        # TODO: Concatenate with rates
        .pipe(lambda df:
              # Use inner join to drop "subtracted_from" column added by
              # .averted
              pd.concat([df, averted_dalys], join='inner', ignore_index=True))
        .assign(Metric='Number')
        .pipe(convert_to_categorical)
    )
    return dalys

In [None]:
dalys = process_dalys(ylls, ylds)
print_memory_usage(dalys)
print(len(dalys), 'rows')
dalys.tail()

9.241499 MB measure
4.381367 MB minuend
8.755367 MB subtrahend
3.086028 MB minuend re-indexed
6.164028 MB subtrahend re-indexed
6.164088 MB difference
6.163499 MB difference with reset index
6.487824 MB final difference
16.207778 MB 


Unnamed: 0,age_group,artifact_path,entity,entity_type,event_year,measure,sex,sub_entity,input_draw,scenario,location,value,Metric
809995,95_plus,/mnt/team/simulation_science/pub/models/vivari...,AD,cause,2060,Averted DALYs Associated with AD,Female,alzheimers_disease_state,457,bbbm_testing_and_treatment,China,-83.755582,Number
809996,95_plus,/mnt/team/simulation_science/pub/models/vivari...,AD,cause,2060,Averted DALYs Associated with AD,Female,alzheimers_mild_cognitive_impairment_state,457,bbbm_testing_and_treatment,China,-1.937065,Number
809997,95_plus,/mnt/team/simulation_science/pub/models/vivari...,AD,cause,2060,Averted DALYs Associated with AD,Male,alzheimers_blood_based_biomarker_state,457,bbbm_testing_and_treatment,China,0.0,Number
809998,95_plus,/mnt/team/simulation_science/pub/models/vivari...,AD,cause,2060,Averted DALYs Associated with AD,Male,alzheimers_disease_state,457,bbbm_testing_and_treatment,China,-2.669536,Number
809999,95_plus,/mnt/team/simulation_science/pub/models/vivari...,AD,cause,2060,Averted DALYs Associated with AD,Male,alzheimers_mild_cognitive_impairment_state,457,bbbm_testing_and_treatment,China,-0.09755,Number


## Do some quick checks

In [None]:
# Verify that DALYs == YLDs except in AD dementia state
df1 = dalys.query("sub_entity!='alzheimers_disease_state' and ~measure.str.contains('Averted')").drop(columns=['entity', 'measure', 'Metric'])
df2 = ylds.query("sub_entity!='alzheimers_disease_state'").drop(columns=['entity', 'measure'])
temp = ops.compare_values(df1, df2)
assert len(temp) == 0, 'DALYs differ from YLDs in MCI or BBBM state!'
temp 


age_group,artifact_path,entity_type,event_year,input_draw,location,scenario,sex,sub_entity


In [None]:
# Check that DALYs are always greater than YLDs
df1 = dalys.query("~measure.str.contains('Averted')").drop(columns=['entity', 'measure', 'Metric'])
df2 = ylds.drop(columns=['entity', 'measure'])
assert ((ops.value(df1) - ops.value(df2)) >= 0).value.all(), "DALYs are less than YLDs!"

## Summarize DALYs and save to file

In [None]:
dalys_output = summarize_and_beautify(dalys)
print_memory_usage(dalys_output)
print(len(dalys_output), 'rows')
dalys_output.tail()

2025-10-31 16:20:55.583040
2025-10-31 16:20:57.727263
2025-10-31 16:23:15.324897
15.945151 MB 


Unnamed: 0,Year,Location,Age,Sex,Disease Stage,Scenario,Measure,Metric,Mean,95% UI Lower,95% UI Upper
161995,2060,United States of America,95_plus,Female,AD Dementia,BBBM Testing and Treatment,Averted DALYs Associated with AD,Number,-388.573131,-1035.484113,304.878181
161996,2060,United States of America,95_plus,Female,MCI due to AD,BBBM Testing and Treatment,Averted DALYs Associated with AD,Number,-8.611066,-14.17629,-1.641364
161997,2060,United States of America,95_plus,Male,Preclinical AD,BBBM Testing and Treatment,Averted DALYs Associated with AD,Number,0.0,0.0,0.0
161998,2060,United States of America,95_plus,Male,AD Dementia,BBBM Testing and Treatment,Averted DALYs Associated with AD,Number,-41.925535,-393.899789,317.289804
161999,2060,United States of America,95_plus,Male,MCI due to AD,BBBM Testing and Treatment,Averted DALYs Associated with AD,Number,-4.442485,-12.458546,-0.339223


In [215]:
dalys_output.to_csv(output_dir / 'dalys.csv', index=False)

In [136]:
324000/162000

2.0

In [137]:
ylds

Unnamed: 0,age_group,artifact_path,entity,entity_type,event_year,input_draw,measure,scenario,sex,sub_entity,value,location
0,25_to_29,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_and_other_dementias,cause,2025,169,ylds,baseline,Female,alzheimers_blood_based_biomarker_state,0.000000,Japan
1,25_to_29,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_and_other_dementias,cause,2025,169,ylds,baseline,Female,alzheimers_mild_cognitive_impairment_state,0.000000,Japan
2,25_to_29,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_and_other_dementias,cause,2025,169,ylds,baseline,Female,alzheimers_disease_state,0.000000,Japan
3,25_to_29,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_and_other_dementias,cause,2025,169,ylds,baseline,Male,alzheimers_blood_based_biomarker_state,0.000000,Japan
4,25_to_29,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_and_other_dementias,cause,2025,169,ylds,baseline,Male,alzheimers_mild_cognitive_impairment_state,0.000000,Japan
...,...,...,...,...,...,...,...,...,...,...,...,...
485995,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_and_other_dementias,cause,2060,457,ylds,bbbm_testing_and_treatment,Female,alzheimers_mild_cognitive_impairment_state,50.921112,China
485996,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_and_other_dementias,cause,2060,457,ylds,bbbm_testing_and_treatment,Female,alzheimers_disease_state,2498.046147,China
485997,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_and_other_dementias,cause,2060,457,ylds,bbbm_testing_and_treatment,Male,alzheimers_blood_based_biomarker_state,0.000000,China
485998,95_plus,/mnt/team/simulation_science/pub/models/vivari...,alzheimers_disease_and_other_dementias,cause,2060,457,ylds,bbbm_testing_and_treatment,Male,alzheimers_mild_cognitive_impairment_state,3.943808,China
