Author: Yongquan Xie, Nathaniel Blair-Stahn<br>
Date: July 25, 2019<br>
Purpose: SQ-LNS presentation Nigeria results preparation<br>
Note: Yongquan and Nathaniel will give this presentation on August 1, 2019

In [None]:
%matplotlib inline

import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import interact, FloatSlider, IntSlider

pd.set_option('display.max_rows', 8)

!date
!whoami

## Load output data and aggregate over random seeds

In [None]:
cause_names = ['lower_respiratory_infections', 'measles', 'diarrheal_diseases', 
               'protein_energy_malnutrition', 'iron_deficiency', 'other_causes']
risk_names = ['anemia', 'child_stunting', 'child_wasting']

template_cols = ['coverage', 'duration', 'child_stunting_permanent', 
                 'child_wasting_permanent', 'iron_deficiency_permanent', 
                 'iron_deficiency_mean', 'cause', 'measure', 'input_draw']

result_dir = '/share/costeffectiveness/results/sqlns/presentation/nigeria/2019_07_23_10_57_25'

In [None]:
# note that we have applied coefficient of variation as constant with different sqlns effect on iron deficiency
def clean_and_aggregate(path, filename):
#     r = pd.read_hdf(path + 'nigeria/2019_07_18_13_20_17/output.hdf')
    r = pd.read_hdf(f'{path}/{filename}')
    r.rename(columns={'sqlns.effect_on_child_stunting.permanent': 'child_stunting_permanent',
                      'sqlns.effect_on_child_wasting.permanent': 'child_wasting_permanent',
                      'sqlns.effect_on_iron_deficiency.permanent': 'iron_deficiency_permanent',
                      'sqlns.effect_on_iron_deficiency.mean': 'iron_deficiency_mean',
                      'sqlns.program_coverage': 'coverage',
                      'sqlns.duration': 'duration'}, inplace=True)
    r['coverage'] *= 100
    # The 'sqlns_treated_days' column got subtracted in the wrong order for the 2019_07_23_10_57_25 run:
    r['sqlns_treated_days'] = -1 * r['sqlns_treated_days'] # This line should be deleted once the code is fixed
    r = r.groupby(['coverage', 'duration', 'child_stunting_permanent', 'child_wasting_permanent', 'iron_deficiency_permanent', 'iron_deficiency_mean', 'input_draw']).sum()
    return r

In [None]:
# Load outpt data - as of 2019-07-25 there are random seeds missing
r = clean_and_aggregate(result_dir, 'output.hdf')
# Raw data aggregated by random seed, with intervention columns renamed
r

## Get a list of the unique draws for plotting by draw

In [None]:
draws = r.reset_index().input_draw.unique()
draws

## Plot total YLLs and YLDs vs. coverage for each draw

Raw YLLs and YLDs are plotted side by side with the rates per 100,000 person years. Plots should be monotonically decreasing as coverage level increases.

Create a `pandas.IndexSlice` object to easily select with the multi-index of the original aggregated dataframe.

In [None]:
# Create a pandas IndexSlice object to easily multi-index the original dataframe
idx = pd.IndexSlice
r.loc[idx[:, 365.25, False, False, False, 0.895, 55],
      ['years_of_life_lost', 'years_lived_with_disability', 'person_time']].reset_index()

In [None]:
@interact()
def plot_total_dalys_by_draw(duration=[365.25, 730.50],
                    cgf_permanent=[False, True],
                    iron_permanent=[False, True],
                    iron_mean=[0.895, 4.475, 8.950],
                    input_draw = draws,
                  ):
    
    data = r.loc[idx[:, duration, cgf_permanent, cgf_permanent, iron_permanent, iron_mean, input_draw],
      ['years_of_life_lost', 'years_lived_with_disability', 'person_time']].reset_index()
    
    fig, ax = plt.subplots(2,2, figsize=(12,8))
    
    xx = data['coverage']
    
    measures_short_names = {'years_of_life_lost': 'YLL', 'years_lived_with_disability': 'YLD'}

    for i, (measure, short_name) in enumerate(measures_short_names.items()):
        ax[i,0].plot(xx, data[measure], '-o')
        ax[i,1].plot(xx, 100_000*data[measure] / data['person_time'], '-o', color='orange')
    
        ax[i,0].set_title(f'Total {short_name} count vs. coverage', fontsize=20)
        ax[i,0].set_xlabel('Program Coverage (%)', fontsize=16)
        ax[i,0].set_ylabel(f'{short_name}s', fontsize=20)
        ax[i,0].grid()
