In [1]:
import pandas as pd, numpy as np

In [2]:
from vivarium import Artifact

In [3]:
art = Artifact('/ihme/costeffectiveness/artifacts/vivarium_ciff_sam/ethiopia.hdf',
               filter_terms=['year_start==2019', f'age_start>=0.5'])

# Define constants

## Artifact values

In [4]:
wasting_exposure = art.load('risk_factor.child_wasting.exposure').reset_index()
wasting_paf = art.load('risk_factor.child_wasting.population_attributable_fraction').reset_index().drop(columns='affected_measure')
wasting_rr = art.load('risk_factor.child_wasting.relative_risk').reset_index()
wasting_rr_dd = wasting_rr.loc[wasting_rr.affected_entity=='diarrheal_diseases'].drop(columns=['affected_entity','affected_measure'])
wasting_rr_lri = wasting_rr.loc[wasting_rr.affected_entity=='lower_respiratory_infections'].drop(columns=['affected_entity','affected_measure'])
wasting_rr_measles = wasting_rr.loc[wasting_rr.affected_entity=='measles'].drop(columns=['affected_entity','affected_measure'])

prevalence_c302 = art.load('cause.diarrheal_diseases.prevalence') 
p_1 = wasting_exposure.loc[wasting_exposure.parameter=='cat1'].set_index([c for c in wasting_exposure.columns if 'draw' not in c and c != 'parameter']).drop(columns='parameter')
p_2 = wasting_exposure.loc[wasting_exposure.parameter=='cat2'].set_index([c for c in wasting_exposure.columns if 'draw' not in c and c != 'parameter']).drop(columns='parameter')
p_3 = wasting_exposure.loc[wasting_exposure.parameter=='cat3'].set_index([c for c in wasting_exposure.columns if 'draw' not in c and c != 'parameter']).drop(columns='parameter')
p_4 = wasting_exposure.loc[wasting_exposure.parameter=='cat4'].set_index([c for c in wasting_exposure.columns if 'draw' not in c and c != 'parameter']).drop(columns='parameter')

ACMR = art.load('cause.all_causes.cause_specific_mortality_rate') 
csmr_c302 = art.load('cause.diarrheal_diseases.cause_specific_mortality_rate')
emr_c302 = art.load('cause.diarrheal_diseases.excess_mortality_rate')
csmr_pem = art.load('cause.protein_energy_malnutrition.cause_specific_mortality_rate')
emr_pem = art.load('cause.protein_energy_malnutrition.excess_mortality_rate')
csmr_lri = art.load('cause.lower_respiratory_infections.cause_specific_mortality_rate')
csmr_measles = art.load('cause.measles.cause_specific_mortality_rate')
paf_wasting_lri = wasting_paf.loc[wasting_paf.affected_entity=='lower_respiratory_infections'].set_index([c for c in wasting_paf.columns if 'draw' not in c and c != 'affected_entity']).drop(columns='affected_entity')
paf_wasting_measles = wasting_paf.loc[wasting_paf.affected_entity=='measles'].set_index([c for c in wasting_paf.columns if 'draw' not in c and c != 'affected_entity']).drop(columns='affected_entity')
paf_wasting_c302 = wasting_paf.loc[wasting_paf.affected_entity=='diarrheal_diseases'].set_index([c for c in wasting_paf.columns if 'draw' not in c and c != 'affected_entity']).drop(columns='affected_entity')
incidence_c302 = art.load('cause.diarrheal_diseases.incidence_rate') 
remission_c302 = art.load('cause.diarrheal_diseases.remission_rate') 
RR_lri_1 = wasting_rr_lri.loc[wasting_rr_lri.parameter=='cat1'].set_index([c for c in wasting_rr_lri.columns if 'draw' not in c and c != 'parameter']).drop(columns='parameter')
RR_lri_2 = wasting_rr_lri.loc[wasting_rr_lri.parameter=='cat2'].set_index([c for c in wasting_rr_lri.columns if 'draw' not in c and c != 'parameter']).drop(columns='parameter')
RR_lri_3 = wasting_rr_lri.loc[wasting_rr_lri.parameter=='cat3'].set_index([c for c in wasting_rr_lri.columns if 'draw' not in c and c != 'parameter']).drop(columns='parameter')
RR_measles_1 = wasting_rr_measles.loc[wasting_rr_measles.parameter=='cat1'].set_index([c for c in wasting_rr_measles.columns if 'draw' not in c and c != 'parameter']).drop(columns='parameter')
RR_measles_2 = wasting_rr_measles.loc[wasting_rr_measles.parameter=='cat2'].set_index([c for c in wasting_rr_measles.columns if 'draw' not in c and c != 'parameter']).drop(columns='parameter')
RR_measles_3 = wasting_rr_measles.loc[wasting_rr_measles.parameter=='cat3'].set_index([c for c in wasting_rr_measles.columns if 'draw' not in c and c != 'parameter']).drop(columns='parameter')
RR_1 = wasting_rr_dd.loc[wasting_rr_dd.parameter=='cat1'].set_index([c for c in wasting_rr_dd.columns if 'draw' not in c and c != 'parameter']).drop(columns='parameter')
RR_2 = wasting_rr_dd.loc[wasting_rr_dd.parameter=='cat2'].set_index([c for c in wasting_rr_dd.columns if 'draw' not in c and c != 'parameter']).drop(columns='parameter')
RR_3 = wasting_rr_dd.loc[wasting_rr_dd.parameter=='cat3'].set_index([c for c in wasting_rr_dd.columns if 'draw' not in c and c != 'parameter']).drop(columns='parameter')
RR_4 = wasting_rr_dd.loc[wasting_rr_dd.parameter=='cat4'].set_index([c for c in wasting_rr_dd.columns if 'draw' not in c and c != 'parameter']).drop(columns='parameter')

