In [7]:
from typing import Tuple, Union

import numpy as np
import pandas as pd
from scipy import stats
from vivarium.framework.randomness import get_hash

from vivarium_gates_iv_iron.constants import data_keys, metadata
from vivarium_gates_iv_iron.data.loader import get_data, load_standard_data
from vivarium_gates_iv_iron.utilities import create_draws

In [2]:
def get_prevalence_pregnant(key: str, location: str) -> pd.DataFrame:
    asfr = get_data(data_keys.PREGNANCY.ASFR, location)
    sbr = get_data(data_keys.PREGNANCY.SBR, location)
    incidence_c995 = get_data(data_keys.PREGNANCY.INCIDENCE_RATE_MISCARRIAGE, location)
    incidence_c374 = get_data(data_keys.PREGNANCY.INCIDENCE_RATE_ECTOPIC, location)

    prevalence_pregnant = (((asfr + asfr * sbr) * 40 / 52) +
                           ((incidence_c995 + incidence_c374) * 24 / 52))

    return prevalence_pregnant

In [12]:
def generate_lognormal_draws(df, seed, quantiles = (0.025, 0.975)):    
    mean = df['mean_value'].values
    lower = df['lower_value'].values
    upper = df['upper_value'].values
    assert np.all((lower <= mean) & (mean <= upper))
    assert np.all((lower == mean) == (upper == mean))

    sample_mask = (mean > 0) & (lower < mean) & (mean < upper)
    stdnorm_quantiles = stats.norm.ppf(quantiles)
    norm_quantiles = np.log([lower[sample_mask], upper[sample_mask]])
    sigma = (norm_quantiles[1] - norm_quantiles[0]) / (stdnorm_quantiles[1] - stdnorm_quantiles[0])

    distribution = stats.lognorm(s=sigma, scale=mean[sample_mask])
    np.random.seed(get_hash(seed))    
    lognorm_samples = distribution.rvs(size=(1000, sample_mask.sum())).T
    lognorm_samples = pd.DataFrame(lognorm_samples, index=df[sample_mask].index)
    
    use_means = np.tile(mean[~sample_mask], 1000).reshape((1000, ~sample_mask.sum())).T    
    use_means = pd.DataFrame(use_means, index=df[~sample_mask].index)
    return pd.concat([lognorm_samples, use_means]).sort_index().rename(columns=lambda d: f'draw_{d}')




In [4]:
key = data_keys.PREGNANCY.ASFR
location = 'South Asia'

In [5]:
%%time
asfr = load_standard_data(key, location)
asfr = asfr.reset_index()
asfr_pivot = asfr.pivot(
    index=[col for col in metadata.ARTIFACT_INDEX_COLUMNS if col != "location"],
    columns='parameter',
    values='value'
)


CPU times: user 4.74 s, sys: 163 ms, total: 4.9 s
Wall time: 20.9 s


In [14]:
%%time
asfr_pivot.apply(create_draws, args=(key, location), axis=1).head()

CPU times: user 57.1 s, sys: 86.6 ms, total: 57.2 s
Wall time: 57.2 s


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,draw_0,draw_1,draw_2,draw_3,draw_4,draw_5,draw_6,draw_7,draw_8,draw_9,...,draw_990,draw_991,draw_992,draw_993,draw_994,draw_995,draw_996,draw_997,draw_998,draw_999
sex,age_start,age_end,year_start,year_end,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
Female,0.0,0.019178,1990,1991,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Female,0.0,0.019178,1991,1992,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Female,0.0,0.019178,1992,1993,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Female,0.0,0.019178,1993,1994,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Female,0.0,0.019178,1994,1995,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [15]:
%%time
seed = f'{key}_{location}'
generate_lognormal_draws(asfr_pivot, seed).head()

CPU times: user 22.5 ms, sys: 8.07 ms, total: 30.6 ms
Wall time: 28.1 ms


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,draw_0,draw_1,draw_2,draw_3,draw_4,draw_5,draw_6,draw_7,draw_8,draw_9,...,draw_990,draw_991,draw_992,draw_993,draw_994,draw_995,draw_996,draw_997,draw_998,draw_999
sex,age_start,age_end,year_start,year_end,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1
Female,0.0,0.019178,1990,1991,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Female,0.0,0.019178,1991,1992,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Female,0.0,0.019178,1992,1993,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Female,0.0,0.019178,1993,1994,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
Female,0.0,0.019178,1994,1995,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