#         ax[i,0].legend(loc=(0.8, -.25), fontsize=14)

        ax[i,1].set_title(f'Total {short_name} rate vs. coverage', fontsize=20)
        ax[i,1].set_xlabel('Program Coverage (%)', fontsize=16)
        ax[i,1].set_ylabel(f'{short_name}s per 100,000 person years', fontsize=12)
        ax[i,1].grid()
        
    fig.tight_layout()

## Plot treated days and estimated fraction of population treated for all draws

In [None]:
r.filter(regex='treated_days|population|person_time')

In [None]:
# The fraction of population tracked is about 54.4% for all scenarios and draws.
# Why? How do you compute this?
(r['total_population_tracked']/r['total_population']).describe()

In [None]:
days_per_year = 365.25
years_of_simulation = 5

@interact()
def plot_treated_days_by_draw(duration=[365.25, 730.50],
                    cgf_permanent=[False, True],
                    iron_permanent=[False, True],
                    iron_mean=[0.895, 4.475, 8.950],
                    input_draw = draws,
                  ):
    
    data = r.loc[idx[:, duration, cgf_permanent, cgf_permanent, iron_permanent, iron_mean, input_draw],
      ['sqlns_treated_days', 'total_population_living', 'total_population_tracked', 'person_time']].reset_index()
    
    fig, ax = plt.subplots(1,2, figsize=(13,6))
    
    xx = data['coverage']
    

    ax[0].plot(xx, data['sqlns_treated_days'] / days_per_year, '-o')
#     # This is computing something like "average person years per treatment year for a treated simulant",
#     # then multiplying that by the number of treated years over the number of person years.
#     ax[1].plot(xx,
#                (data['total_population_living'] / data['total_population_tracked']) *
#                years_of_simulation * data['sqlns_treated_days'] / (duration * data['person_time']),
#                '-o', color='orange')
    ax[1].plot(xx, data['sqlns_treated_days'] / (duration * data['total_population_tracked']),
               '-o', color='orange')

    ax[0].set_title('Treated years vs. coverage', fontsize=20)
    ax[0].set_xlabel('Program Coverage (%)', fontsize=16)
    ax[0].set_ylabel('SQ-LNS treated years', fontsize=20)
    ax[0].grid()
#         ax[i,0].legend(loc=(0.8, -.25), fontsize=14)

    ax[1].set_title('Estimated fraction of\npopulation treated vs. coverage', fontsize=20)
    ax[1].set_xlabel('Program Coverage (%)', fontsize=16)
#     ax[1].set_ylabel('(survival-rate)\nx (simulation-duration / treatment-duration)\nx (treated-years / person-years)', fontsize=12)
    ax[1].set_ylabel('sqlns_treated_time /\n(treatment_duration x population_tracked)', fontsize=12)
    ax[1].grid()
        
    fig.tight_layout()

### Based on the graphs, estimated coverage is close to program coverage -- how close?

The maximum difference is less than 2%, with a mean around 0.28%.

In [None]:
(100*r.sqlns_treated_days / 
 (r.index.get_level_values('duration') * r.total_population_tracked)
 - r.index.get_level_values('coverage')).describe()

### If we invert the equation to estimate treated years from coverage, how close do we get?

The maximum difference is about 1764 treated years, with a mean around 208 treated years.

In [None]:
((r.sqlns_treated_days
 - 0.01*r.index.get_level_values('coverage')
 * r.index.get_level_values('duration')
 * r.total_population_tracked)/days_per_year).describe()

### Check that deaths due to different causes add up to `total_population_dead`:  Yes, true

In [None]:
# Deaths due to any cause add up to total deaths
(r.filter(like='death').sum(axis=1) - r.total_population_dead).sum()

## Define functions to transform data into "long" form suitible for more analysis/graphing

