In [1]:
from mult_model_fns import *
from folic_acid.folic_acid_mult_model_fns import *

In [2]:
from functions_for_all_nutrients import *

In [3]:
sexes = [1,2]
ages = [2,3,4,5]

draws = [f'draw_{i}' for i in range(1_000)]
index_cols=['location_id','sex_id','age_group_id']

# define alternative scenario coverage levels (low, medium, high)
    # this parameter represents the proportion of additional coverage achieved in the
    # alternative scenario, defined as the difference between the proportion of the population
    # that eats the fortified vehicle and the proportion of the population that eats 
    # the industrially produced vehicle
alternative_scenario_coverage_levels = [0.2,0.5,0.8,1]
coverage_levels = alternative_scenario_coverage_levels

rei_id = [] #folic acid doesn't effect any risks; just NTDs
cause_ids = [642] # NTDs
nonfatal_causes = [642] # YLLs and YLDs
nutrient = 'folic acid'

In [4]:
vehicles = ['salt', 'zero salt',
           'industry salt', 'zero industry salt',
           'wheat flour', 'zero wheat flour',
           'industry wheat', 'zero industry wheat',
           'maize flour', 'zero maize flour']

In [5]:
coverage_data_dir = f'/ihme/homes/alibow/notebooks/vivarium_research_lsff/data_prep/outputs/waterfall_coverage_all_vehicles.csv'

In [6]:
cov = pd.read_csv(coverage_data_dir)
cov.vehicle.unique()

array(['industry oil', 'industry salt', 'maize flour', 'oil',
       'wheat flour', 'salt', 'zero wheat flour', 'zero industry oil',
       'zero industry salt', 'zero oil', 'zero industry wheat',
       'zero maize flour', 'zero salt', 'industry wheat'], dtype=object)

In [7]:
cov.loc[cov.vehicle.str.contains('salt')].loc[cov.location_id==214]

Unnamed: 0,location_id,location_name,nutrient,vehicle,value_description,value_mean,value_025_percentile,value_975_percentile,wra_applicable,u5_applicable,sub_population
68,214,Nigeria,folic acid,industry salt,percent of population eating fortified vehicle,0.0,0.0,0.0,True,True,
93,214,Nigeria,folic acid,industry salt,percent of population eating industrially prod...,98.52,97.53,99.51,True,True,
302,214,Nigeria,folic acid,salt,percent of population eating fortified vehicle,0.0,0.0,0.0,True,True,
311,214,Nigeria,na,salt,percent of population eating industrially prod...,92.1,87.5,96.7,True,True,
554,214,Nigeria,folic acid,zero industry salt,percent of population eating fortified vehicle,0.0,0.0,0.0,True,True,
579,214,Nigeria,folic acid,zero industry salt,percent of population eating industrially prod...,0.0,0.0,0.0,True,True,
864,214,Nigeria,folic acid,zero salt,percent of population eating fortified vehicle,0.0,0.0,0.0,True,True,
873,214,Nigeria,na,zero salt,percent of population eating industrially prod...,0.0,0.0,0.0,True,True,


In [8]:
# define no fortification relative risk distribution
# folic acid specific -- this should be replaced for other models

from numpy import log
from scipy.stats import norm, lognorm

# median and 0.975-quantile of lognormal distribution for RR
median = 1.71
q_975 = 2.04

# 0.975-quantile of standard normal distribution (=1.96, approximately)
q_975_stdnorm = norm().ppf(0.975)

mu = log(median) # mean of normal distribution for log(RR)
sigma = (log(q_975) - mu) / q_975_stdnorm # std dev of normal distribution for log(RR)

In [9]:
location_ids = [168,
 161,
 201,
 205,
 202,
 6,
 171,
 141,
 179,
 207,
 163,
 11,
 180,
 181,
 184,
 15,
 164,
 213,
 214,
 165,
 196,
 522,
 190,
 189,
 20]

In [10]:
# calculate relative risk for lack of fortification
# https://vivarium-research.readthedocs.io/en/latest/concept_models/vivarium_conic_lsff/concept_model.html#effect-size-folic-acid
rr_ntds_nofort = format_rrs(lognormal_draws(mu, sigma, seed = 7), location_ids)
rr_ntds_nofort.mean(axis=1)

