In [1]:
from pathlib import Path
import tqdm
from multiprocessing import Pool
from functools import partial
import pickle

import warnings
warnings.simplefilter('ignore')

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# import db_queries

from rra_climate_health.data_prep.location_mapping import FHS_HIERARCHY_PATH

ROOT = Path('/mnt/team/rapidresponse/pub/population/modeling/climate_malnutrition')


In [2]:
fhs_loc_meta = pd.read_parquet(FHS_HIERARCHY_PATH).sort_values('sort_order').reset_index(drop=True)

# location_ids = fhs_loc_meta.loc[fhs_loc_meta['most_detailed'] == 1, 'location_id'].to_list()
versions = {
    'stunting': {
        'old model': {
            'model': '2024_09_09.02',
            'results': '2024_09_09.01',
        },
        'new model': {
            'model': '2024_09_09.02',
            'results': '2024_10_18.03',
        },
    },
    'wasting': {
        'old model': {
            'model': '2024_09_09.02',
            'results': '2024_09_09.01',
        },
        'new model': {
            'model': '2024_09_09.02',
            'results': '2024_10_18.03',
        },
    },
}
scenarios = ['ssp245']
year_ids = list(range(2020, 2023))  # list(range(2020, 2101))
age_group_ids = [4, 5]
sex_ids = [1, 2]


In [3]:
# def load_prediction(location_id: int, root: Path, label: str, measure: str, scenario: str, year_id: int):
#     try:
#         data = pd.read_parquet(root / 'results' / label / f'{measure}_{location_id}_{scenario}_{year_id}.parquet').reset_index()
#         data['location_id'] = data['fhs_location_id'].astype(int)
#         return data.set_index(['location_id', 'year_id', 'age_group_id', 'sex_id', 'scenario', 'measure']).loc[:, 'affected_proportion']
#     except FileNotFoundError:
#         return pd.Series()

def load_model(root: Path, measure: str, version: str, age_group_id: int, sex_id: int):
    with open(root / measure / 'models'/ version / f'{age_group_id}_{sex_id}.pkl', 'rb') as file:
        model = pickle.load(file)
    return model


def load_random_effects(root: Path, measure: str, version: str,
                        version_label: str, age_group_id: int, sex_id: int, fhs_loc_meta: pd.DataFrame):
    data = load_model(root, measure, version, age_group_id, sex_id).ranef
    data['version_label'] = version_label
    data['age_group_id'] = age_group_id
    data['sex_id'] = sex_id
    data['measure'] = measure
    fhs_loc_meta['ihme_loc_id'] = fhs_loc_meta['ihme_loc_id'].str[:3]
    data = data.join(fhs_loc_meta.set_index('ihme_loc_id').loc[:, 'location_id'])

    return data.set_index(['version_label', 'measure', 'location_id', 'age_group_id', 'sex_id']).loc[:, 'X.Intercept.']


def load_prediction(year_id: int, root: Path, measure: str, version: str, version_label: str, scenario: str):
    data = pd.read_parquet(root / measure / 'results' / version / f'{year_id}_{scenario}.parquet')
    data['year_id'] = year_id
    data['version_label'] = version_label
    data['measure'] = measure

    return data.set_index(['version_label', 'measure', 'location_id', 'year_id', 'age_group_id', 'sex_id']).loc[:, 'value'].sort_index()


random_effects = []
predictions = []
for measure, measure_versions in versions.items():
    for version_label, version_dict in measure_versions.items():
        for age_group_id in age_group_ids:
            for sex_id in sex_ids:
                random_effects.append(
                    load_random_effects(ROOT, measure, version_dict['model'],
                                        version_label, age_group_id, sex_id, fhs_loc_meta)
                )
        for year_id in year_ids:
            for scenario in scenarios:
                # load_predictions = partial(
                #     load_prediction,
                #     root=ROOT,
                #     label=LABEL,
                #     measure=measure,
                #     scenario=scenario,
                #     year_id=year_id,
                # )
                # with Pool(20) as pool:
                #     location_predictions = list(pool.imap(load_predictions, location_ids))
                predictions.append(load_prediction(year_id, ROOT, measure, version_dict['results'], version_label, scenario))