In [None]:
def standardize_shape(data, measure):
    measure_data = data.loc[:, [c for c in data.columns if measure in c]]
    measure_data = measure_data.stack().reset_index().rename(columns={'level_7': 'label', 0: 'value'})
    if 'due_to' in measure:
        measure, cause = measure.split('_due_to_', 1)
        measure_data.loc[:, 'measure'] = measure
        measure_data.loc[:, 'cause'] = cause
    else:
        measure_data.loc[:, 'measure'] = measure  
    measure_data.drop(columns='label', inplace=True)
    
    return measure_data

In [None]:
def get_person_time(data):
    pt = standardize_shape(data, 'person_time')
    pt = pt.rename(columns={'value': 'person_time'}).drop(columns='measure')
    return pt

In [None]:
def get_treated_days(data):
    treated = standardize_shape(data, 'sqlns_treated_days')
    treated = treated.rename(columns={'value': 'sqlns_treated_days'}).drop(columns='measure')
    return treated

In [None]:
def get_disaggregated_results(data, cause_names):
    deaths = []
    ylls = []
    ylds = []
    dalys = []
    for cause in cause_names:
        if cause in cause_names[:4]:
            deaths.append(standardize_shape(data, f'death_due_to_{cause}'))
            
            ylls_sub = standardize_shape(data, f'ylls_due_to_{cause}')
            ylds_sub = standardize_shape(data, f'ylds_due_to_{cause}')
            dalys_sub = (ylds_sub.set_index([c for c in template_cols if c != 'measure']) + \
                         ylls_sub.set_index([c for c in template_cols if c != 'measure'])).reset_index()
            dalys_sub['measure'] = 'dalys'
            
            ylls.append(ylls_sub)
            ylds.append(ylds_sub)
            dalys.append(dalys_sub)
        elif cause == 'iron_deficiency':
            ylds_sub = standardize_shape(data, f'ylds_due_to_{cause}')     
            dalys_sub = ylds_sub.copy()
            dalys_sub['measure'] = 'dalys'
            
            ylds.append(ylds_sub)
            dalys.append(dalys_sub)
        else: # cause == 'other_causes'
            deaths.append(standardize_shape(data, f'death_due_to_{cause}'))
            
            ylls_sub = standardize_shape(data, f'ylls_due_to_{cause}')
            dalys_sub = ylls_sub.copy()
            dalys_sub['measure'] = 'dalys'
            
            ylls.append(ylls_sub)
            dalys.append(dalys_sub)
    
    death_data = pd.concat(deaths, sort=False)
    yll_data = pd.concat(ylls, sort=False)
    yld_data = pd.concat(ylds, sort=False)
    daly_data = pd.concat(dalys, sort=False)
    
    output = pd.concat([death_data, yll_data, yld_data, daly_data], sort=False)
    output = output.set_index(template_cols).sort_index()
    
    return output.reset_index()

In [None]:
def get_all_cause_results(data):
    all_cause_data = data[['total_population_dead', 
                           'years_of_life_lost', 
                           'years_lived_with_disability']].rename(
        columns={'total_population_dead': 'death_due_to_all_causes',
                 'years_of_life_lost': 'ylls_due_to_all_causes', 
                 'years_lived_with_disability': 'ylds_due_to_all_causes'})
    
    all_cause_data['dalys_due_to_all_causes'] = (all_cause_data['ylls_due_to_all_causes'] 
                                                 + all_cause_data['ylds_due_to_all_causes'])
    
    return pd.concat([standardize_shape(all_cause_data, column) for column in all_cause_data.columns], sort=False)

In [None]:
def get_all_results(data, cause_names):
    return pd.concat([get_disaggregated_results(data, cause_names), get_all_cause_results(data)], sort=False)

## Get the transformed data and get a list of unique measures to plot

In [None]:
output = get_all_results(r, cause_names)
output

In [None]:
measures = output.measure.unique()
measures

## Add columns recording person time and treated time for each (scenario, draw, cause) combination

In [None]:
join_columns = [c for c in template_cols if c not in ['cause', 'measure']]
df = output.merge(get_person_time(r), on=join_columns).merge(get_treated_days(r), on=join_columns)
df

## Function to plot mortality/DALY/YLL/YLD by disease at the draw level

Each raw measure is plotted side by side with its rate per 100,000 person years. Plots should be monotonically decreasing as coverage level increases.

