# V&V anemia screening and iron interventions

This notebook focuses on anemia screening, oral iron, and IV iron.

All the separate checks in this notebok are labeled with "CHECK" (all caps).

## Setup

In [None]:
from vivarium import Artifact, InteractiveContext
import pandas as pd, numpy as np, os

In [None]:
! pip list | grep vivarium

# [software verion + hash . date]

In [None]:
! pip freeze | grep vivarium

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

from vivarium import InteractiveContext, Artifact

In [None]:
import vivarium_gates_mncnh
from vivarium.framework.configuration import build_model_specification
from pathlib import Path

path = Path(vivarium_gates_mncnh.__file__).parent / 'model_specifications/model_spec.yaml'
custom_model_specification = build_model_specification(path)
del custom_model_specification.configuration.observers
custom_model_specification.configuration.input_data.input_draw_number = 60
custom_model_specification.configuration.population.population_size = 20_000 * 10

artifact_path = custom_model_specification.configuration.input_data.artifact_path
art = Artifact(artifact_path)

artifact_path

In [None]:
draw_num = custom_model_specification.configuration.input_data.input_draw_number
draw = 'draw_' + str(draw_num)
draw

CHECK: IFA coverage in the artifact does not vary by sex, age, or year.

Suite: artifact tests.

Type: precise assert.

In [None]:
ifa_coverage_at_anc = art.load('risk_factor.iron_folic_acid_supplementation.coverage')[draw]
assert (ifa_coverage_at_anc.groupby('parameter').nunique() == 1).all()
ifa_coverage_at_anc = ifa_coverage_at_anc.groupby('parameter').first()
# ok, so following the update,
    # cat1: uncovered
    # cat2: covered
ifa_coverage_at_anc = ifa_coverage_at_anc.loc['cat2']
ifa_coverage_at_anc

CHECK: ANC1 coverage in the artifact does not vary by sex, age, or year.

Suite: artifact tests.

Type: precise assert.

In [None]:
anc1 = art.load('covariate.antenatal_care_1_visit_coverage_proportion.estimate')[draw]
assert len(anc1) == 1
anc1 = anc1.values[0]
ifa_coverage = ifa_coverage_at_anc * anc1
ifa_coverage

In [None]:
ifa_effect = art.load('risk_factor.iron_folic_acid_supplementation.effect_size')[draw]
ifa_effect_hgb = ifa_effect['hemoglobin.exposure']
ifa_effect_hgb

CHECK: IV iron's effect on hemoglobin in the artifact does not vary by sex, age, or year.

Suite: artifact tests.

Type: precise assert.

In [None]:
iv_iron_effect_hgb = art.load('intervention.iv_iron.hemoglobin_effect_size')[draw]
assert len(iv_iron_effect_hgb) == 1
iv_iron_effect_hgb = iv_iron_effect_hgb.iloc[0]
iv_iron_effect_hgb

CHECK: Relative risks of hemoglobin in the artifact do not vary by sex, age, or year.

Suite: artifact tests.

Type: precise assert.

In [None]:
art_rrs = art.load('risk_factor.hemoglobin.relative_risk').reset_index()
exposure_levels = art_rrs.parameter.unique()
tmrel = exposure_levels[len([x for x in exposure_levels if x < 120])]
art_tmrel = art_rrs.loc[art_rrs.parameter == tmrel].drop(columns='parameter')
art_rrs = (art_rrs.set_index([x for x in art_rrs.columns if 'draw' not in x])
          / art_tmrel.set_index([x for x in art_tmrel.columns if 'draw' not in x])).reset_index()
assert (art_rrs.groupby(['affected_entity', 'parameter'])[draw].nunique() == 1).all()
art_rrs = art_rrs.groupby(['affected_entity', 'parameter'])[draw].first()
art_rrs

In [None]:
art_rrs.index.get_level_values('affected_entity').unique()

CHECK: Probability of low ferritin in the artifact in the 'not_anemic' category is always half of the corresponding value in the 'mild' category.

Suite: artifact tests.

Type: precise assert.

In [None]:
art_low_ferritin_probs = art.load('ferritin.probability_of_low_ferritin')[draw]
# https://vivarium-research.readthedocs.io/en/latest/models/intervention_models/mncnh_pregnancy/anemia_screening.html#id6
assert np.allclose(
    art_low_ferritin_probs.loc[(slice(None), slice(None), 'not_anemic')],
    art_low_ferritin_probs.loc[(slice(None), slice(None), 'mild')] / 2,
)
art_low_ferritin_probs

# I might have expected that more people would have low ferritin in the more extreme anemias, but it is actually the opposite