predictions = pd.concat(predictions).sort_index().rename('pred')
random_effects = pd.concat(random_effects).sort_index().rename('ranef')


FileNotFoundError: [Errno 2] No such file or directory: '/mnt/team/rapidresponse/pub/population/modeling/climate_malnutrition/stunting/models/2024_09_09.02/4_1.pkl'

In [None]:
predictions

In [None]:
random_effects

In [20]:
# kwargs = {}
# for id_var in ['location_id', 'year_id', 'age_group_id', 'sex_id']:
#     kwargs[id_var] = predictions.index.get_level_values(id_var).unique().tolist()
# gbd = db_queries.get_outputs(
#     'rei',
#     release_id=9,
#     measure_id=29,
#     metric_id=3,
#     rei_id=[240, 241],
#     **kwargs
# )
# gbd['measure'] = gbd['rei'].str.replace('nutrition_', '')
# gbd = gbd.set_index(['location_id', 'year_id', 'age_group_id', 'sex_id', 'measure']).sort_index().loc[:, 'val'].rename('gbd')

# plot_data = predictions.loc[:, :, :, :, 'ssp119', :].to_frame().join(gbd).dropna()

# plt.scatter(
#     plot_data['gbd'],
#     plot_data['pred'],
#     alpha=0.25
# )
# plt.plot((0, 1), (0, 1), color='red')

In [21]:
# me_ids = db_queries.get_ids('modelable_entity')
# me_ids.loc[me_ids['modelable_entity_id'].isin([8949, 8950, 8951, 10556, 10557,
#                                                8945, 8946, 8947, 10558, 10559])]


In [22]:
# me_ids = db_queries.get_ids('modelable_entity')
# me_ids.loc[
#     (me_ids['modelable_entity_name'].str.contains('asting'))
#     & (me_ids['modelable_entity_name'].str.contains('oderate'))
# ]
# me_ids.loc[
#     (me_ids['modelable_entity_name'].str.contains('tunting'))
#     & (me_ids['modelable_entity_name'].str.contains('oderate'))
# ]

# 10557 --> Mild Stunting, < -1 SD (post-ensemble)
# 10559 --> Mild Wasting, < -1 SD (post-ensemble)

# 10556 --> Moderate Stunting, < -2 SD (post-ensemble)
# 10558 --> Moderate Wasting, < -2 SD (post-ensemble)


In [23]:
# age_metadata = db_queries.get_age_metadata(release_id=9)
age_metadata = pd.read_parquet(ROOT / 'input' / 'gbd_prevalence' / 'age_metadata.parquet')
age_metadata = age_metadata.loc[age_metadata['age_group_days_start'] >= 28]
age_metadata = age_metadata.loc[age_metadata['age_group_years_end'] <= 5]

kwargs = {}
for id_var in ['location_id', 'year_id', 'sex_id']:  # , 'age_group_id'
    kwargs[id_var] = predictions.index.get_level_values(id_var).unique().tolist()

population = pd.read_parquet(ROOT / 'input' / 'gbd_prevalence' / 'population.parquet')

gbd = []
for measure, me_id in [('stunting', 10556), ('wasting', 10558)]:  # Moderate
# for measure, me_id in [('stunting', 10557), ('wasting', 10559)]:  # Mild
# for measure, me_id in [('stunting', 8949), ('stunting', 8950), ('stunting', 8951),
#                        ('wasting', 8945), ('wasting', 8946), ('wasting', 8947)]:  # Combine individual models
    # _gbd = db_queries.get_model_results(
    #     'epi',
    #     me_id,
    #     release_id=9,
    #     measure_id=5,
    #     age_group_id=age_metadata['age_group_id'].to_list(),
    #     **kwargs
    # )
    _gbd = pd.read_parquet(ROOT / 'input' / 'gbd_prevalence' / f'{measure}.parquet')
    _gbd['measure'] = measure
    gbd.append(_gbd)