location_id
168    1.712028
161    1.712028
201    1.712028
205    1.712028
202    1.712028
6      1.712028
171    1.712028
141    1.712028
179    1.712028
207    1.712028
163    1.712028
11     1.712028
180    1.712028
181    1.712028
184    1.712028
15     1.712028
164    1.712028
213    1.712028
214    1.712028
165    1.712028
196    1.712028
522    1.712028
190    1.712028
189    1.712028
20     1.712028
dtype: float64

In [11]:
dalys = pull_dalys(cause_ids, cause_ids, location_ids, ages, sexes, index_cols)
dalys.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,draw_0,draw_1,draw_10,draw_100,draw_101,draw_102,draw_103,draw_104,draw_105,draw_106,...,draw_990,draw_991,draw_992,draw_993,draw_994,draw_995,draw_996,draw_997,draw_998,draw_999
location_id,sex_id,age_group_id,cause_id,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1
6,1,2,642,10322.763878,9471.983857,9326.322863,10395.966085,13510.893438,10190.255967,9542.422309,11721.644281,8732.405622,9642.032379,...,9922.34626,11560.055241,9404.223828,10986.001146,10630.453477,11978.378562,10120.298142,9298.436526,7386.087558,7474.08052
6,1,3,642,10793.369985,10263.918859,9382.185956,9968.037842,11416.271719,9862.964908,11121.482312,11598.592613,9764.89327,10017.196952,...,8034.565573,8990.057434,11906.901506,11044.259261,8330.75918,11241.229038,9985.727664,9238.886705,9739.067857,8067.928753
6,1,4,642,25336.608266,21714.315391,20047.215818,23050.370131,28084.480363,21569.777665,22029.386753,26934.517383,23360.799193,20067.716892,...,19383.308185,27212.898218,31242.177965,27961.219167,18728.674471,22540.794924,24194.220376,24114.656987,18316.178064,17181.754974
6,1,5,642,66204.817692,68020.658153,67263.522786,71341.865485,86497.514199,69578.397166,78537.426646,88361.45928,75241.19274,61302.551952,...,54270.894583,84548.082144,89238.607739,104460.090941,70193.626853,93351.585484,80049.044676,90301.053605,54608.073708,56213.929743
6,2,2,642,10878.00812,10382.724118,10116.728296,9244.719719,12568.191988,8388.912625,9137.537798,10220.456155,7906.165243,8515.270704,...,6821.211242,11308.953852,9333.432262,11717.126442,10577.374858,9341.582233,6810.005658,11231.196256,7818.548155,12510.833231


In [12]:
pop = get_population(gbd_round_id=6,
                    location_id=location_ids,
                    sex_id=sexes,
                    age_group_id=ages,
                    year_id=2019,
                    decomp_step='step4')

In [13]:
results = pd.DataFrame()
for vehicle in vehicles:
    alpha, alpha_star = get_baseline_and_counterfactual_coverage(coverage_data_dir,
                                             location_ids,
                                             'folic acid',
                                             [vehicle],
                                             list(range(2022,2026)),
                                             coverage_levels, 'WRA')
    alpha = alpha.loc[alpha.year==2025].set_index('location_id').drop(columns=['vehicle','year'])
    alpha_star = alpha_star.loc[alpha_star.year==2025].set_index(['location_id','coverage_level']).drop(columns=['vehicle','year'])
    gets_intervn = prop_gets_intervention_effect(location_ids, year_start=2022, estimation_years = range(2022,2026))
    new_coverage = percolate_new_coverage(gets_intervn, alpha, alpha_star)
    paf_ntds_nofort = paf_o_r(rr_ntds_nofort, alpha)
    pif_ntds_nofort = pif_o_r(paf_ntds_nofort, alpha = alpha, alpha_star = new_coverage)
    dalys_averted = calc_dalys_averted(dalys, pif_ntds_nofort)
    dalys_averted_u5 = dalys_averted.reset_index().groupby(['location_id','year_id','coverage_level']).sum()[draws]
    dalys_averted_u5['vehicle'] = vehicle
    counts = dalys_averted_u5.reset_index().loc[dalys_averted_u5.reset_index().year_id==2025]
    counts['measure'] = 'counts_averted'
    counts = counts.set_index([c for c in counts.columns if 'draw' not in c])
    rates = counts.reset_index().merge(pop.groupby('location_id').sum().drop(columns='year_id'), on='location_id')
    for i in list(range(0,1000)):
        rates[f'draw_{i}'] = rates[f'draw_{i}'] / rates['population'] * 100_000
    rates['measure'] = 'rates_averted'
    rates = rates.set_index(['location_id','measure','coverage_level','year_id'])
    rates = rates.drop(columns=[c for c in rates.columns if 'draw' not in c])
    pif = (counts / dalys.groupby('location_id').sum() * 100).reset_index()
    pif['measure'] = 'pif'
    pif = pif.set_index(['location_id','measure','coverage_level','year_id'])
    vehicle_results = pd.concat([rates.reset_index(), pif.reset_index(), counts.reset_index()], ignore_index=True, sort=True).reset_index()
    vehicle_results['vehicle'] = vehicle
    results = pd.concat([vehicle_results, results], ignore_index=True)
    