In [None]:
@interact()
def plot_cause_spceific_dalys_by_draw(duration=[365.25, 730.50],
                    cgf_permanent=[False, True],
                    iron_permanent=[False, True],
                    iron_mean=[0.895, 4.475, 8.950],
                    input_draw = df.input_draw.unique(),
                    measure = df.measure.unique(),
                    include_other_causes=True,
                    include_all_causes=False,
                  ):
    
    data = df.loc[(df.duration == duration)
                  & (df.child_stunting_permanent == cgf_permanent)
                  & (df.child_wasting_permanent == cgf_permanent)
                  & (df.iron_deficiency_permanent == iron_permanent)
                  & (df.iron_deficiency_mean == iron_mean)
                  & (df.input_draw == input_draw)
                  & (df.measure == measure)]
    
    fig, ax = plt.subplots(1,2, figsize=(18,8))
    
    # 'other_causes' value is much higher - can omit by indexing with [:-1]
    displayed_causes = cause_names if include_other_causes else cause_names[:-1]
    if include_all_causes:
        displayed_causes = displayed_causes + ['all_causes']
        
    for cause in displayed_causes:
        data_sub = data.loc[data.cause == cause]
        
        xx = data_sub['coverage']
        value = data_sub['value']
        value_over_pt = 100_000* data_sub['value'] / data_sub['person_time']
        
        ax[0].plot(xx, value, '-o', label=cause)
        ax[1].plot(xx, value_over_pt, '-o')
        
    singular_measure = measure if measure=='death' else measure[:-1]
    plural_measure = 'deaths' if measure=='death' else measure
    
    ax[0].set_title(f'{singular_measure.upper()} count by disease vs. coverage', fontsize=20)
    ax[0].set_xlabel('Program Coverage (%)', fontsize=20)
    ax[0].set_ylabel(f'{plural_measure.upper()}', fontsize=20)
    ax[0].grid()
    ax[0].legend(loc=(0.9, -.3))
    
    ax[1].set_title(f'{singular_measure.upper()} rate by disease vs. coverage', fontsize=20)
    ax[1].set_xlabel('Program Coverage (%)', fontsize=20)
    ax[1].set_ylabel(f'{plural_measure.upper()} per 100,000 person years', fontsize=20)
    ax[1].grid()

## Well, this is bad - ylds due to iron deficiency are always higher than total ylds

In [None]:
# r is raw data aggregated by random seed, with intervention columns renamed
(r['years_lived_with_disability'] - r['ylds_due_to_iron_deficiency']).describe()

In [None]:
# Maybe iron deficiency got excluded from the total? Compare total to sum without ID:
# They seem pretty close, so maybe that's what happened
# Except then this difference should always be positive, so this doesn't make sense
(r.filter(regex='ylds_due_to_(?!iron)').sum(axis=1) - r['years_lived_with_disability']).describe()

## Function to compute averted deaths/DALYs/YLLs/YLDs and treatment days per averted

In [None]:
def get_averted_results(df):
    bau = df[df.coverage == 0.0].drop(columns=['coverage', 'sqlns_treated_days'])
    t = pd.merge(df, bau, on=template_cols[1:], suffixes=['', '_bau'])
    
    # Averted raw value
    t['averted'] = t['value_bau'] - t['value']
    
    # Get value per 100,000 PY
    t['value_rate'] = 100_000 * t['value'] / t['person_time']
    
    # Averted value per 100,000 PY:
    t['averted_rate'] = (t['value_bau']/t['person_time_bau'] - t['value_rate']) * 100_000
    
    # Treated days per averted DALY/YLL/YLD/death can be multiplied
    # by cost per day of treatment to compute cost effectiveness
    t['treated_days_per_averted'] = t['sqlns_treated_days']/t['averted'] # Note divide-by-0 in baseline - ratio is undefined
    t['treated_days_per_averted_rate'] = 100_000*(t['sqlns_treated_days']/(t['person_time']*t['averted_rate'])) # Note divide-by-0 in baseline
    
#     t['value'] = (t['value']/t['person_time']) * 100_000
#     t['averted'] = (t['averted']/t['person_time']) * 100_000
    
    return t

## Compute averted deaths/DALYs/YLLs/YLDs per person year

In [None]:
averted_df = get_averted_results(df)
averted_df