gbd = pd.concat(gbd)
gbd = gbd.groupby(['measure', 'location_id', 'year_id', 'age_group_id', 'sex_id'], as_index=False)['mean'].sum()

gbd = gbd.merge(population)
gbd = gbd.merge(age_metadata.loc[:, ['age_group_id', 'age_group_years_start', 'age_group_years_end']])
gbd.loc[gbd['age_group_years_end'] <= 1, 'age_group_id'] = 4
gbd.loc[gbd['age_group_years_start'] >= 1, 'age_group_id'] = 5
gbd['gbd'] = gbd['mean'] * gbd['population']
gbd = gbd.groupby(['measure', 'location_id', 'year_id', 'age_group_id', 'sex_id'])['gbd', 'population'].sum()
gbd['gbd'] /= gbd['population']


In [24]:
# inputs = pd.read_parquet('/mnt/team/rapidresponse/pub/population/modeling/climate_malnutrition/stunting/training_data/2024_06_30.02/data.parquet')
# inputs['scenario'] = 'ssp119'
# inputs['year_id'] = inputs['int_year']
# for int_var in ['location_id', 'year_id', 'age_group_id', 'sex_id']:
#     inputs[int_var] = inputs[int_var].astype(int)
# inputs = inputs.groupby(['location_id', 'year_id', 'age_group_id', 'sex_id', 'scenario'])[['stunting', 'wasting']].mean()
# inputs.columns.name = 'measure'
# inputs = inputs.stack().rename('pred')


In [None]:
plot_data = predictions.to_frame().join(gbd).dropna().join(random_effects, how='left').reorder_levels(predictions.index.names)
plot_data['has_raneff'] = plot_data['ranef'].notnull()
plot_data = plot_data.loc[:, :, :, :, [4], [1]]
## plot_data = inputs.loc[:, :, :, :, 'ssp119', :].to_frame().join(gbd).dropna()
plot_data['pred'] *= plot_data['population']
plot_data['gbd'] *= plot_data['population']
plot_data = plot_data.groupby(['version_label', 'measure', 'has_raneff', 'location_id', 'year_id'])['pred', 'gbd', 'population'].sum()
plot_data['pred'] /= plot_data['population']
plot_data['gbd'] /= plot_data['population']
plot_data = plot_data.drop('population', axis=1)
plot_data.describe()
# plot_data = np.log(plot_data + plot_data['gbd'].min() / 10)


In [None]:
plot_data.loc[:, 'wasting', True, :, :]

In [None]:
foo = plot_data.loc[['base model'], ['wasting'], True, :, :].join(random_effects)
foo['resid'] = foo['pred'] - foo['gbd']
plt.scatter(foo['ranef'], foo['pred'], alpha=0.5)
plt.scatter(foo['ranef'], foo['gbd'], alpha=0.5)


In [None]:
sns.set_style('whitegrid')

for measure, measure_versions in versions.items():
    # if measure == 'stunting':
    marker = 'o'
    fig, ax = plt.subplots(2, 2, figsize=(11, 8.5))
    for i, (version_label, version_dict) in enumerate(measure_versions.items()):
        idx = int(i >= 2), i % 2
        label = f"{version_label}"
        sublabel = f""
        # version = version_dict['results']
        # for measure, marker in [('stunting', 'o')]:  # , ('wasting', 's')
        for re, color in [(True, 'mediumseagreen'), (False, 'mediumorchid')]:
            rmse = np.round(
                (
                    (
                        plot_data.loc[version_label, measure, re, :, :].loc[:, 'gbd'] - plot_data.loc[version_label, measure, re, :, :].loc[:, 'pred']
                    ) ** 2
                ).mean() ** 0.5,
                4
            )
            if re:
                legend_label = f'{measure}, locs in model (RMSE: {rmse})'
            else:
                legend_label = f'{measure}, locs not in model (RMSE: {rmse})'
            ax[idx].scatter(
                plot_data.loc[version_label, measure, re, :, :].loc[:, 'gbd'],
                plot_data.loc[version_label, measure, re, :, :].loc[:, 'pred'],
                color=color,
                s=100,
                alpha=0.1,
                marker=marker,
            )
            ax[idx].scatter(
                np.nan,
                np.nan,
                color=color,
                label=legend_label,
                s=100,
                alpha=1.,
                marker=marker,
            )
        ax[idx].plot((0, 1), (0, 1), color='red')
        ax[idx].set_ylabel('Prediction')
        ax[idx].set_xlabel('GBD 2021')
        ax[idx].set_ylim(0, 1)
        ax[idx].set_xlim(0, 1)
        ax[idx].set_title(label)
        ax[idx].legend()
    fig.tight_layout()
    fig.show()


