# Demonstrate how to replace population and mortality data for Alzheimer's Model 3 with forecasts

    Age-specific population - /mnt/share/forecasting/data/9/future/population/20240320_daly_capstone_resubmission_squeeze_soft_round_shifted_hiv_shocks_covid_all_who_reagg/population_agg.nc

    Deaths Rates - /snfs1/Project/forecasting/results/7/future/death/20240320_daly_capstone_resubmission_squeeze_soft_round_shifted_hiv_shocks_covid_all_who_reagg/



In [1]:
import vivarium
!date

Tue Aug 12 21:08:31 PDT 2025


In [2]:
import xarray as xr, numpy as np, pandas as pd, matplotlib.pyplot as plt

In [3]:
# copy model 2 artifact to a new place for modification
# TODO: make this work for all location artifacts
old_artifact_path = '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model2.0/united_states_of_america.hdf'
new_artifact_dir = '/mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model3.0/'
!cp $old_artifact_path $new_artifact_dir   # TODO: modify an artifact being built for the current model instead of copying the artifact from model 2
!ls -halt $new_artifact_dir/united_states_of_america.hdf

-rw-r--r-- 1 abie IHME-Simulationscience 2.9M Aug 12 21:08 /mnt/team/simulation_science/pub/models/vivarium_csu_alzheimers/artifacts/model3.0//united_states_of_america.hdf


In [4]:
# take a look at what we are planning to replace
art = vivarium.Artifact(f'{new_artifact_dir}/united_states_of_america.hdf') # TODO: make this generate for any location
art.keys

['metadata.keyspace',
 'metadata.locations',
 'population.location',
 'population.structure',
 'population.age_bins',
 'population.demographic_dimensions',
 'population.theoretical_minimum_risk_life_expectancy',
 'cause.all_causes.cause_specific_mortality_rate',
 'covariate.live_births_by_sex.estimate',
 'cause.alzheimers_disease_and_other_dementias.prevalence_scale_factor',
 'cause.alzheimers_disease_and_other_dementias.prevalence',
 'cause.alzheimers_disease_and_other_dementias.incidence_rate',
 'cause.alzheimers_disease_and_other_dementias.cause_specific_mortality_rate',
 'cause.alzheimers_disease_and_other_dementias.excess_mortality_rate',
 'cause.alzheimers_disease_and_other_dementias.disability_weight',
 'cause.alzheimers_disease_and_other_dementias.restrictions']

In [5]:
# art.load('population.structure')

In [6]:
# art.load('cause.all_causes.cause_specific_mortality_rate')

## make tables from FHS .nc files

In [7]:
def table_from_nc(fname_dict, param, loc_id, loc_name, age_mapping):
    ds = xr.open_dataset(
        fname_dict[param],
        engine="netcdf4",               # let xarray auto-detect; list here if you know it
        decode_cf=True,                 # handle CF-conventions (time units, etc.)
    )

    # Select relevant part of FHS dataset
    var_name = param
    if param == 'births':
        var_name = 'population'
    elif param in ['migration', 'mortality']:
        var_name = 'value'
        
    pop_ts = ds[var_name].sel(
              location_id=loc_id,
          )

    if param != 'migration':
        pop_ts = pop_ts.isel(scenario=0)
          
    pop_ts = pop_ts.squeeze(drop=True)  # remove now-singleton dims
    
    df = pop_ts.to_dataframe(name='value').reset_index()
    
    # Transform to vivarium format
    # 1. Convert location_id to location name
    df['location'] = loc_name
    
    # 2. Convert sex_id to sex names
    sex_mapping = {1: 'Male', 2: 'Female',}
    df['sex'] = df['sex_id'].map(sex_mapping)
    
    # 3. Convert age_group_id to age intervals
    if param != 'births':
        age_bins = age_mapping.set_index('age_group_id')
        df['age_start'] = np.round(df['age_group_id'].map(age_bins['age_start']), 3)
        df['age_end'] = np.round(df['age_group_id'].map(age_bins['age_end']), 3)
        age_cols = ['age_start', 'age_end']
    else:
        age_cols = []
    
    # 4. Convert year_id to year intervals
    df['year_start'] = df['year_id'].map(int)
    df['year_end'] = df['year_id'].map(int)+1
        
    # 5. Set index and unstack to get draw columns
    index_cols = ['location', 'sex',] + age_cols + ['year_start', 'year_end', 'draw']
    df_indexed = df.dropna(subset=index_cols).set_index(index_cols)
    
    df_wide = df_indexed['value'].unstack(level='draw')
    
    # 6. Rename columns to draw_x format
    df_wide.columns = [f'draw_{col}' for col in df_wide.columns]
    
    return df_wide