In [None]:
# Ranges of costs per year based on cost per day of SQ-LNS
print(.11*365, .25*365, 0.03*365)

## Plot ICERs at draw level

In [None]:
averted_cause_list = averted_df.cause.unique()
cost_slider = IntSlider(value=50, min=5, max=100, step=5, continuous_update=False)

@interact()
def plot_icers_by_draw(duration=[365.25, 730.50],
                    cgf_permanent=[False, True],
                    iron_permanent=[False, True],
                    iron_mean=[0.895, 4.475, 8.950],
                    input_draw = draws,
                              measure=measures,
                              cause=averted_cause_list,
                              cost_per_py=cost_slider,
                  ):
    
    data = averted_df
    data = data.loc[(data.duration == duration)
                  & (data.child_stunting_permanent == cgf_permanent)
                  & (data.child_wasting_permanent == cgf_permanent)
                  & (data.iron_deficiency_permanent == iron_permanent)
                  & (data.iron_deficiency_mean == iron_mean)
                  & (data.input_draw == input_draw)
                  & (data.cause == cause)
                  & (data.measure == measure)]
    
    fig, ax = plt.subplots(2,2, figsize=(12,8))
    
    xx = data['coverage']
    

    ax[0,0].plot(xx, cost_per_py * data['sqlns_treated_days'] / days_per_year, '-o')
    ax[1,0].plot(xx, data['averted'], '-o', color='green')
    ax[0,1].plot(xx, cost_per_py * data['treated_days_per_averted'] / days_per_year,
               '-o', color='orange')
    ax[1,1].plot(xx, cost_per_py * data['treated_days_per_averted_rate'] / days_per_year,
               '-o', color='orange')

    ax[0,0].set_title('Total cost vs. coverage', fontsize=20)
    ax[0,0].set_xlabel('Program Coverage (%)', fontsize=16)
    ax[0,0].set_ylabel('Cost of SQ-LNS\ntreatment ($)', fontsize=16)
    ax[0,0].grid()
#         ax[i,0].legend(loc=(0.8, -.25), fontsize=14)

    ax[1,0].set_title(f'Averted {measure} vs. coverage', fontsize=20)
    ax[1,0].set_xlabel('Program Coverage (%)', fontsize=16)
    ax[1,0].set_ylabel(f'Averted {measure}', fontsize=20)
    ax[1,0].grid()

    ax[0,1].set_title('Cost effectiveness (ICERs)\nvs. coverage', fontsize=20)
    ax[0,1].set_xlabel('Program Coverage (%)', fontsize=16)
    ax[0,1].set_ylabel(f'Cost per averted {measure}', fontsize=12)
    ax[0,1].grid()
    
    ax[1,1].set_title('Cost effectiveness (ICERs)\nvs. coverage', fontsize=20)
    ax[1,1].set_xlabel('Program Coverage (%)', fontsize=16)
    ax[1,1].set_ylabel(f'Cost per averted {measure}\n(calculated using rate difference)', fontsize=12)
    ax[1,1].grid()
        
    fig.tight_layout()

### Note: This dataframe has two infinities in the `treated_days_per_averted` column

In [None]:
data = averted_df
data = data.loc[(data.duration == 365.25)
              & (data.child_stunting_permanent == False)
              & (data.child_wasting_permanent == False)
              & (data.iron_deficiency_permanent == False)
              & (data.iron_deficiency_mean == 0.895)
              & (data.input_draw == 29)
              & (data.cause == 'diarrheal_diseases')
              & (data.measure == 'death')
#               & (data.coverage == 60)
               ]
data

### Note: This dataframe has very small negative numbers instead of zero, which throws off the calculation

In [None]:
data = averted_df
data = data.loc[(data.duration == 365.25)
              & (data.child_stunting_permanent == False)
              & (data.child_wasting_permanent == False)
              & (data.iron_deficiency_permanent == False)
              & (data.iron_deficiency_mean == 0.895)
              & (data.input_draw == 602)
              & (data.cause == 'measles')
              & (data.measure == 'ylls')
#               & (data.coverage == 60)
               ]
data

## Function to compute mean and percentiles over draws