In [None]:
# stunting_grid = stunting_data.groupby('grid_cell').stunting.mean()
# wasting_grid = wasting_data.groupby('grid_cell').stunting.mean()
model.formula

In [None]:
pd.concat([stunting_grid, wasting_grid.rename('wasting')], axis=1).dropna()

In [None]:

import pickle

def load_model_data(root: Path, label: str, measure: str, age_group_id: int, sex_id: int):
    with open(root / 'models' / label / f'model_{measure}_{age_group_id}_{sex_id}.pkl', 'rb') as file:
        model = pickle.load(file)
    data = model.data
    data['age_group_id'] = age_group_id
    data['sex_id'] = sex_id
    data['measure'] = measure
    data = data.rename(columns={measure: 'obs'})
    return data.groupby(['ihme_loc_id', 'age_group_id', 'sex_id', 'measure'])[['obs', 'fits']].mean()


In [None]:
new = pd.read_parquet('/mnt/team/rapidresponse/pub/population/modeling/climate_malnutrition/stunting/training_data/2024_06_30.02/data.parquet')
for int_var in ['age_group_id', 'sex_id']:
    new[int_var] = new[int_var].astype(int)
new = new.groupby(['ihme_loc_id', 'age_group_id', 'sex_id'])[['stunting', 'wasting']].mean()
new.columns.name = 'measure'
new = new.stack().rename('pred')

old = []
for age_group_id in [4, 5]:
    for sex_id in [1, 2]:
        old.append(load_model_data(ROOT, LABEL, 'stunting', age_group_id, sex_id))
old = pd.concat(old).sort_index()
old

In [None]:
old.loc['ZWE']

In [None]:
foo = pd.concat([
    new.rename('new'),
    old.rename('old')
], axis=1).loc[:, :, :, 'stunting']
plt.scatter(
    foo['new'],
    foo['old'],
)
plt.xlabel('From dataset')
plt.ylabel('From model object')
plt.show()

In [None]:
foo.loc['ZWE']

In [None]:
predictions.loc[198, :, :, :, 'ssp119', 'stunting'].groupby(['age_group_id', 'sex_id']).mean()

In [None]:
root=ROOT
label=LABEL
measure='stunting'
age_group_id=5
sex_id=1
with open(root / 'models' / label / f'model_{measure}_{age_group_id}_{sex_id}.pkl', 'rb') as file:
    model = pickle.load(file)
md = model.data
md = md.loc[md['ihme_loc_id'] == 'ZWE']
md = md.groupby('grid_cell')[['stunting', 'fits']].mean()
plt.scatter(md['stunting'], md['fits'])
plt.xlabel('obs')
plt.ylabel('pred')
plt.plot((0, 1), (0, 1), color='red')
plt.show()


In [None]:
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')

measure = 'wasting'
VERSIONS = ['2024_07_09.01', '2024_07_09.02', '2024_07_09.03', '2024_07_09.04']

input_data = []
models = {}
for version in VERSIONS:
    for age_group_id in [4, 5]:
        for sex_id in [1, 2]:
            with open(f'/mnt/team/rapidresponse/pub/population/modeling/climate_malnutrition/{measure}/models/{version}/{age_group_id}_{sex_id}.pkl', 'rb') as file:
                model = pickle.load(file)
            data = model.data
            data['version'] = version
            data['measure'] = measure
            data['age_group_id'] = age_group_id
            data['sex_id'] = sex_id
            data = data.rename(columns={measure: 'obs'})
            input_data.append(data)
            models[f'{version}_{age_group_id}_{sex_id}'] = model