In [None]:
# TODO: Merge this section with the next
art_iv_iron_sb_rr = art.load('intervention.iv_iron.stillbirth_relative_risk')[draw].rename('val')
art_iv_iron_lbwsg_shifts = art.load('intervention.iv_iron.low_birth_weight_and_short_gestation_effect_size')[draw]

In [None]:
def create_df(scenario):
    from copy import deepcopy
    custom_model_specification_scenario = deepcopy(custom_model_specification)
    custom_model_specification_scenario.configuration.intervention.scenario = scenario
    sim = InteractiveContext(custom_model_specification_scenario)
    hemoglobin_component = sim.list_components()['risk_factor.hemoglobin']
    sim_step_name = hemoglobin_component._sim_step_name
    data = pd.DataFrame()
    timestep = 'initialization'
    for i in list(range(0,5)):
        pop = sim.get_population()
        pop = pd.concat([pop, sim.get_value('hemoglobin.exposure')(pop.index)], axis=1)
        pop['timestep'] = timestep
        pop['timestep_n'] = i
        data = pd.concat([data, pop])
        timestep = sim_step_name()
        sim.step()
    return data

In [None]:
baseline_data = create_df('baseline')
baseline_data

In [None]:
iv_iron_data = create_df('anemia_screening_and_iv_iron_scaleup')

In [None]:
iv_iron_indices = iv_iron_data.loc[iv_iron_data.iv_iron_intervention=='covered'].index

In [None]:
cols = ['gestational_age_exposure','birth_weight_exposure','hemoglobin.exposure','pregnancy_outcome',
         'hemoglobin_exposure','oral_iron_intervention','iv_iron_intervention','timestep','sex_of_child','simulant']

In [None]:
baseline_data['simulant'] = baseline_data.index.values
iv_iron_data['simulant'] = iv_iron_data.index.values

CHECK: The proportion of stillbirths for those receiving IV iron changes between scenarios
with a curve similar to the one saved in the artifact.

Type: manual check.

In [None]:
iv_iron_comp = iv_iron_data.loc[iv_iron_indices][cols].merge(baseline_data.loc[iv_iron_indices][cols],
                                                                on=['simulant','timestep'], suffixes=['_iv_iron','_baseline']).drop_duplicates()
def plot_sb_rr(hgb_exposure_timing):
    plot_data = pd.DataFrame()
    plot_data['pre_int_hgb'] = iv_iron_comp.loc[iv_iron_comp.timestep==hgb_exposure_timing]['hemoglobin.exposure_iv_iron'].values
    plot_data['iv_iron_outcome'] = iv_iron_comp.loc[iv_iron_comp.timestep=='ultrasound']['pregnancy_outcome_iv_iron'].values
    plot_data['baseline_outcome'] = iv_iron_comp.loc[iv_iron_comp.timestep=='ultrasound']['pregnancy_outcome_baseline'].values
    plot_data['hemoglobin_bin'] = pd.cut(plot_data.pre_int_hgb, 15)
    baseline_sb = plot_data.rename(columns={'baseline_outcome':'outcome'}).groupby('hemoglobin_bin').outcome.value_counts(normalize=True)
    intervention_sb = plot_data.rename(columns={'iv_iron_outcome':'outcome'}).groupby('hemoglobin_bin').outcome.value_counts(normalize=True)
    sb_rr = (intervention_sb / baseline_sb).reset_index()
    sb_rr = sb_rr.loc[sb_rr.outcome=='stillbirth']
    sb_rr['hgb_bin_start'] = [sb_rr.loc[i].hemoglobin_bin.left for i in sb_rr.index]
    sb_rr['hgb_bin_end'] = [sb_rr.loc[i].hemoglobin_bin.right for i in sb_rr.index]
    sb_rr['hgb_bin_mid'] = (sb_rr.hgb_bin_end - sb_rr.hgb_bin_start) / 2 + sb_rr.hgb_bin_start
    import matplotlib.pyplot as plt
    plt.figure()
    plt.plot(sb_rr.hgb_bin_mid, sb_rr['proportion'], label='Simulation', marker='o')
    plt.plot(art_iv_iron_sb_rr.reset_index().first_trimester_hemoglobin_exposure_start, art_iv_iron_sb_rr, color='tab:orange', label='Artifact')
    plt.xlabel('First trimester ANC hemoglobin')
    plt.ylabel('RR of stillbirth due to IV iron')
    plt.title(f'Effect of IV iron on stillbirth by {hgb_exposure_timing} hemoglobin exposure')
    plt.grid()
    plt.legend()

