In [11]:
from typing import Tuple, Union

import numpy as np
import pandas as pd
from scipy import stats
from vivarium.framework.randomness import get_hash
from vivarium_inputs import utility_data

from vivarium_gates_iv_iron.constants import data_keys, metadata
from vivarium_gates_iv_iron.data import loader, ut
from vivarium_gates_iv_iron.utilities import create_draws


In [19]:
location = 'South Asia'

In [None]:
index_cols = ['sex', 'age_start', 'age_end', 'year_start', 'year_end']

child_locs = get_child_locs(location)

In [12]:
def load_sbr(key: str, location: str):
    index_cols = ['sex', 'age_start', 'age_end', 'year_start', 'year_end']

    child_locs = get_child_locs(location)
    child_dfs = [get_child_sbr_with_weighting_unit(loc) for loc in child_locs]

    disaggregated_df = pd.concat(child_dfs)

    df = pd.concat([weighted_average(disaggregated_df, 'sbr', f"draw_{i}", index_cols) for i in range(1000)],
                   axis=1)
    df.columns = [f"draw_{i}" for i in range(1000)]

    return df


def get_child_sbr_with_weighting_unit(location: str):
    def get_sbr_value():
        sbr = load_standard_data(data_keys.PREGNANCY.SBR, location)
        sbr = sbr.reset_index()
        sbr = sbr[(sbr.year_start == 2019) & (sbr.parameter == 'mean_value')]['value'].values[0]
        return sbr

    sbr_df = get_weighting_units(location)
    sbr_df['sbr'] = get_sbr_value()
    sbr_df['location'] = location
    sbr_df = sbr_df.reset_index()

    return sbr_df


def get_child_locs(location, location_set_id: int = 35, decomp: str = 'step4'):
    # Level = 3 default parameter pulls child locations at national level
    # location_set_id = 35 is for GBD model results

    parent_id = utility_data.get_location_id(location)
    loc_metadata = get_location_metadata(location_set_id=location_set_id,
                                         decomp_step=decomp,
                                         gbd_round_id=metadata.GBD_2019_ROUND_ID)

    path_lists = [[int(loc) for loc in path.split(',')] for path in loc_metadata.path_to_top_parent]

    is_child_loc = [parent_id in path_list for path_list in path_lists]

    # Subset to level
    is_country = loc_metadata.location_type == "admin0"
    child_locs = loc_metadata.loc[(is_child_loc) & (is_country), 'location_name'].tolist()

    return child_locs


def get_weighting_units(location):
    asfr_draws = get_data(data_keys.PREGNANCY.ASFR, location)
    wra = get_wra(location)

    df = pd.concat([asfr_draws, wra], axis=1)
    draw_cols = [f"draw_{i}" for i in range(1000)]
    wu_df = df[draw_cols].multiply(wra['wra'], axis=0)
    wu_df.index = df.index

    return wu_df


def get_wra(location: str, decomp: str = "step4"):
    location_id = utility_data.get_location_id(location)
    wra = get_population(decomp_step=decomp, age_group_id=[7, 8, 9, 10, 11, 12, 13, 14, 15], sex_id=2,
                         gbd_round_id=metadata.GBD_2019_ROUND_ID, location_id=location_id)

    # reshape to vivarium format
    wra = wra.set_index(['age_group_id', 'location_id', 'sex_id', 'year_id']).drop('run_id', axis=1)
    wra = utilities.scrub_gbd_conventions(wra, location)
    wra = vi_utils.split_interval(wra, interval_column='age', split_column_prefix='age')
    wra = vi_utils.split_interval(wra, interval_column='year', split_column_prefix='year')
    wra = vi_utils.sort_hierarchical_data(wra)

    wra = wra.rename({'population': 'wra'}, axis=1)
    wra.index = wra.index.droplevel('location')

    return wra


def weighted_average(df, data_col, weight_col, by_col):
    df['_data_times_weight'] = df[data_col] * df[weight_col]
    g = df.groupby(by_col)
    result = g['_data_times_weight'].sum() / g[weight_col].sum()
    del df['_data_times_weight'], df[weight_col]
    return result