In [None]:
def get_final_table(data):
    # Group by all index columns except input_draw to aggregate over draws
    g = data.groupby(template_cols[:-1])[['value',
                                          'person_time', 
                                          'sqlns_treated_days',
                                          'averted',
                                          'averted_rate',
                                          'treated_days_per_averted',
                                          'treated_days_per_averted_rate',
                                         ]]\
            .describe(percentiles=[.025, .975]) # returns mean, stdev, median, percentiles .025 & .975
    
#     table = g.filter([('value', 'mean'), ('value', '2.5%'), ('value', '97.5%'),
#                       ('person_time', 'mean'), ('person_time', '2.5%'), ('person_time', '97.5%'),
#                       ('sqlns_treated_days', 'mean'), ('sqlns_treated_days', '2.5%'), ('sqlns_treated_days', '97.5%'),
#                       ('averted', 'mean'), ('averted', '2.5%'), ('averted', '97.5%'),
#                       ('treated_days_per_averted', 'mean'), ('treated_days_per_averted', '2.5%'), ('treated_days_per_averted', '97.5%'),
#                       ])
#     return table
    return g

## Compute aggregated results

In [None]:
aggregated_results_df = get_final_table(averted_df)
aggregated_results_df

## Draw graphs of aggregated averted DALYs per 100,000 PY, and DALYs per 100,000 PY

In [None]:
@interact()
def plot_aggregated_averted_rates(duration=[365.25, 730.50],
                       cgf_permanent=[False, True],
                       iron_permanent=[False, True],
                       iron_mean=[0.895, 4.475, 8.950],
                       measure = measures,
                       include_other_causes=False,
                       include_all_causes=False,
                      ):
    
    df = aggregated_results_df.reset_index()
    
    data = df.loc[(df.duration == duration)
                  & (df.child_stunting_permanent == cgf_permanent)
                  & (df.child_wasting_permanent == cgf_permanent)
                  & (df.iron_deficiency_permanent == iron_permanent)
                  & (df.iron_deficiency_mean == iron_mean)
                  & (df.measure == measure)]
    
    plt.figure(figsize=(12, 8))
    
    # 'other_causes' value is much higher - can omit by indexing with [:-1]
    displayed_causes = cause_names if include_other_causes else cause_names[:-1]
    if include_all_causes:
        displayed_causes = displayed_causes + ['all_causes']
        
    for cause in displayed_causes:
        data_sub = data.loc[data.cause == cause]
        
        xx = data_sub['coverage']
        mean = data_sub[('averted_rate', 'mean')]
        lb = data_sub[('averted_rate', '2.5%')]
        ub = data_sub[('averted_rate', '97.5%')]
        
        plt.plot(xx, mean, '-o', label=cause)
        plt.fill_between(xx, lb, ub, alpha=0.1)
    
    plt.title('Nigeria')
    plt.xlabel('Program Coverage (%)')
    plt.ylabel(f'{measure.upper()} Averted (per 100,000 PY)')
    plt.legend(loc=(1.05, .05))
    plt.grid()

In [None]:
@interact()
def plot_dalys_per_1e5_py(duration=[365.25, 730.50],
                       cgf_permanent=[False, True],
                       iron_permanent=[False, True],
                       iron_mean=[0.895, 4.475, 8.950],
                        include_other_causes=False):
    
    df = aggregated_results_df.reset_index()
    
    data = df.loc[(df.duration == duration)
                  & (df.child_stunting_permanent == cgf_permanent)
                  & (df.child_wasting_permanent == cgf_permanent)
                  & (df.iron_deficiency_permanent == iron_permanent)
                  & (df.iron_deficiency_mean == iron_mean)
                  & (df.measure == 'dalys')]
    
    plt.figure(figsize=(12, 8))
    
    # 'other_causes' value is much higher - can omit by indexing with [:-1]
    displayed_causes = cause_names if include_other_causes else cause_names[:-1]
    for cause in displayed_causes:
        data_sub = data.loc[data.cause == cause]
        
        xx = data_sub['coverage']
        mean_per_py = data_sub[('value', 'mean')]
        lb = data_sub[('value', '2.5%')]
        ub = data_sub[('value', '97.5%')]
        
        plt.plot(xx, mean_per_py, '-o', label=cause)
        plt.fill_between(xx, lb, ub, alpha=0.1)
    
    plt.title('Nigeria')
    plt.xlabel('Program Coverage (%)')
    plt.ylabel('DALYs per 100,000 PY')
    plt.legend(loc=(1.05, .05))
    plt.grid()