In [8]:
fname_dict = {
    'population': '/mnt/share/forecasting/data/9/future/population/20240320_daly_capstone_resubmission_squeeze_soft_round_shifted_hiv_shocks_covid_all_who_reagg/population_agg.nc',
#     'births': '/mnt/share/forecasting/data/9/future/live_births/20231204_ref/live_births.nc',
#     'deaths': '/snfs1/Project/forecasting/results/7/future/death/20240320_daly_capstone_resubmission_squeeze_soft_round_shifted_hiv_shocks_covid_all_who_reagg_num/_all.nc',
    'mortality': '/snfs1/Project/forecasting/results/7/future/death/20240320_daly_capstone_resubmission_squeeze_soft_round_shifted_hiv_shocks_covid_all_who_reagg/_all.nc',
#     'migration': '/mnt/share/forecasting/data/6/future/migration/20230605_loc_intercept_shocks_pg_21LOCS_ATTENUATED/migration.nc',
}

# These IDs are listed on https://shiny.ihme.washington.edu/content/273/
loc_id, loc_name  = 102, 'United States of America'   # TODO: make this work for any location
age_mapping = pd.read_csv(f'data/age_bins.csv')  # downloaded with a different environment
                                                # using vivarium_inputs.utility_data.get_age_bins;
                                                # see also https://shiny.ihme.washington.edu/content/273/
                                                # TODO: refactor this so that it does not require an extra .csv file

In [None]:
%%time

df = {}

for param in ['population', 'mortality']:
    df[param] = table_from_nc(fname_dict, param, loc_id, loc_name, age_mapping)  # slow, takes minutes to hours to run, depending on disk caching

In [None]:
art.load('population.structure') # old population structure

In [None]:
df['population']['value'] = df['population'].mean(axis=1)
new_pop_structure = df['population'].filter(like='value')

In [None]:
new_pop_structure

In [None]:
art.replace('population.structure', new_pop_structure)
art.load('population.structure') # new population structure

In [None]:
art.load('cause.all_causes.cause_specific_mortality_rate')  # old mortality rates

In [None]:
df['mortality']

In [None]:
art.replace('cause.all_causes.cause_specific_mortality_rate', df['mortality'].loc['United States of America'])
art.load('cause.all_causes.cause_specific_mortality_rate')  # new mortality rates

In [None]:
new_pop_bins = art.load('population.age_bins').query('age_start >= 5.0')
art.replace('population.age_bins', new_pop_bins)

In [None]:
# now we need to replicate the prevalence scale factor for each year
df = art.load('cause.alzheimers_disease_and_other_dementias.prevalence_scale_factor')

In [None]:
index_cols = list(df.index.names)
df = df.reset_index()

In [None]:
df_list = []

for y in range(2022, 2051):
    df_y = df.copy()
    df_y['year_start'] = y
    df_y['year_end'] = y+1
    df_list.append(
        df_y
    )

In [None]:
df_all = pd.concat(df_list).set_index(index_cols)
df_all

In [None]:
art.replace('cause.alzheimers_disease_and_other_dementias.prevalence_scale_factor', df_all)

In [None]:
!ls -halt $new_artifact_dir/united_states_of_america.hdf