for hgb_exposure_timing in ['first_trimester_anc']:
    plot_sb_rr(hgb_exposure_timing)

In [None]:
iv_iron_comp['ga_diff'] = iv_iron_comp.gestational_age_exposure_iv_iron - iv_iron_comp.gestational_age_exposure_baseline
iv_iron_comp['bw_diff'] = iv_iron_comp.birth_weight_exposure_iv_iron - iv_iron_comp.birth_weight_exposure_baseline

import matplotlib.pyplot as plt

def plot_lbwsg_shifts(outcome, sexes=('Male', 'Female'), hgb_exposure_timing='first_trimester_anc', consistent_live_birth_restriction=False):
    outcome_col = {
        'birth_weight': 'bw_diff',
        'gestational_age': 'ga_diff',
    }[outcome]

    for sex in sexes:
        plt.figure()

        is_ultrasound = iv_iron_comp.timestep == 'ultrasound'
        is_hgb_time = iv_iron_comp.timestep == hgb_exposure_timing
        is_live_birth = (
            (iv_iron_comp.pregnancy_outcome_baseline == 'live_birth') &
            (iv_iron_comp.pregnancy_outcome_iv_iron == 'live_birth')
        )

        live_birth_simulants = iv_iron_comp.loc[
            is_ultrasound & is_live_birth, 'simulant'
        ].unique()

        p = iv_iron_comp.loc[iv_iron_comp.sex_of_child_baseline == sex]
        a = art_iv_iron_lbwsg_shifts.loc[sex, outcome]

        if consistent_live_birth_restriction:
            p = p.loc[p.simulant.isin(live_birth_simulants)]

            starts = a.index.get_level_values('first_trimester_hemoglobin_exposure_start')
            ends = a.index.get_level_values('first_trimester_hemoglobin_exposure_end')

            starts = starts.to_series().replace(starts.min(), -np.inf)

            intervals = pd.IntervalIndex.from_arrays(starts, ends, closed="left")

            hgb_values = p.loc[is_hgb_time, 'hemoglobin.exposure_iv_iron']
            idx = intervals.get_indexer(hgb_values)

            ultrasound_idx = (
                p.loc[is_ultrasound]
                 .reset_index()
                 .set_index('simulant')
                 .loc[p.loc[is_hgb_time, 'simulant'], 'index']
                 .values
            )

            recalc = pd.Series(a.to_numpy()[idx], index=ultrasound_idx)

            assert np.allclose(
                recalc,
                p.loc[is_ultrasound, outcome_col],
                rtol=0,
                atol=1e-10,
            ), f"{outcome} shift doesn't match recalculation"

        x = p.loc[is_hgb_time, 'hemoglobin.exposure_iv_iron']
        y = p.loc[is_ultrasound, outcome_col]

        plt.scatter(x, y, label='Simulation')
        plt.plot(
            a.reset_index().first_trimester_hemoglobin_exposure_start,
            a,
            color='tab:orange',
            label='Artifact',
        )

        plt.xlabel('Hemoglobin exposure')
        plt.ylabel(f'{outcome} shift')
        plt.title(f'IV iron effect on {outcome} based on {hgb_exposure_timing} hemoglobin exposure')
        plt.grid()
        plt.legend()

CHECK: Each of gestational age and birthweight vary between scenarios for the simulants who received IV iron,
by an amount matching the shift saved in the artifact, *except* for simulants whose birth outcome changed.

Type: precise assert.

In [None]:
plot_lbwsg_shifts(outcome='gestational_age')
plot_lbwsg_shifts(outcome='birth_weight')

In [None]:
plot_lbwsg_shifts(outcome='gestational_age', consistent_live_birth_restriction=True)
plot_lbwsg_shifts(outcome='birth_weight', consistent_live_birth_restriction=True)

Whole bunch of checks contained in this function. Precise asserts, in the interactive sim test suite, unless noted otherwise.

Before first step:

* CHECK: ANC attendance column starts as all null.
* CHECK: Hemoglobin and ferritin screening coverage columns start as all null, *and stay that way until the screening timestep*.
* CHECK: Oral iron column starts as all 'no_treatment'.
* CHECK: Each simulant's hemoglobin value should be lower than their 'raw_hemoglobin' value (the simulant's draw from the GBD distribution) by baseline IFA coverage times the IFA effect on hemoglobin.

After first-trimester ANC timestep:

* CHECK: ANC attendance column all non-null.
* CHECK: Oral iron column remains 'no_treatment' for simulants who did not attend first-trimester ANC.
* CHECK: Oral iron column all 'mms' for simulants who did attend first-trimester ANC, in an MMS scale-up scenario.
* CHECK: Oral iron column a mix of 'ifa' and 'no_treatment' in non-MMS scale-up scenarios.
* CHECK: In non-MMS scale-up scenarios, proportion of those attending first-trimester ANC who received IFA should match IFA coverage from artifact. Type: fuzzy check of proportion.
* CHECK: Simulants who received IFA or MMS on this timestep should have had their hemoglobin value increased by the IFA effect (shared by MMS), relative to the previous timestep.
* CHECK: Simulants who did not receive oral iron on this timestep should have the same hemoglobin they had on the previous timestep.

After later-pregnancy screening timestep:

* CHECK: In non-screening-scale-up scenarios, there is non-zero hemoglobin screening.
* CHECK: In screening scale-up scenarios, the simulants with hemoglobin screening coverage are exactly those who attend later-pregnancy ANC.
* CHECK: The proportion of simulants with truly low hemoglobin (hemoglobin <100) who test low (out of those who test) is approximately the documented 85% sensitivity. Type: fuzzy check of proportion.
* CHECK: The proportion of simulants with truly adequate hemoglobin (hemoglobin >= 100) who test adequate (out of those who test) is approximately the documented 80% specificity. Type: fuzzy check of proportion. 
* CHECK: In non-screening-scale-up scenarios, there is zero ferritin screening.
* CHECK: In screening scale-up scenarios, the simulants with ferritin screening coverage are exactly those who test low hemoglobin.
* CHECK: The proportion of simulants who test low ferritin (out of those who are tested) matches the target in the artifact (by age and anemia status). Type: fuzzy check of proportion.

After later-pregnancy intervention timestep:

* CHECK: In non-MMS-scaleup scenarios, the only change to oral iron coverage vs the previous timestep is some later-pregnancy-only-ANC-attending simulants shifting from 'none' to 'ifa'.
* CHECK: In MMS-scaleup scenarios, the only change to oral iron coverage vs the previous timestep is some later-pregnancy-only-ANC-attending simulants shifting from 'none' to 'mms'.
* CHECK: No IV iron received in non-IV-iron scaleup scenarios.
* CHECK: In IV-iron scaleup scenarios, IV iron received by exactly the simulants who test low ferritin.
* CHECK: Hemoglobin unchanged from previous timestep for simulants who neither changed oral iron coverage on this timestep, nor received IV iron.
* CHECK: Hemoglobin increased vs previous timestep in those who received IV iron by the IV iron effect size on hemoglobin.
* CHECK: Hemoglobin increased vs previous timestep in those who received new oral iron and did NOT receive IV iron by the IFA effect size on hemoglobin.

For each outcome (maternal hemorrhage and maternal sepsis) timestep:

* CHECK: Hemoglobin unchanged from after later-pregnancy intervention timestep.
* CHECK: Simulant-level relative risk on this outcome equals a recalculated interpolation of the artifact RRs.
* CHECK: The relative risk of the event happening in each quantile-bin of hemoglobin (vs the quantile-bin that contains the TMREL) approximately matches the mean of the artifact RRs for hemoglobin values in that quantile-bin. Type: manual, since this is a pretty odd check with lots of room for imprecision. We should consider if there is another good way for us to (partially) check this.