In [5]:
"""NOTE: the artifact values for diarrheal disease remission is greater than 
remission after re-scaling to the total population. I have adjusted the remission 
rate to be equal to the number of incident cases in the population minus the mortality 
rate for the average disease duration (6 days) to get the model to behave correctly. 
This method ignores correlation between diarrheal diseases and other causes"""

remission_c302 = incidence_c302 - (6 / 365) * (ACMR - csmr_c302 + emr_c302)
(remission_c302).apply(pd.DataFrame.describe, percentiles=[0.025,0.975], axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,count,mean,std,min,2.5%,50%,97.5%,max
sex,age_start,age_end,year_start,year_end,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
Female,0.5,1.0,2019,2020,1000.0,2.912131,0.291084,2.092265,2.366062,2.893852,3.548673,3.812014
Female,1.0,2.0,2019,2020,1000.0,1.642085,0.22166,0.944368,1.252067,1.629663,2.098036,2.330801
Female,2.0,5.0,2019,2020,1000.0,1.642085,0.22166,0.944368,1.252067,1.629663,2.098036,2.330801
Male,0.5,1.0,2019,2020,1000.0,2.898951,0.291449,2.119631,2.351311,2.88659,3.507169,3.729021
Male,1.0,2.0,2019,2020,1000.0,1.667292,0.219592,1.065674,1.294353,1.661789,2.110573,2.491325
Male,2.0,5.0,2019,2020,1000.0,1.667292,0.219592,1.065674,1.294353,1.661789,2.110573,2.491325


## Wasting transition rates

These were read-in from simulation model output rather than calculating them directly given the complexity of the calculations. This should be updated as the model is finalized. The exception is for r4 and i3 (which are currently undergoing calibration)


In [6]:
output_dir ='/ihme/costeffectiveness/results/vivarium_ciff_sam/v4.1_wasting_treatment/ciff_sam/2021_09_24_16_36_30/count_data/'
wasting_transitions = pd.read_csv(output_dir + 'wasting_transition_count.csv').drop(columns='Unnamed: 0')
wasting_transitions = (wasting_transitions
                       .loc[wasting_transitions.scenario=='baseline']
                       .groupby(['input_draw','sex','age','measure']).sum()
                       .drop(columns='year')
                       .reset_index())
wasting_person_time = pd.read_csv(output_dir + 'wasting_state_person_time.csv').drop(columns='Unnamed: 0')
wasting_person_time = (wasting_person_time
                       .loc[wasting_person_time.scenario=='baseline']
                       .groupby(['input_draw','sex','age','cause']).sum()
                       .drop(columns='year')
                       .reset_index())

In [7]:
def compute_wasting_transition_rate(transition_count_parameter, person_time_state, ages):
    rate = (((wasting_transitions.loc[wasting_transitions.measure==f'{transition_count_parameter}']
           .set_index(['input_draw','sex','age']).drop(columns='measure'))
          / (wasting_person_time.loc[wasting_person_time.cause==f'{person_time_state}']
             .set_index(['input_draw','sex','age']).drop(columns='cause')))
          .groupby(['sex','age']).mean().reset_index())
    rate['sex'] = rate.sex.str.capitalize()
    rate['year_start']=2019
    rate['year_end']=2020
    rate = rate.loc[rate.age.isin(ages)]
    rate['age_start'] = np.where(rate.age=='6-11_months',0.5,
                                np.where(rate.age=='12_to_23_months', 1.0, 2.0))
    rate['age_end'] = np.where(rate.age=='6-11_months',1.0,
                                np.where(rate.age=='12_to_23_months', 2.0, 5.0))
    for i in list(range(0,1000)):
        rate[f'draw_{i}'] = rate['value']
    return rate.drop(columns=['age','value']).set_index(['sex','age_start','age_end','year_start','year_end']).sort_index()

In [8]:
ages = ['6-11_months','12_to_23_months','2_to_4']
t1 = compute_wasting_transition_rate('severe_acute_malnutrition_to_mild_child_wasting_event_count', 
                                     'severe_acute_malnutrition', ages)
r2 = compute_wasting_transition_rate('severe_acute_malnutrition_to_moderate_acute_malnutrition_event_count', 
                                     'severe_acute_malnutrition', ages)
r3 = compute_wasting_transition_rate('moderate_acute_malnutrition_to_mild_child_wasting_event_count', 
                                     'moderate_acute_malnutrition', ages)
r4_obs = compute_wasting_transition_rate('mild_child_wasting_to_susceptible_to_child_wasting_event_count', 
                                     'mild_child_wasting', ages)
i1 = compute_wasting_transition_rate('moderate_acute_malnutrition_to_severe_acute_malnutrition_event_count', 
                                     'moderate_acute_malnutrition', ages)
i2 = compute_wasting_transition_rate('mild_child_wasting_to_moderate_acute_malnutrition_event_count', 
                                     'mild_child_wasting', ages)
i3_obs = compute_wasting_transition_rate('susceptible_to_child_wasting_to_mild_child_wasting_event_count', 
                                     'susceptible_to_child_wasting', ages)


r4_daily = 1/50
r4 = -np.log(1 - r4_daily) * 365
ap0 = 1 - np.exp(-ACMR * 1 / 365)
m_D4_daily = 1 - np.exp(-(ACMR  
        - csmr_c302 + csmr_c302 * (1 - paf_wasting_c302)
        - csmr_lri + csmr_lri * (1 - paf_wasting_lri)
        - csmr_measles + csmr_measles * (1 - paf_wasting_measles)) / 365)

i3_daily = ap0 * p_4 / (p_4 / (ap0+1)) + p_3 / (ap0 + 1) * r4_daily / (p_4 / (ap0+1)) - m_D4_daily
i3 = -np.log(1 - i3_daily) * 365

## Prevalence ratios

In [9]:
PR_1=1.060416
PR_2=1.061946
PR_3=1.044849
PR_4=0.990530

# Solve for parameters

## State prevalence values

In [10]:
# prevalence values
p_D1 = (PR_1 * p_1 * prevalence_c302) / (PR_1 * prevalence_c302 - prevalence_c302 + 1)
p_D2 = (PR_2 * p_2 * prevalence_c302) / (PR_2 * prevalence_c302 - prevalence_c302 + 1)
p_D3 = (PR_3 * p_3 * prevalence_c302) / (PR_3 * prevalence_c302 - prevalence_c302 + 1)
p_S1 = (-p_1 * prevalence_c302 + p_1) / (PR_1 * prevalence_c302 - prevalence_c302 + 1)
p_S2 = (-p_2 * prevalence_c302 + p_2) / (PR_2 * prevalence_c302 - prevalence_c302 + 1)
p_S3 = (-p_3 * prevalence_c302 + p_3) / (PR_3 * prevalence_c302 - prevalence_c302 + 1)
p_D4 = prevalence_c302 - p_D1 - p_D2 - p_D3
p_S4 = (1 - prevalence_c302) - p_S1 - p_S2 - p_S3

In [11]:
print(p_D1.mean().mean(),
      p_D2.mean().mean(), 
      p_D3.mean().mean(),
      p_D4.mean().mean())
print((p_D1/prevalence_c302).mean().mean(),
      (p_D2/prevalence_c302).mean().mean(), 
      (p_D3/prevalence_c302).mean().mean(),
      (p_D4/prevalence_c302).mean().mean())

0.0008586534064018113 0.0029456870183582466 0.007246069255888614 0.021723145785027178
0.024237085046622014 0.08763267723950048 0.22124259485679937 0.6668876428570781


In [12]:
print(p_S1.mean().mean(),
      p_S2.mean().mean(), 
      p_S3.mean().mean(),
      p_S4.mean().mean())
print((p_S1/(1-prevalence_c302)).mean().mean(),
      (p_S2/(1-prevalence_c302)).mean().mean(), 
      (p_S3/(1-prevalence_c302)).mean().mean(),
      (p_S4/(1-prevalence_c302)).mean().mean())

0.022046471988559396 0.07974698357651164 0.2048109589049813 0.6606220300642718
0.02285620459010616 0.08252084120991131 0.2117459985670651 0.6828769556329174


In [13]:
for param in [p_D1,p_D2,p_D3,p_D4,p_S1,p_S2,p_S3,p_S4
             ]:
    name = [x for x in globals() if globals()[x] is param][0]
    assert param.min().min()>0, 'negative values'
    assert param.max().max()<1, 'values greater than 1'
    
assert np.all((p_D1+p_D2+p_D3+p_D4+p_S1+p_S2+p_S3+p_S4).round(5)==1), 'Prevalence parameters do not all sum to one'

## Other intermediate variables

We are making that assumption that wasting relative risks apply to 
excess mortality, NOT incidence

In [14]:
m_D1 = (ACMR - csmr_c302 + emr_c302 * (1 - paf_wasting_c302) * RR_1
        - csmr_pem + emr_pem
        - csmr_lri + csmr_lri * (1 - paf_wasting_lri) * RR_lri_1
        - csmr_measles + csmr_measles * (1 - paf_wasting_measles) * RR_measles_1) * p_D1

m_D2 = (ACMR - csmr_c302 + emr_c302 * (1 - paf_wasting_c302) * RR_2
        - csmr_pem + emr_pem
        - csmr_lri + csmr_lri * (1 - paf_wasting_lri) * RR_lri_2
        - csmr_measles + csmr_measles * (1 - paf_wasting_measles) * RR_measles_2) * p_D2

m_D3 = (ACMR - csmr_c302 + emr_c302 * (1 - paf_wasting_c302) * RR_3
        - csmr_pem
        - csmr_lri + csmr_lri * (1 - paf_wasting_lri) * RR_lri_3
        - csmr_measles + csmr_measles * (1 - paf_wasting_measles) * RR_measles_3) * p_D3

print(m_D1.mean().mean(),
     m_D2.mean().mean(),
     m_D3.mean().mean())
print((m_D1/p_D1).mean().mean(),
     (m_D2/p_D2).mean().mean(),
     (m_D3/p_D3).mean().mean())

0.00038833795886415257 0.0003675113644543976 0.00041463564804416976
0.3789093264178939 0.11061201288018828 0.052375756298638984


In [15]:
#di_1 = (incidence_c302 * (1 - diarrhea_paf_incidence) * RR_1_incidence) * p_S1/(1-prevalence_c302)
#di_2 = (incidence_c302 * (1 - diarrhea_paf_incidence) * RR_2_incidence) * p_S2/(1-prevalence_c302)
#di_3 = (incidence_c302 * (1 - diarrhea_paf_incidence) * RR_3_incidence) * p_S3/(1-prevalence_c302)
#di_4 = (incidence_c302 * (1 - diarrhea_paf_incidence)) * p_S4/(1-prevalence_c302)
di_1 = (incidence_c302 * p_S1/(1-prevalence_c302))
di_2 = (incidence_c302 * p_S2/(1-prevalence_c302))
di_3 = (incidence_c302 * p_S3/(1-prevalence_c302))
di_4 = (incidence_c302 * p_S4/(1-prevalence_c302))


print(di_1.mean().mean(), di_2.mean().mean(), di_3.mean().mean(), di_4.mean().mean())
print((di_1/p_S1).mean().mean(), 
      (di_2/p_S2).mean().mean(), 
      (di_3/p_S3).mean().mean(), 
      (di_4/p_S4).mean().mean())

0.05133109288092936 0.17556029747694293 0.4385558397095897 1.4070671698369293
2.149389531421424 2.149389531421424 2.149389531421424 2.149389531421424


In [16]:
dr_1 = remission_c302 / prevalence_c302 * p_D1
dr_2 = remission_c302 / prevalence_c302 * p_D2
dr_3 = remission_c302 / prevalence_c302 * p_D3
dr_4 = remission_c302 / prevalence_c302 * p_D4

print(dr_1.mean().mean(), dr_2.mean().mean(), dr_3.mean().mean(), dr_4.mean().mean())
print((dr_1/p_D1).mean().mean(), 
      (dr_2/p_D2).mean().mean(), 
      (dr_3/p_D3).mean().mean(), 
      (dr_4/p_D4).mean().mean())

0.05440804332086057 0.18635524910715998 0.45803077869112385 1.3728449675625325
63.158196485418635 63.158196485418635 63.158196485418635 63.158196485418635


In [17]:
b_D1 = ACMR * p_D1
b_D2 = ACMR * p_D2
b_D3 = ACMR * p_D3

r_D1tx = t1 * p_D1 
r_D1ux = r2 * p_D1
r_D2 = r3 * p_D2
r_D3 = r4 * p_D3

In [18]:
assert np.all((m_D1 / p_D1) > (m_D2 / p_D2)), 'Mortality rate of D1 state less than D2 state'
assert np.all((m_D2 / p_D2) > (m_D3 / p_D3)), 'Mortality rate of D2 state less than D3 state'

## Incidence rates

In [19]:
i_S1 = b_D1 + di_1 - dr_1 + i1*p_2 - m_D1 - r_D1tx - r_D1ux
i_S2 = b_D1 + b_D2 + 2.0*di_1 - dr_1 - dr_2 + i2*p_3 - m_D1 - m_D2 - r_D1tx - r_D2
i_S3 = b_D1 + b_D2 + b_D3 + 2.0*di_1 + di_3 - dr_1 - dr_2 - dr_3 + i3*p_4 - m_D1 - m_D2 - m_D3 - r_D3
i_D1 = -b_D1 - di_1 + dr_1 + m_D1 + r_D1tx + r_D1ux
i_D2 = -b_D1 - b_D2 - 2.0*di_1 + dr_1 + dr_2 + m_D1 + m_D2 + r_D1tx + r_D2
i_D3 = -b_D1 - b_D2 - b_D3 - 2.0*di_1 - di_3 + dr_1 + dr_2 + dr_3 + m_D1 + m_D2 + m_D3 + r_D3

In [20]:
print(i_S1.mean().mean(), i_S2.mean().mean(), i_S3.mean().mean())
print(i_D1.mean().mean(), i_D2.mean().mean(), i_D3.mean().mean())

0.13540387087524577 0.43054527539010995 1.3415866049591687
0.009263146737993277 0.16012468828318985 0.21208573676567363


In [21]:
RR_i3 = (i_D3 * p_S4) / (i_S3 * p_D4)
RR_i2 = (i_D2 * p_S3) / (i_S2 * p_D3)
RR_i1 = (i_D1 * p_S2) / (i_S1 * p_D2)

In [22]:
print(RR_i3.mean().mean(), RR_i2.mean().mean(), RR_i1.mean().mean())

4.764589652847803 10.542478253704584 1.743731173135167


In [23]:
assert np.all((i1 * p_2).transpose().sort_index().transpose().round(8) == (i_S1 + i_D1).round(8)), 'eq1 untrue'
assert np.all((i2 * p_3).transpose().sort_index().transpose().round(8) == (i_S2 + i_D2).round(8)), 'eq2 untrue'
assert np.all((i3 * p_4).transpose().sort_index().transpose().round(8) == (i_S3 + i_D3).round(8)), 'eq3 untrue'
assert np.all((i_D3 + di_3 + b_D3 + r_D2 + r_D1tx).round(8) 
              == (r_D3 + dr_3 + m_D3 + i_D2).round(8)), 'eq4 untrue'
assert np.all((i_D2 + di_1 + r_D1ux + b_D2).round(8)
             == (r_D2 + dr_2 + i_D1 + m_D2).round(8)), 'eq5 untrue'
assert np.all((i_D1 + di_1 + b_D1).round(8) 
             == (dr_1 + m_D1 + r_D1tx + r_D1ux).round(8)), 'eq6 untrue'

In [24]:
for param in [i_S1, i_S2, i_S3, i_D1, i_D2, i_D3
             ]:
    name = [x for x in globals() if globals()[x] is param][0]
    assert param.min().min()>0, 'negative values'

In [25]:
RR_i3.groupby('year_start').mean().apply(pd.DataFrame.describe, percentiles=[0.025,0.975], axis=1)

Unnamed: 0_level_0,count,mean,std,min,2.5%,50%,97.5%,max
year_start,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
2019,1000.0,4.76459,0.230162,4.076506,4.327559,4.764401,5.20549,5.439851


In [26]:
RR_i2.groupby('year_start').mean().apply(pd.DataFrame.describe, percentiles=[0.025,0.975], axis=1)

Unnamed: 0_level_0,count,mean,std,min,2.5%,50%,97.5%,max
year_start,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
2019,1000.0,10.542478,0.747141,8.263688,9.135932,10.502671,12.040487,12.805677


In [27]:
RR_i1.groupby('year_start').mean().apply(pd.DataFrame.describe, percentiles=[0.025,0.975], axis=1)

Unnamed: 0_level_0,count,mean,std,min,2.5%,50%,97.5%,max
year_start,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
2019,1000.0,1.743731,0.064248,1.548605,1.622291,1.744116,1.869238,1.948257