results

Excluded location IDs [] due to missing data
Excluded location IDs [] due to missing data
Excluded location IDs [] due to missing data
Excluded location IDs [] due to missing data
Excluded location IDs [] due to missing data
Excluded [11]/folic acid/wheat flour due to impossible logical values
Excluded location IDs [] due to missing data
Excluded location IDs [] due to missing data
Excluded location IDs [] due to missing data
Excluded location IDs [] due to missing data
Excluded location IDs [] due to missing data


Unnamed: 0,index,coverage_level,draw_0,draw_1,draw_10,draw_100,draw_101,draw_102,draw_103,draw_104,...,draw_994,draw_995,draw_996,draw_997,draw_998,draw_999,location_id,measure,vehicle,year_id
0,0,0.2,-0.000046,-0.000056,-0.000129,-0.000119,-0.000066,-0.000034,-0.000017,-0.000036,...,-0.000019,-0.000258,-0.000067,-0.000089,-0.000068,-0.000087,141,rates_averted,zero maize flour,2025
1,1,0.5,-0.000114,-0.000141,-0.000322,-0.000298,-0.000164,-0.000086,-0.000044,-0.000091,...,-0.000049,-0.000646,-0.000168,-0.000222,-0.000171,-0.000218,141,rates_averted,zero maize flour,2025
2,2,0.8,-0.000183,-0.000225,-0.000515,-0.000476,-0.000262,-0.000137,-0.000070,-0.000145,...,-0.000078,-0.001033,-0.000268,-0.000354,-0.000273,-0.000349,141,rates_averted,zero maize flour,2025
3,3,1.0,-0.000228,-0.000281,-0.000644,-0.000595,-0.000328,-0.000171,-0.000087,-0.000181,...,-0.000097,-0.001292,-0.000335,-0.000443,-0.000341,-0.000437,141,rates_averted,zero maize flour,2025
4,4,0.2,-1.204953,-8.441330,-3.597350,-10.619338,-7.241903,-12.903140,-5.200195,-3.196042,...,-7.465722,-16.348642,-6.880646,-20.920840,-13.175624,-12.487321,168,rates_averted,zero maize flour,2025
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2791,295,1.0,311740.097340,290536.301393,228292.424368,272132.857246,275061.930909,332506.947109,152192.152155,249990.423460,...,449393.194049,659382.493192,572461.487692,621215.381933,569902.398390,593526.700342,214,counts_averted,salt,2025
2792,296,0.2,6430.917608,7733.505672,5556.355992,5743.885360,12178.115937,4703.571264,4480.860167,6245.152351,...,2251.892701,4560.553199,3401.079380,8044.728667,3104.716472,9659.366654,522,counts_averted,salt,2025
2793,297,0.5,16077.294020,19333.764180,13890.889981,14359.713399,30445.289843,11758.928159,11202.150417,15612.880877,...,5629.731754,11401.382998,8502.698450,20111.821668,7761.791180,24148.416634,522,counts_averted,salt,2025
2794,298,0.8,25723.670433,30934.022687,22225.423969,22975.541439,48712.463748,18814.285055,17923.440667,24980.609403,...,9007.570806,18242.212796,13604.317520,32178.914669,12418.865888,38637.466615,522,counts_averted,salt,2025


In [14]:
#results.to_pickle('results_raw/folic_acid_waterfall.pkl')
results.to_pickle('results_raw/folic_acid_waterfall_salt.pkl')

In [15]:
results.vehicle.unique()

array(['zero maize flour', 'maize flour', 'zero industry wheat',
       'industry wheat', 'zero wheat flour', 'wheat flour',
       'zero industry salt', 'industry salt', 'zero salt', 'salt'],
      dtype=object)