In [None]:
def check_hemoglobin_in_interactive_sim(scenario):
    from copy import deepcopy
    custom_model_specification_scenario = deepcopy(custom_model_specification)
    custom_model_specification_scenario.configuration.intervention.scenario = scenario

    sim = InteractiveContext(custom_model_specification_scenario)

    hemoglobin_component = sim.list_components()['risk_factor.hemoglobin']
    sim_step_name = hemoglobin_component._sim_step_name
    sim_step_name()

    initial_pop = sim.get_population()
    assert initial_pop.anc_attendance.isnull().all()
    assert (initial_pop.oral_iron_intervention == 'no_treatment').all()

    # https://github.com/ihmeuw/vivarium_public_health/blob/3b2b4f13ef53c0e9fd378e824eeb81a9057fea8d/src/vivarium_public_health/risks/base_risk.py#L303-L305
    raw_hemoglobin = hemoglobin_component.exposure_distribution.ppf(hemoglobin_component.propensity(initial_pop.index)).rename('raw_hemoglobin')
    initial_hemoglobin = sim.get_value('hemoglobin.exposure')(initial_pop.index).rename('initial_hemoglobin')

    assert sim_step_name() == "first_trimester_anc"
    sim.step()
    first_trimester_anc_pop = sim.get_population()
    assert first_trimester_anc_pop.anc_attendance.notnull().all()
    assert (first_trimester_anc_pop[
        first_trimester_anc_pop.anc_attendance.isin(['none', 'later_pregnancy_only'])
    ].oral_iron_intervention == 'no_treatment').all(), "people receiving IFA before screening who did not attend ANC in first trimester"

    if scenario in ('baseline', 'anemia_screening_vv', 'anemia_screening_and_iv_iron_scaleup'):
        # are we seeing the right coverage of IFA among those who attend ANC in the first trimester?
        print(f'Target IFA coverage at ANC: {ifa_coverage_at_anc}')
    elif scenario == 'mms_full_scaleup':
        assert (
            first_trimester_anc_pop.loc[~first_trimester_anc_pop.anc_attendance.isin(['none', 'later_pregnancy_only']), 'oral_iron_intervention'] == 'mms'
        ).all(), "MMS not fully scaled up!"
    print('Simulated oral iron coverage at ANC in first trimester:')
    display(first_trimester_anc_pop.loc[~first_trimester_anc_pop.anc_attendance.isin(['none', 'later_pregnancy_only'])].oral_iron_intervention.value_counts(normalize=True))

    first_trimester_anc_pop = pd.concat([
        # NOTE: In previous versions of this notebook we also checked the hemoglobin_exposure column in the state table
        # This is an implementation detail that won't be needed after the population refactor, and it lags behind the
        # actual value by one timestep, which is a bit confusing.
        # We ignore it here and check everything using the pipeline.
        first_trimester_anc_pop.drop(columns=['hemoglobin_exposure']),
        raw_hemoglobin,
        initial_hemoglobin,
        sim.get_value('hemoglobin.exposure')(first_trimester_anc_pop.index).rename('hemoglobin_exposure'),
        sim.get_value('hemoglobin_on_maternal_hemorrhage.relative_risk')(first_trimester_anc_pop.index)
    ], axis=1)

    print('Mean hemoglobin exposure: GBD vs after first trimester ANC in sim')
    display(first_trimester_anc_pop[['raw_hemoglobin', 'hemoglobin_exposure']].mean())

    if scenario in ('baseline', 'anemia_screening_vv', 'anemia_screening_and_iv_iron_scaleup'):
        # hemoglobin exposure is still lower at this point than GBD estimates, because not all IFA has been applied
        assert first_trimester_anc_pop['hemoglobin_exposure'].mean() < first_trimester_anc_pop['raw_hemoglobin'].mean()
    elif scenario == 'mms_full_scaleup':
        # There's not a clear target for whether this should be lower or higher,
        # since not all oral iron has been applied yet, but MMS has been scaled up to 100% of early ANC visits
        pass

    assert np.allclose(
        first_trimester_anc_pop.loc[first_trimester_anc_pop.oral_iron_intervention != 'no_treatment', 'hemoglobin_exposure'] -
        first_trimester_anc_pop.loc[first_trimester_anc_pop.oral_iron_intervention != 'no_treatment', 'initial_hemoglobin'],
        ifa_effect_hgb,
        rtol=0,
        atol=1e-13,
    ), "Oral iron effect does not match expectation from recalculation"

    assert (
        first_trimester_anc_pop.loc[first_trimester_anc_pop.oral_iron_intervention == 'no_treatment', 'hemoglobin_exposure'] ==
        first_trimester_anc_pop.loc[first_trimester_anc_pop.oral_iron_intervention == 'no_treatment', 'initial_hemoglobin']
    ).all(), "hemoglobin modified in simulants not receiving oral iron"

    assert np.allclose(
        first_trimester_anc_pop['raw_hemoglobin'] - first_trimester_anc_pop['initial_hemoglobin'],
        ifa_coverage * ifa_effect_hgb,
        rtol=0,
        atol=1e-13,
    ), "IFA deletion does not match expectation from recalculation"

    assert sim_step_name() == "later_pregnancy_screening"
    assert (~first_trimester_anc_pop.filter(like='screening_coverage')).all().all(), "Screening occurred before screening timestep"
    assert first_trimester_anc_pop.anemia_status_during_pregnancy.isnull().all(), "Anemia status during pregnancy assigned before screening timestep"

    sim.step() # Step past screening
    pop_after_screening = sim.get_population()
    pop_after_screening = pd.concat([
        pop_after_screening.drop(columns=['hemoglobin_exposure']),
        sim.get_value('hemoglobin.exposure')(pop_after_screening.index).rename('hemoglobin_exposure'),
    ], axis=1)

    assert pop_after_screening["hemoglobin_screening_coverage"].any(), "No baseline hemoglobin screening"
    if scenario in ('baseline', 'mms_total_scaleup'):
        assert not pop_after_screening["ferritin_screening_coverage"].any(), "Ferritin screening when not scaled up"
    elif scenario in ('anemia_screening_vv', 'anemia_screening_and_iv_iron_scaleup'):
        assert pop_after_screening[~pop_after_screening.anc_attendance.isin(['none', 'first_trimester_only'])].hemoglobin_screening_coverage.all()
        # NOTE: ferritin_screening_coverage column is not correct!
        assert (
            (pop_after_screening.tested_hemoglobin == 'low')
            ==
            (pop_after_screening.tested_ferritin != 'not_tested')
        ).all(), "Ferritin screening not fully scaled up, or applied to those not testing low hemoglobin!"
        # NOTE: anemia_status_during_pregnancy column is not updated!
        # https://github.com/ihmeuw/vivarium_gates_mncnh/blob/98980208ebcf8c4f19fa121f6ca20eba56357840/src/vivarium_gates_mncnh/components/screening.py#L90
        anemia_status_during_pregnancy = (
            pd.cut(
                pop_after_screening['hemoglobin_exposure'],
                bins=[-np.inf, 70, 100, 110],
                labels=["severe", "moderate", "mild"],
                right=False,
            )
            .astype("object")
            .fillna("not_anemic")
        )
        age_cutoffs = sorted(list(set(art_low_ferritin_probs.index.get_level_values('age_start')) | set(art_low_ferritin_probs.index.get_level_values('age_end'))))
        age_groups = (
            pd.cut(
                pop_after_screening['age'],
                bins=age_cutoffs,
                right=False,
            )
        )
        sim_ferritin_results = (
            pop_after_screening.assign(
                anemia_status_during_pregnancy=anemia_status_during_pregnancy,
                age_start=age_groups.cat.categories.left[age_groups.cat.codes],
                age_end=age_groups.cat.categories.right[age_groups.cat.codes],
            )
                [pop_after_screening.tested_ferritin != 'not_tested']
                .groupby(['age_start', 'age_end', 'anemia_status_during_pregnancy'])
                .tested_ferritin.value_counts(normalize=True).loc[(slice(None), slice(None), slice(None), 'low')]
        )
        comparison = art_low_ferritin_probs.rename('target').to_frame().join(sim_ferritin_results.rename('sim'))
        print('Percent of those tested who have low ferritin:')
        display(comparison)
    else:
        raise ValueError()

    # TODO: Why is this not working?
    # assert pop_after_screening.anemia_status_during_pregnancy.notnull().all()

    # https://vivarium-research.readthedocs.io/en/latest/models/intervention_models/mncnh_pregnancy/anemia_screening.html#hemoglobin-screening-accuracy-instructions
    print('Specificity target (percent of true negatives that test negative): 80%')
    print('Observed:')
    display(pop_after_screening[
        (pop_after_screening.hemoglobin_exposure > 100) &
        (pop_after_screening.hemoglobin_screening_coverage)
    ].tested_hemoglobin.value_counts(dropna=False, normalize=True))

    print('Sensitivity target (percent of true positives that test positive): 85%')
    print('Observed:')
    display(pop_after_screening[
        (pop_after_screening.hemoglobin_exposure < 100) &
        (pop_after_screening.hemoglobin_screening_coverage)
    ].tested_hemoglobin.value_counts(dropna=False, normalize=True))
    
    screening_step_hemoglobin_exposure = sim.get_value('hemoglobin.exposure')(pop_after_screening.index)    
    assert sim_step_name() == 'later_pregnancy_intervention'
    sim.step()
    pop_after_later_pregnancy_intervention = sim.get_population()

    oral_iron_changed = pop_after_later_pregnancy_intervention.oral_iron_intervention != pop_after_screening.oral_iron_intervention
    later_oral_iron_changes = pd.DataFrame({
        'oral_iron_before': pop_after_screening[oral_iron_changed].oral_iron_intervention,
        'oral_iron_after': pop_after_later_pregnancy_intervention[oral_iron_changed].oral_iron_intervention,
    })
    assert (later_oral_iron_changes.oral_iron_before == 'no_treatment').all()
    if scenario in ('baseline', 'anemia_screening_vv', 'anemia_screening_and_iv_iron_scaleup'):
        assert (later_oral_iron_changes.oral_iron_after == 'ifa').all()
    elif scenario == 'mms_total_scaleup':
        assert (later_oral_iron_changes.oral_iron_after == 'mms').all()
    print('Oral iron changes at later ANC:')
    display(later_oral_iron_changes)

    assert (pop_after_screening[oral_iron_changed].anc_attendance == 'later_pregnancy_only').all()

    iv_iron_received = pop_after_later_pregnancy_intervention.iv_iron_intervention == 'covered'

    if scenario in ('baseline', 'anemia_screening_vv', 'mms_total_scaleup'):
        assert not iv_iron_received.any(), "baseline IV iron"
    elif scenario == 'anemia_screening_and_iv_iron_scaleup':
        assert (
            (pop_after_later_pregnancy_intervention.tested_ferritin == 'low') ==
            iv_iron_received
        ).all(), "IV iron not fully scaled up, or given to people who did not test low ferritin"

    assert pop_after_screening[iv_iron_received].anc_attendance.isin(['first_trimester_and_later_pregnancy', 'later_pregnancy_only']).all()

    hemoglobin_exposure_after = sim.get_value('hemoglobin.exposure')(pop_after_later_pregnancy_intervention.index)
    assert (
        screening_step_hemoglobin_exposure[~oral_iron_changed & ~iv_iron_received] ==
        hemoglobin_exposure_after[~oral_iron_changed & ~iv_iron_received]
    ).all(), "Hemoglobin exposure changed at later intervention timestep for those not receiving later intervention"

    assert np.allclose(
        hemoglobin_exposure_after[oral_iron_changed & ~iv_iron_received] -
        screening_step_hemoglobin_exposure[oral_iron_changed & ~iv_iron_received],
        ifa_effect_hgb,
        rtol=0,
        atol=1e-13,
    ), "IFA effect does not match expectation from recalculation"

    assert np.allclose(
        hemoglobin_exposure_after[~oral_iron_changed & iv_iron_received] -
        screening_step_hemoglobin_exposure[~oral_iron_changed & iv_iron_received],
        iv_iron_effect_hgb,
        rtol=0,
        atol=1e-13,
    ), "IV iron effect does not match expectation from recalculation"

    # https://vivarium-research.readthedocs.io/en/latest/models/concept_models/vivarium_mncnh_portfolio/hemoglobin_module/module_document.html#end-of-pregnancy-hemoglobin-module
    assert np.allclose(
        hemoglobin_exposure_after[oral_iron_changed & iv_iron_received] -
        screening_step_hemoglobin_exposure[oral_iron_changed & iv_iron_received],
        iv_iron_effect_hgb,
        rtol=0,
        atol=1e-13,
    ), "IV iron effect combining with IFA effect (it shouldn't)"

    outcomes = ['maternal_hemorrhage', 'maternal_sepsis_and_other_maternal_infections']
    for outcome in outcomes:
        while sim_step_name() != outcome:
            sim.step()
        # now one more step to advance past
        sim.step()

        stepped_pop = sim.get_population()
        assert (sim.get_value('hemoglobin.exposure')(stepped_pop.index) == hemoglobin_exposure_after).all(), "Changes to hemoglobin after late-pregnancy ANC"
        stepped_pop = pd.concat([
            stepped_pop.drop(columns=['hemoglobin_exposure']),
            raw_hemoglobin,
            initial_hemoglobin,
            sim.get_value('hemoglobin.exposure')(stepped_pop.index).rename('hemoglobin_exposure'),
            sim.get_value(f'hemoglobin_on_{outcome}.relative_risk')(stepped_pop.index)
        ], axis=1)

        outcome_art_rrs = art_rrs.loc[outcome]
        assert np.allclose(
            np.interp(stepped_pop['hemoglobin_exposure'], outcome_art_rrs.index, outcome_art_rrs),
            stepped_pop[f'hemoglobin_on_{outcome}.relative_risk'],
            rtol=0,
            atol=1e-2,
        ), f"{outcome} relative risk doesn't match recalculation"

        # NOTE: This plot is redundant to the check above
        import matplotlib.pyplot as plt
        plt.scatter(stepped_pop['raw_hemoglobin'], stepped_pop[f'hemoglobin_on_{outcome}.relative_risk'], s = 20, label='raw hemoglobin exposure')
        plt.scatter(stepped_pop['hemoglobin_exposure'], stepped_pop[f'hemoglobin_on_{outcome}.relative_risk'], s = 20, label='intervention-modified hemoglobin exposure')
        plt.plot(outcome_art_rrs.index, outcome_art_rrs, label='artifact RR curve', color='tab:green')
        plt.legend()
        plt.xlabel('hemoglobin exposure (g/L)')
        plt.ylabel(f'relative risk of {outcome}')
        plt.title('Which hemoglobin exposure measure is affecting RR?')
        plt.show()

        # ok so now let's check if this actually translates to incidence

        # Bin hemoglobin exposure
        stepped_pop['hb_bin'] = pd.qcut(stepped_pop['hemoglobin_exposure'], 30)

        # Calculate rate of outcome == True in each bin
        sim_rate = stepped_pop.groupby('hb_bin')[outcome].value_counts(normalize=True).unstack().fillna(0)

        # Find the bin containing 120 (tmrel)
        tmrel_bin = [b for b in sim_rate.index if b.left <= tmrel < b.right][0]
        tmrel_rate = sim_rate.loc[tmrel_bin]

        art_rrs_by_bin = art_rrs.reset_index().assign(hb_bin=lambda df: pd.cut(df.parameter, stepped_pop['hb_bin'].cat.categories)).groupby('hb_bin')[draw].mean()
        art_rrs_by_bin.plot(kind='line', color='tab:orange', label='Artifact')

        plt.xlabel('Hemoglobin Exposure (binned)')
        plt.ylabel('Relative Rate vs. Bin Containing 120 g/L')
        plt.title(f'Relative Rate of {outcome} by Hemoglobin Exposure')
        (sim_rate[True] / tmrel_rate[True]).plot(kind='bar')
        plt.legend()
        plt.xticks(rotation=90)
        plt.grid()
        plt.show()

    stepped_pop = sim.get_population()
    assert (sim.get_value('hemoglobin.exposure')(stepped_pop.index) == hemoglobin_exposure_after).all(), "Changes to hemoglobin after late-pregnancy ANC"
    stepped_pop = pd.concat([
        stepped_pop.drop(columns=['hemoglobin_exposure']),
        raw_hemoglobin,
        initial_hemoglobin,
        sim.get_value('hemoglobin.exposure')(stepped_pop.index).rename('hemoglobin_exposure'),
    ], axis=1)
    return stepped_pop

