In [1]:
import pandas as pd

from gbd_mapping import causes, covariates, risk_factors, sequelae
from db_queries import (
    get_covariate_estimates,
    get_location_metadata,
    get_population,
)
from vivarium_gbd_access.utilities import get_draws
from vivarium.framework.artifact import EntityKey
from vivarium_gbd_access import constants as gbd_constants, gbd
from vivarium_inputs import (
    globals as vi_globals,
    interface,
    utilities as vi_utils,
    utility_data,
)
from vivarium_inputs.mapping_extension import alternative_risk_factors

from vivarium_gates_iv_iron.constants import data_keys, metadata
from vivarium_gates_iv_iron.data import utilities
from vivarium_gates_iv_iron.paths import PREGNANT_PROPORTION_WITH_HEMOGLOBIN_BELOW_70_CSV as HGB_BELOW_70_CSV
from vivarium_gates_iv_iron.utilities import create_draws


In [2]:
def get_data(lookup_key: str, location: str) -> pd.DataFrame:
    """Retrieves data from an appropriate source.

    Parameters
    ----------
    lookup_key
        The key that will eventually get put in the artifact with
        the requested data.
    location
        The location to get data for.

    Returns
    -------
        The requested data.

    """
    mapping = {
        data_keys.POPULATION.LOCATION: load_population_location,
        data_keys.POPULATION.STRUCTURE: load_population_structure,
        data_keys.POPULATION.AGE_BINS: load_age_bins,
        data_keys.POPULATION.DEMOGRAPHY: load_demographic_dimensions,
        data_keys.POPULATION.TMRLE: load_theoretical_minimum_risk_life_expectancy,
        data_keys.POPULATION.ACMR: load_standard_data,
        data_keys.POPULATION.PREGNANT_LACTATING_WOMEN_LOCATION_WEIGHTS: get_pregnant_lactating_women_location_weights,
        data_keys.POPULATION.WOMEN_REPRODUCTIVE_AGE_LOCATION_WEIGHTS: get_women_reproductive_age_location_weights,
        data_keys.PREGNANCY.INCIDENCE_RATE: load_pregnancy_incidence_rate,
        data_keys.PREGNANCY.PREGNANT_PREVALENCE: get_prevalence_pregnant,
        data_keys.PREGNANCY.NOT_PREGNANT_PREVALENCE: get_prevalence_not_pregnant,
        data_keys.PREGNANCY.POSTPARTUM_PREVALENCE: get_prevalence_postpartum,
        data_keys.PREGNANCY.INCIDENCE_RATE_MISCARRIAGE: load_standard_data,
        data_keys.PREGNANCY.INCIDENCE_RATE_ECTOPIC: load_standard_data,
        data_keys.PREGNANCY.ASFR: load_asfr,
        data_keys.PREGNANCY.SBR: load_sbr,
        data_keys.LBWSG.DISTRIBUTION: load_metadata,
        data_keys.LBWSG.CATEGORIES: load_metadata,
        data_keys.LBWSG.EXPOSURE: load_lbwsg_exposure,
        data_keys.PREGNANCY_OUTCOMES.STILLBIRTH: load_pregnancy_outcome,
        data_keys.PREGNANCY_OUTCOMES.LIVE_BIRTH: load_pregnancy_outcome,
        data_keys.PREGNANCY_OUTCOMES.OTHER: load_pregnancy_outcome,
        data_keys.MATERNAL_DISORDERS.CSMR: load_standard_data,
        data_keys.MATERNAL_DISORDERS.INCIDENCE_RATE: load_standard_data,
        data_keys.MATERNAL_DISORDERS.YLDS: load_maternal_disorders_ylds,
        data_keys.MATERNAL_HEMORRHAGE.CSMR: load_standard_data,
        data_keys.MATERNAL_HEMORRHAGE.INCIDENCE_RATE: load_standard_data,
        data_keys.HEMOGLOBIN.MEAN: get_hemoglobin_data,
        data_keys.HEMOGLOBIN.STANDARD_DEVIATION: get_hemoglobin_data,
        data_keys.HEMOGLOBIN.PREGNANT_PROPORTION_WITH_HEMOGLOBIN_BELOW_70: get_hemoglobin_csv_data
    }
    return mapping[lookup_key](lookup_key, location)

In [None]:
location = 'South Asia'


In [None]:
def get_random_variable_draws_for_location(columns: pd.Index, location: str, seed: str, distribution) -> pd.Series:
    return get_random_variable_draws(columns, f"{seed}_{location}", distribution)


def get_lognorm_from_quantiles(
    mean: float, 
    lower: float, 
    upper: float,
    quantiles: Tuple[float, float] = (0.025, 0.975)
) -> stats.lognorm:
    """Returns a frozen lognormal distribution with the specified mean, such that
    (lower, upper) are approximately equal to the quantiles with ranks
    (quantile_ranks[0], quantile_ranks[1]).
    """
    # Let Y ~ norm(mu, sigma^2) and X = exp(Y), where mu = log(mean)
    # so X ~ lognorm(s=sigma, scale=exp(mu)) in scipy's notation.
    # We will determine sigma from the two specified quantiles lower and upper.
    if not (lower <= mean <= upper):
        raise ValueError(
            f"The mean ({mean}) must be between the lower ({lower}) and upper ({upper}) "
            "quantile values."
        )
    try:
        # mean (and mean) of the normal random variable Y = log(X)
        mu = np.log(mean)
        # quantiles of the standard normal distribution corresponding to quantile_ranks
        stdnorm_quantiles = stats.norm.ppf(quantiles)
        # quantiles of Y = log(X) corresponding to the quantiles (lower, upper) for X
        norm_quantiles = np.log([lower, upper])
        # standard deviation of Y = log(X) computed from the above quantiles for Y
        # and the corresponding standard normal quantiles
        sigma = (norm_quantiles[1] - norm_quantiles[0]) / (stdnorm_quantiles[1] - stdnorm_quantiles[0])
        # Frozen lognormal distribution for X = exp(Y)
        # (s=sigma is the shape parameter; the scale parameter is exp(mu), which equals the mean)
        return stats.lognorm(s=sigma, scale=mean)
    except:
        return stats.norm(loc=mean, scale=0)

def create_draws(
    df: pd.DataFrame,
    key: str,
    location: str,
    distribution_function=get_lognorm_from_quantiles
) -> pd.DataFrame:
    """
    Parameters
    ----------
    df
        Multi-index dataframe with mean, lower, and upper values columns.
    location
    key
    distribution_function
        Distribution function to use to create draws
    Returns
    -------

    """
    # location defined in namespace outside of function
    mean = df['mean_value']
    lower = df['lower_value']
    upper = df['upper_value']

    Tuple = (key, distribution_function(mean=mean, lower=lower, upper=upper))
    # pull index from constants
    draws = get_random_variable_draws_for_location(pd.Index([f'draw_{i}' for i in range(0, 1000)]), location, *Tuple)

    return draws


def create_draw(
    draw: int,
    distribution_parameters: Tuple,
    key: str,
    location: str,
    distribution_function=get_lognorm_from_quantiles
) -> pd.DataFrame:
    """
    Parameters
    ----------
    draw: Input draw number
    distribution_parameters: Parameters for distribution (mean, lower upper)
    location: Location string, used in seed generation
    key: Key for the variable, used in seed generation
    distribution_function: Distribution function to use to create the draw
    Returns
    -------

    """
    distribution = distribution_function(mean=distribution_parameters[0],
                                         lower=distribution_parameters[1],
                                         upper=distribution_parameters[2])
    seed = f"{key}_{location}"
    return get_random_variable(draw, seed, distribution)