In [None]:
@interact()
def plot_icers(duration=[365.25, 730.50],
                    cgf_permanent=[False, True],
                    iron_permanent=[False, True],
                    iron_mean=[0.895, 4.475, 8.950],
                              measure=measures,
                              cause=averted_cause_list,
                              cost_per_py=cost_slider,
                  ):
    
    data = aggregated_results_df.reset_index()
    
    data = data.loc[(data.duration == duration)
                  & (data.child_stunting_permanent == cgf_permanent)
                  & (data.child_wasting_permanent == cgf_permanent)
                  & (data.iron_deficiency_permanent == iron_permanent)
                  & (data.iron_deficiency_mean == iron_mean)
                  & (data.cause == cause)
                  & (data.measure == measure)]
    
    fig, ax = plt.subplots(2,2, figsize=(14,9))
    
    xx = data['coverage']
    
    # Plot cost vs. coverage
    mean = cost_per_py * data[('sqlns_treated_days', 'mean')] / days_per_year
    lb = cost_per_py * data[('sqlns_treated_days', '2.5%')] / days_per_year
    ub = cost_per_py * data[('sqlns_treated_days', '97.5%')] / days_per_year
    ax[0,0].plot(xx, mean, '-o')
    ax[0,0].fill_between(xx, lb, ub, alpha=0.1)
    
    # Plot averted measure vs. coverage
    mean = data[('averted', 'mean')]
    lb = data[('averted', '2.5%')]
    ub = data[('averted', '97.5%')]
    ax[1,0].plot(xx, mean, '-o', color='orange')
    ax[1,0].fill_between(xx, lb, ub, alpha=0.1, color='orange')
    
    # Plot ICERs calculated using raw values
    mean = cost_per_py * data[('treated_days_per_averted', 'mean')] / days_per_year
    lb = cost_per_py * data[('treated_days_per_averted', '2.5%')] / days_per_year
    ub = cost_per_py * data[('treated_days_per_averted', '97.5%')] / days_per_year
    ax[0,1].plot(xx, mean, '-o', color='green')
    ax[0,1].fill_between(xx, lb, ub, alpha=0.1, color='green')
    
    # Plot ICERs calculated using rates
    mean = cost_per_py * data[('treated_days_per_averted_rate', 'mean')] / days_per_year
    lb = cost_per_py * data[('treated_days_per_averted_rate', '2.5%')] / days_per_year
    ub = cost_per_py * data[('treated_days_per_averted_rate', '97.5%')] / days_per_year
    ax[1,1].plot(xx, mean, '-o', color='green')
    ax[1,1].fill_between(xx, lb, ub, alpha=0.1, color='green')

    ## Label the plots
    
    ax[0,0].set_title('Total cost vs. coverage', fontsize=16)
    ax[0,0].set_xlabel('Program Coverage (%)', fontsize=12)
    ax[0,0].set_ylabel('Cost of SQ-LNS\ntreatment ($)', fontsize=16)
    ax[0,0].grid()
#         ax[i,0].legend(loc=(0.8, -.25), fontsize=14)

    ax[1,0].set_title(f'Averted {measure} vs. coverage', fontsize=16)
    ax[1,0].set_xlabel('Program Coverage (%)', fontsize=12)
    ax[1,0].set_ylabel(f'Averted {measure}', fontsize=16)
    ax[1,0].grid()

    ax[0,1].set_title('Cost effectiveness (ICERs)\nvs. coverage', fontsize=16)
    ax[0,1].set_xlabel('Program Coverage (%)', fontsize=12)
    ax[0,1].set_ylabel(f'Cost per averted {measure}', fontsize=12)
    ax[0,1].grid()
    
    ax[1,1].set_title('Cost effectiveness (ICERs)\nvs. coverage', fontsize=16)
    ax[1,1].set_xlabel('Program Coverage (%)', fontsize=12)
    ax[1,1].set_ylabel(f'Cost per averted {measure}\n(calculated using rate difference)', fontsize=12)
    ax[1,1].grid()
        
    fig.tight_layout()

In [None]:
3e6 /600