input_data = pd.concat(input_data)
plot_inputs = input_data.groupby(['version', 'ihme_loc_id', 'age_group_id', 'sex_id', 'measure'])[['obs', 'fits']].mean()

plt.scatter(
    plot_inputs['obs'],
    plot_inputs['fits'],
    s=20,
    alpha=0.5
)
plt.plot(
    (0, 1),
    (0, 1),
    color='red'
)
plt.xlim(0, 1)
plt.xlabel('Obs')
plt.ylim(0, 1)
plt.ylabel('Pred')
plt.tight_layout()
plt.show()


In [None]:
data = input_data.loc[input_data['version'] == '2024_07_09.01']
data['lin_residuals'] = 1 / (1 + np.exp(-data['residuals']))
data = data.groupby(['any_days_over_30C', 'ihme_loc_id'])['obs', 'fits', 'lin_residuals'].mean()
data.loc[0, 'obs'].hist(color='dodgerblue', alpha=0.75, label='0 days over 30')
data.loc[1, 'obs'].hist(color='firebrick', alpha=0.75, label='1+ days over 30')
plt.show()
data.loc[0, 'fits'].hist(color='dodgerblue', alpha=0.75, label='0 days over 30')
data.loc[1, 'fits'].hist(color='firebrick', alpha=0.75, label='1+ days over 30')
plt.show()

In [None]:
1 / (
    1 + np.exp(
        -(
            -2.065182 * 1 # intercept
            + 0.256654 * 1 # days over 30
            + -0.441034 * 0 # income
            + 0.242929 * 1 # any days over 30
            + 0.359003 * 0 # income/any days over 30
        )
    )
)

In [None]:
foo = 0
for i in range(10):
    foo += 0.1 * 1 / (
        1 + np.exp(
            -(
                # -2.065182 * 1 # intercept
                + 0.256654 * 0 # days over 30
                + -0.441034 * 5 # income
                + 0.242929 * 0 # any days over 30
                + 0.359003 * 0 # income/any days over 30
            )
        )
    )
foo

In [None]:
print(measure)
coefs = []
for label, model in models.items():
    coefs.append(models[label].coefs['Estimate'].rename(label))
pd.concat(coefs, axis=1)

In [None]:
print(measure)
coefs = []
for label, model in models.items():
    coefs.append(models[label].coefs['Estimate'].rename(label))
pd.concat(coefs, axis=1)

In [None]:
# dir(models['2024_07_09.01_4_1'])
measure = 'wasting'
data = pd.read_parquet(f'/mnt/team/rapidresponse/pub/population/modeling/climate_malnutrition/{measure}/training_data/2024_07_09.01/data.parquet')
data['year_id'] = data['int_year']
id_vars = ['location_id', 'year_id', 'age_group_id', 'sex_id']
for id_var in id_vars:
    data[id_var] = data[id_var].astype(int)
data = data.groupby(id_vars)['wasting'].mean().rename('dhs')
data = gbd.loc[measure].join(data, how='right')
plt.scatter(data['gbd'], data['dhs'])
plot_max = min(1, data[['gbd', 'dhs']].max().max() * 1.1)
plt.plot((0, plot_max), (0, plot_max), color='red')
plt.xlim(0, plot_max)
plt.ylim(0, plot_max)
plt.title(f'{measure} by location-year-age-sex')
plt.xlabel('gbd 2021')
plt.ylabel('input data')
plt.show()
# data.columns

In [None]:
models['5_2'].var_info['ldi_pc_pd']['transformer']

In [None]:
models['5_2'].data.describe()

In [None]:
pd.read_parquet('/mnt/team/rapidresponse/pub/population/modeling/climate_malnutrition/stunting/results/2024_07_02.01/2020_ssp119.parquet').describe()