In [None]:
baseline_final_pop = check_hemoglobin_in_interactive_sim('baseline')

In [None]:
# so now let's checkout the mms universe
mms_final_pop = check_hemoglobin_in_interactive_sim('mms_total_scaleup')

In [None]:
# and screening
# NOTE: This scenario has been removed
# screening_final_pop = check_hemoglobin_in_interactive_sim('anemia_screening_vv')

# Generally looks reasonable, but ferritin proportions among those with moderate and severe anemia look pretty wiggly

In [None]:
iv_iron_final_pop = check_hemoglobin_in_interactive_sim('anemia_screening_and_iv_iron_scaleup')

In [None]:
compare_population = mms_final_pop.join(baseline_final_pop, lsuffix='_mms_total_scaleup', rsuffix='_baseline')
compare_population

* CHECK: The same simulants attend ANC (in each ANC attendance category) in the baseline as in the MMS scale-up scenarios.
* CHECK: Simulants whose iron interventions did not change between scenarios, have the same final hemoglobin in the scenarios. Note it would be nice to check this for intermediate hemoglobin too.
* CHECK: Simulants who changed from IFA in baseline to MMS in the MMS scale-up scenario have the same final hemoglobin.
* CHECK: Simulants who changed from no oral iron in baseline to MMS in the MMS scale-up scenario have higher hemoglobin in the MMS scenario by the effect of IFA on hemoglobin.

In [None]:
compare_population['hemoglobin_diff'] = compare_population['hemoglobin_exposure_mms_total_scaleup'] - compare_population['hemoglobin_exposure_baseline']
assert (
    compare_population.anc_attendance_baseline
    ==
    compare_population.anc_attendance_mms_total_scaleup
).all(), "MMS scale up changing ANC attendance"
assert (
    compare_population[
        (compare_population.oral_iron_intervention_baseline == compare_population.oral_iron_intervention_mms_total_scaleup)
    ].hemoglobin_diff == 0
).all(), "Scenario differences where oral iron didn't change"
assert (compare_population[compare_population.oral_iron_intervention_baseline == 'ifa'].hemoglobin_diff == 0).all(), "IFA effect on hemoglobin different from MMS effect"
assert np.allclose(
    compare_population[
        (compare_population.oral_iron_intervention_baseline == 'no_treatment') &
        (compare_population.oral_iron_intervention_mms_total_scaleup == 'mms')
    ].hemoglobin_diff,
    ifa_effect_hgb
), "IFA effect on hemoglobin different from MMS effect"
compare_population.groupby('oral_iron_intervention_baseline')['hemoglobin_diff'].describe()