In [1]:
from typing import Tuple, Dict
from pathlib import Path
import functools
import multiprocessing
import tqdm
from loguru import logger

import pandas as pd
import numpy as np
import xarray as xr
from scipy.special import logit, expit

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.backends.backend_pdf import PdfPages

from rra_climate_health.paths import OUTPUT_ROOT
from rra_climate_health.data_prep.location_mapping import FHS_HIERARCHY_PATH

SEV_INPUT_PATH_TEMPLATE = '/mnt/share/forecasting/data/7/future/sev/20240910_nutrition_{measure}_{scenario}'
SEV_OUTPUT_PATH_TEMPLATE = '/mnt/share/forecasting/data/7/future/sev/20241021_nutrition_{measure}_{scenario}'

MEASURES = [
    'stunting',
    'wasting',
]
REI_IDS = {
    'stunting': 241,
    'wasting': 240,
}
CMIP6_SCENARIOS = [
    'ssp119',
    'ssp245',
    'ssp585',
]


In [2]:
def plot_shifted_sevs(
    gbd: pd.Series, fhs: pd.Series, shifted_fhs: pd.Series,
    location_metadata: pd.DataFrame, sev_age_metadata: pd.DataFrame,
    sev_output_path_template: str, scenario: str, measure: str,
):
    mean_fhs = fhs.groupby(['age_group_id', 'sex_id', 'location_id', 'year_id']).mean()
    mean_shifted_fhs = shifted_fhs.groupby(['age_group_id', 'sex_id', 'location_id', 'year_id']).mean()

    plot_locations = location_metadata.loc[location_metadata['level'] >= 3].set_index('location_id').loc[:, 'location_name'].items()

    with PdfPages(Path(sev_output_path_template.format(measure=measure, scenario=scenario)) / f'nutrition_{measure}.pdf') as pdf:
        for location_id, location_name in plot_locations:
            sns.set_style('whitegrid')
            fig, ax = plt.subplots(2, 2, figsize=(12, 8), sharey=True, sharex=True)

            colors = {
                'gbd': 'dodgerblue',
                'mean_fhs': 'firebrick',
                'mean_shifted_fhs': 'mediumorchid',
            }

            for i, (age_group_id, age_group_name) in enumerate(sev_age_metadata.set_index('age_group_id').loc[:, 'age_group_name'].items()):
                idx = int(i >= 2), i % 2

                for sex_id, sex_name, linestyle in [(1, 'Males', '-'), (2, 'Females', '--')]:
                    gbd.loc[age_group_id, sex_id, location_id, :].plot(
                        ax=ax[idx], label=f'GBD ({sex_name})', linestyle=linestyle, color=colors['gbd'],
                    )
                    mean_fhs.loc[age_group_id, sex_id, location_id, :].plot(
                        ax=ax[idx], label=f'FHS (unshifted) ({sex_name})', linestyle=linestyle, color=colors['mean_fhs'],
                    )
                    mean_shifted_fhs.loc[age_group_id, sex_id, location_id, :].plot(
                        ax=ax[idx], label=f'FHS (shifted) ({sex_name})', linestyle=linestyle, color=colors['mean_shifted_fhs'],
                    )

                    if i == 0:
                        ax[idx].legend(ncol=2)

                ax[idx].set_title(age_group_name)
                ax[idx].set_xlabel(None)
            fig.suptitle(f'{location_name} ({location_id})')
            fig.tight_layout()
            pdf.savefig(fig)
            plt.close()


def get_shifted_forecast(
    measure_scenario: Tuple[str],
    location_metadata: pd.DataFrame,
    age_metadata: pd.DataFrame,
    output_root: Path = OUTPUT_ROOT,
    sev_input_path_template: str = SEV_INPUT_PATH_TEMPLATE,
    sev_output_path_template: str = SEV_OUTPUT_PATH_TEMPLATE,
    rei_ids: Dict[str, int] = REI_IDS,
    verbose: bool = False
):
    measure, scenario = measure_scenario
    if verbose:
        logger.warning(f'MEASURE: {measure}')
        logger.warning(f'MEASURE: {scenario}')

    if verbose:
        logger.info('Creating path and preparing metadata')
    Path(sev_output_path_template.format(measure=measure, scenario=scenario)).mkdir(exist_ok=True)

    sev_age_metadata = age_metadata.loc[
        (age_metadata['age_group_days_start'] >= 28)
        & (age_metadata['age_group_years_end'] <= 5)
    ]
    other_age_metadata = age_metadata.loc[
        (age_metadata['age_group_days_start'] < 28)
        | (age_metadata['age_group_years_end'] > 5)
    ]

    location_metadata['level_3_parent'] = (
        location_metadata
        .apply(lambda x: int(x['path_to_top_parent'].split(',')[3]) if x['level'] >= 3 else -1, axis=1)
    )
    level_3_map = (
        location_metadata
        .loc[location_metadata['level'] > 3]
        .groupby('level_3_parent')['location_id'].apply(lambda x: list(x))
        .to_dict()
    )

    if verbose:
        logger.info('Reading GBD data')
    gbd = pd.read_parquet(output_root / 'input' / 'gbd_prevalence' / f'{measure}_sev.parquet')
    gbd = (
        gbd
        .set_index(['age_group_id', 'sex_id', 'location_id', 'year_id']).loc[:, 'val']
    )
    gbd = gbd.loc[sev_age_metadata['age_group_id'], :, :, :]

    if verbose:
        logger.info('Reading FHS data for relevant age groups')
    fhs = (
        xr.open_dataset(
            Path(sev_input_path_template.format(measure=measure, scenario=scenario)) / f'nutrition_{measure}.nc'
        )
        .sel(age_group_id=sev_age_metadata['age_group_id'].to_list())
        .to_dataframe()
        .loc[:, 'value']
    )

    if verbose:
        logger.info('Broadcasting national trends from forecast to subnational histories')
    fhs_subnat = fhs.reset_index('location_id')
    fhs_subnat['location_id'] = fhs_subnat['location_id'].map(level_3_map)
    fhs_subnat = fhs_subnat.dropna().explode('location_id')
    fhs_subnat['location_id'] = fhs_subnat['location_id'].astype(int)
    fhs_subnat = fhs_subnat.set_index('location_id', append=True).reorder_levels(fhs.index.names).loc[:, 'value']
    fhs = pd.concat([
        fhs.drop(np.hstack(list(level_3_map.values())), level='location_id'),
        fhs_subnat,
    ])

    if verbose:
        logger.info('Calculating logit difference in 2022 and applying shift')
    logit_gbd = logit(gbd.loc[:, :, :, 2022]).rename('gbd')
    logit_fhs = logit(fhs.loc[:, :, :, 2022]).rename('fhs')
    logit_diff = logit_fhs.to_frame().join(logit_gbd, how='left')
    logit_diff = logit_diff['gbd'] - logit_diff['fhs']
    shifted_fhs = expit(logit(fhs) + logit_diff)

    plot_shifted_sevs(
        gbd, fhs, shifted_fhs,
        location_metadata, sev_age_metadata,
        sev_output_path_template, scenario, measure,
    )

    if verbose:
        logger.info('Getting ages not being estimated and checking we do not have any non-zero values')
    fhs_zeros = (
        xr.open_dataset(
            Path(sev_input_path_template.format(measure=measure, scenario=scenario)) / f'nutrition_{measure}.nc'
        )
        .sel(age_group_id=other_age_metadata['age_group_id'].to_list())
        .to_dataframe()
        .loc[:, 'value']
    )
    in_zeros = fhs_zeros.index.get_level_values('location_id').unique().tolist()
    in_estimates = shifted_fhs.index.get_level_values('location_id').unique().tolist()
    if not all([l in in_zeros for l in in_estimates]):
        raise ValueError('Data in estimates but not zeros.')
    if not all([l in in_estimates for l in in_zeros]):
        raise ValueError('Data in zeros but not estimates.')

    # fhs_zeros = (
    #     xr.open_dataset(
    #         Path(sev_input_path_template.format(measure=measure, scenario=scenario)) / f'nutrition_{measure}.nc'
    #     )
    #     .sel(age_group_id=other_age_metadata['age_group_id'].to_list())
    # )

    # shifted_fhs = shifted_fhs.rename('value').to_frame().to_xarray()
    # shifted_fhs = xr.concat([fhs_zeros, shifted_fhs], dim='age_group_id')

    # shifted_fhs.to_netcdf(
    #     Path(sev_output_path_template.format(measure=measure, scenario=scenario)) / f'nutrition_{measure}.nc'
    # )

    if verbose:
        logger.info('Formatting dataset for FHS save results')
    shifted_fhs = pd.concat([
        fhs_zeros,
        shifted_fhs.rename('value'),
    ])
    shifted_fhs = shifted_fhs.unstack('draw')
    draw_columns = [f'draw_{dc}' for dc in shifted_fhs.columns]
    shifted_fhs.columns = draw_columns
    shifted_fhs['rei_id'] = rei_ids[measure]
    shifted_fhs['measure_id'] = 29
    shifted_fhs['metric_id'] = 3
    shifted_fhs['release_id'] = 9
    shifted_fhs['gbd_round_id'] = 7
    shifted_fhs = (
        shifted_fhs
        .reset_index()
        .loc[:, ['measure_id', 'metric_id', 'rei_id', 'age_group_id', 'sex_id', 'location_id', 'year_id', 'release_id', 'gbd_round_id', 'scenario'] + draw_columns]
    )

    if verbose:
        logger.info('Saving csv file for FHS save results')
    shifted_fhs.to_csv(
        Path(sev_output_path_template.format(measure=measure, scenario=scenario)) / f'nutrition_{measure}.csv',
        index=False
    )


In [3]:
def main():
    measure_scenarios = [(measure, scenario) for measure in MEASURES for scenario in CMIP6_SCENARIOS]

    age_metadata = pd.read_parquet(OUTPUT_ROOT / 'input' / 'gbd_prevalence' / 'age_metadata.parquet')
    age_metadata = age_metadata.sort_values('age_group_years_start').reset_index(drop=True)

    location_metadata = pd.read_parquet(FHS_HIERARCHY_PATH)
    location_metadata = location_metadata.sort_values('sort_order').reset_index(drop=True)

    # _get_shifted_forecast = functools.partial(
    #     get_shifted_forecast,
    #     location_metadata=location_metadata.copy(),
    #     age_metadata=age_metadata.copy(),
    # )

    logger.warning('Creating measure-scenarios')
    # with multiprocessing.Pool(len(measure_scenarios)) as pool:
    #     _ = list(tqdm.tqdm(
    #         pool.imap(
    #             _get_shifted_forecast, measure_scenarios
    #         ),
    #         total=len(measure_scenarios)
    #     ))
    # for measure_scenario in tqdm.tqdm(measure_scenarios, total=len(measure_scenarios)):
    for measure_scenario in measure_scenarios:
        get_shifted_forecast(
            measure_scenario,
            location_metadata=location_metadata.copy(),
            age_metadata=age_metadata.copy(),
            verbose=True,
        )
    logger.warning('Complete')


In [4]:

if __name__ == '__main__':
    main()


# /mnt/team/fhs/pub/venv/fhs_save_results /mnt/share/forecasting/data/7/future/sev/20241021_nutrition_stunting_ssp119/nutrition_stunting.csv
# /mnt/team/fhs/pub/venv/fhs_save_results /mnt/share/forecasting/data/7/future/sev/20241021_nutrition_stunting_ssp245/nutrition_stunting.csv
# /mnt/team/fhs/pub/venv/fhs_save_results /mnt/share/forecasting/data/7/future/sev/20241021_nutrition_stunting_ssp585/nutrition_stunting.csv

# /mnt/team/fhs/pub/venv/fhs_save_results /mnt/share/forecasting/data/7/future/sev/20241021_nutrition_wasting_ssp119/nutrition_wasting.csv
# /mnt/team/fhs/pub/venv/fhs_save_results /mnt/share/forecasting/data/7/future/sev/20241021_nutrition_wasting_ssp245/nutrition_wasting.csv
# /mnt/team/fhs/pub/venv/fhs_save_results /mnt/share/forecasting/data/7/future/sev/20241021_nutrition_wasting_ssp585/nutrition_wasting.csv



[32m2024-10-22 12:21:06.887[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_shifted_forecast[0m:[36m63[0m - [1mCreating path and preparing metadata[0m
[32m2024-10-22 12:21:06.918[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_shifted_forecast[0m:[36m87[0m - [1mReading GBD data[0m
[32m2024-10-22 12:21:07.411[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_shifted_forecast[0m:[36m96[0m - [1mReading FHS data for relevant age groups[0m
[32m2024-10-22 12:21:16.720[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_shifted_forecast[0m:[36m107[0m - [1mBroadcasting national trends from forecast to subnational histories[0m
[32m2024-10-22 12:23:04.503[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_shifted_forecast[0m:[36m119[0m - [1mCalculating logit difference in 2022 and applying shift[0m
[32m2024-10-22 12:25:37.703[0m | [1mINFO    [0m | [36m__main__[0m:[36mget_shifted_forecast[0m:[36m133[0m - [1mGetting ages not being estimated and c

## Stunting paths:
#### /ihme/forecasting/data/7/future/sev/FHS_save_results_2024_10_22_at_12_50_06_scenario_53_nutrition_stunting
#### /ihme/forecasting/data/7/future/sev/FHS_save_results_2024_10_22_at_13_14_01_scenario_0_nutrition_stunting
#### /ihme/forecasting/data/7/future/sev/FHS_save_results_2024_10_22_at_13_37_19_scenario_54_nutrition_stunting

## Wasting paths:
#### /ihme/forecasting/data/7/future/sev/FHS_save_results_2024_10_22_at_14_00_08_scenario_53_nutrition_wasting
#### /ihme/forecasting/data/7/future/sev/FHS_save_results_2024_10_22_at_14_25_15_scenario_0_nutrition_wasting
#### /ihme/forecasting/data/7/future/sev/FHS_save_results_2024_10_22_at_14_49_55_scenario_54_nutrition_wasting