# HMB Intervention Package

In [13]:


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import sciris as sc
import ipywidgets as widgets
from IPython.display import display, clear_output
import os

# Set up directories
outfolder = 'results_extended/'
outfolder_stochastic = 'results_stochastic_extended/'
plotfolder_stochastic = 'figures_stochastic_extended/'


def set_font(size=None):
    """Set font for plots."""
    plt.rcParams['font.family'] = 'Arial'
    if size:
        plt.rcParams['font.size'] = size

def load_all_results():
    """Load the pre-computed results from the parameter sweep."""
    results_path = outfolder + 'uptake-sweep_results-stats.obj'
    if os.path.exists(results_path):
        return sc.loadobj(results_path)
    else:
        print(f"Results file not found at {results_path}")
        print("Please run the 'run_coverage_sweep' section of your main script first.")
        return None

def plot_interactive_results(stats_baseline, stats_scenario, scenario_name,
                             selected_outcomes, fixed_scale=False,
                             show_baseline=True):
    
    # Outcome labels mapping
    outcome_labels = {
        'hiud': 'hIUD Usage',
        'pill': 'Pill Usage',
        'hmb': 'HMB Prevalence',
        'poor_mh': 'Poor Menstrual Health',
        'anemic': 'Anemia',
        'pain': 'Menstrual Pain'
    }
    
    # Set up time vector (assuming 2000-2032)
    years_full = np.arange(2000, 2033)
    si = np.where(years_full >= 2020)[0][0]
    years = years_full[si:]
    
    # Colors
    color_baseline = '#6c757d'  # gray
    color_scenario = '#3c6e71'  # dark teal
    
    # Determine grid size
    n_outcomes = len(selected_outcomes)
    if n_outcomes == 0:
        print("Please select at least one outcome to visualize.")
        return
    
    n_cols = min(3, n_outcomes)
    n_rows = int(np.ceil(n_outcomes / n_cols))
    
    set_font(14)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5.5*n_cols, 4.5*n_rows))
    
    # Handle single subplot case
    if n_outcomes == 1:
        axes = np.array([axes])
    axes = axes.ravel()
    
    lw = 2.5

    # Loop over outcomes
    for i, res in enumerate(selected_outcomes):
        ax = axes[i]
        
        # Plot baseline if requested
        if show_baseline and stats_baseline is not None:
            mean_base = stats_baseline[res]['mean'][si:] * 100
            lower_base = stats_baseline[res]['lower'][si:] * 100
            upper_base = stats_baseline[res]['upper'][si:] * 100
            
            ax.plot(years, mean_base, label='Baseline', 
                   color=color_baseline, linewidth=lw, linestyle='--')
            ax.fill_between(years, lower_base, upper_base, 
                           color=color_baseline, alpha=0.2)
        
        # Plot scenario
        if stats_scenario is not None:
            mean_scen = stats_scenario[res]['mean'][si:] * 100
            lower_scen = stats_scenario[res]['lower'][si:] * 100
            upper_scen = stats_scenario[res]['upper'][si:] * 100
            
            ax.plot(years, mean_scen, label=scenario_name, 
                   color=color_scenario, linewidth=lw)
            ax.fill_between(years, lower_scen, upper_scen, 
                           color=color_scenario, alpha=0.2)
        
        # Add intervention start line
        ax.axvline(x=2026, color='k', ls='--', linewidth=1.5, alpha=0.7)
        
        # Add label for intervention start
        if i == 0:
            if fixed_scale:
                label_height = 90
            else:
                ax.autoscale()
                ylim = ax.get_ylim()
                label_height = ylim[0] + (ylim[1] - ylim[0]) * 0.85
            
            ax.text(2025.5, label_height, 'Intervention\nstart', 
                   ha='right', va='top', fontsize=9, color='#4d4d4d')
        
        ax.set_title(outcome_labels.get(res, res), fontsize=12, fontweight='bold')
        
        if fixed_scale:
            ax.set_ylim(0, 100)
        
        # Set x-axis to show whole number years
        ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
        
        # Remove top and right spines
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        
        if i % n_cols == 0:
            ax.set_ylabel('Prevalence (%)')
        
        if i >= (n_rows - 1) * n_cols:
            ax.set_xlabel('Year')
    
    # Hide unused subplots
    for j in range(n_outcomes, len(axes)):
        axes[j].set_visible(False)
    
    # Adjust layout to make room for legend on the right
    plt.tight_layout()
    plt.subplots_adjust(right=0.82)
    
    # Add legend outside the plot area on the right
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='center right', bbox_to_anchor=(0.98, 0.5),
               fontsize=11, frameon=True, fancybox=True)
    
    plt.suptitle(f'Scenario: {scenario_name}', fontsize=14, fontweight='bold', y=1.02)
    plt.show()
    
    return fig


# -- LOAD DATA 
all_results = load_all_results()

if all_results is not None:
    print(f"Loaded {len(all_results)} scenarios")
    print(f"Available scenarios: {list(all_results.keys())}")
else:
    print("Could not load results. Please check the file path.")


# -- MAKE VIZ

# Define parameter values
prob_offer_values = [0.25, 0.5, 0.75]
prob_accept_values = [0.25, 0.5, 0.75]

# All available outcomes
all_outcomes = ['hiud', 'pill', 'hmb', 'poor_mh', 'anemic', 'pain']
outcome_labels_full = {
    'hiud': 'hIUD Usage',
    'pill': 'Pill Usage', 
    'hmb': 'HMB Prevalence',
    'poor_mh': 'Poor Menstrual Health',
    'anemic': 'Anemia',
    'pain': 'Menstrual Pain'
}

# Create radio button widgets for parameter selection
prob_offer_radio = widgets.RadioButtons(
    options=[('25%', 0.25), ('50%', 0.5), ('75%', 0.75)],
    value=0.5,
    description='Prob Offer:',
    style={'description_width': '100px'},
    layout=widgets.Layout(width='200px')
)

prob_accept_radio = widgets.RadioButtons(
    options=[('25%', 0.25), ('50%', 0.5), ('75%', 0.75)],
    value=0.5,
    description='Prob Accept:',
    style={'description_width': '100px'},
    layout=widgets.Layout(width='200px')
)

# Create checkboxes for y-axis scale
scale_radio = widgets.RadioButtons(
    options=[('Variable Scale', False), ('Fixed 0-100%', True)],
    value=False,
    description='Y-Axis:',
    style={'description_width': '100px'},
    layout=widgets.Layout(width='200px')
)

show_baseline_checkbox = widgets.Checkbox(
    value=True,
    description='Show Baseline',
    style={'description_width': '100px'}
)

# Create individual checkboxes for outcomes
outcome_checkbox_hiud = widgets.Checkbox(value=False, description='hIUD Usage', 
                                          style={'description_width': 'initial'})
outcome_checkbox_pill = widgets.Checkbox(value=False, description='Pill Usage',
                                          style={'description_width': 'initial'})
outcome_checkbox_hmb = widgets.Checkbox(value=True, description='HMB Prevalence',
                                         style={'description_width': 'initial'})
outcome_checkbox_poor_mh = widgets.Checkbox(value=True, description='Poor Menstrual Health',
                                             style={'description_width': 'initial'})
outcome_checkbox_anemic = widgets.Checkbox(value=True, description='Anemia',
                                            style={'description_width': 'initial'})
outcome_checkbox_pain = widgets.Checkbox(value=True, description='Menstrual Pain',
                                          style={'description_width': 'initial'})

# Dictionary to map checkbox to outcome key
outcome_checkboxes = {
    'hiud': outcome_checkbox_hiud,
    'pill': outcome_checkbox_pill,
    'hmb': outcome_checkbox_hmb,
    'poor_mh': outcome_checkbox_poor_mh,
    'anemic': outcome_checkbox_anemic,
    'pain': outcome_checkbox_pain
}

# Update button
update_btn = widgets.Button(
    description='Update Plot',
    button_style='success',
    layout=widgets.Layout(width='150px', height='40px')
)

# Output area
output = widgets.Output()

# Main update function
def update_plot(b=None):
    with output:
        clear_output(wait=True)
        
        if all_results is None:
            print("Results not loaded. Please run the data loading cell first.")
            return
        
        prob_offer = prob_offer_radio.value
        prob_accept = prob_accept_radio.value
        fixed_scale = scale_radio.value
        show_baseline = show_baseline_checkbox.value
        
        # Get selected outcomes from checkboxes
        selected_outcomes = [key for key, cb in outcome_checkboxes.items() if cb.value]
        
        if len(selected_outcomes) == 0:
            print("Please select at least one outcome to visualize.")
            return
        
        # Get scenario name
        scenario_name = f'scenario_offer-{prob_offer*100}_accept-{prob_accept*100}'
        display_name = f'Offer: {int(prob_offer*100)}%, Accept: {int(prob_accept*100)}%'
        
        # Get stats
        stats_baseline = all_results.get('baseline', None)
        stats_scenario = all_results.get(scenario_name, None)
        
        if stats_scenario is None:
            print(f"Scenario '{scenario_name}' not found in results.")
            print(f"Available scenarios: {list(all_results.keys())}")
            return
        
        # Create plot
        plot_interactive_results(
            stats_baseline=stats_baseline,
            stats_scenario=stats_scenario,
            scenario_name=display_name,
            selected_outcomes=selected_outcomes,
            fixed_scale=fixed_scale,
            show_baseline=show_baseline
        )
        
        # Print summary statistics
        print("\n" + "="*60)
        print(f"Summary for {display_name}")
        print("="*60)
        
        # Calculate and display % change from baseline
        intervention_idx = 26  # Index for 2026 in full array starting from 2000
        
        for res in selected_outcomes:
            if stats_baseline is not None:
                baseline_post = np.mean(stats_baseline[res]['mean'][intervention_idx:])
                scenario_post = np.mean(stats_scenario[res]['mean'][intervention_idx:])
                
                if baseline_post != 0:
                    pct_change = ((scenario_post - baseline_post) / baseline_post) * 100
                    direction = "decrease" if pct_change < 0 else "increase"
                    print(f"  {outcome_labels_full[res]:25s}: {pct_change:+.1f}% ({direction})")

update_btn.on_click(update_plot)


# Create layout
param_box = widgets.VBox([
    widgets.HTML("<h3>Parameter Selection</h3>"),
    widgets.HBox([prob_offer_radio, prob_accept_radio]),
], layout=widgets.Layout(padding='10px', border='1px solid #ddd', margin='5px'))

outcome_box = widgets.VBox([
    widgets.HTML("<h3>Outcomes to Display</h3>"),
    outcome_checkbox_hiud,
    outcome_checkbox_pill,
    outcome_checkbox_hmb,
    outcome_checkbox_poor_mh,
    outcome_checkbox_anemic,
    outcome_checkbox_pain,
], layout=widgets.Layout(padding='10px', border='1px solid #ddd', margin='5px'))

display_box = widgets.VBox([
    widgets.HTML("<h3>Display Options</h3>"),
    scale_radio,
    show_baseline_checkbox,
], layout=widgets.Layout(padding='10px', border='1px solid #ddd', margin='5px'))

controls = widgets.HBox([param_box, outcome_box, display_box])

# Full interface
interface = widgets.VBox([
    widgets.HTML("""
    <h2>HMB Intervention Interactive Analysis</h2>
    <p>Use the controls below to explore different parameter combinations and their effects on health outcomes.</p>
    <hr>
    """),
    controls,
    widgets.HBox([update_btn], layout=widgets.Layout(justify_content='center', padding='10px')),
    output
])

display(interface)

# Initial plot
update_plot()

# -- Multiple Scenarios

# Create widgets for comparison view
comparison_output = widgets.Output()

# Create individual checkboxes for scenarios - updated defaults
scenario_checkbox_baseline = widgets.Checkbox(value=True, description='Baseline',
                                               style={'description_width': 'initial'})
scenario_checkbox_25_25 = widgets.Checkbox(value=False, description='Offer 25%, Accept 25%',
                                            style={'description_width': 'initial'})
scenario_checkbox_25_50 = widgets.Checkbox(value=False, description='Offer 25%, Accept 50%',
                                            style={'description_width': 'initial'})
scenario_checkbox_25_75 = widgets.Checkbox(value=True, description='Offer 25%, Accept 75%',
                                            style={'description_width': 'initial'})
scenario_checkbox_50_25 = widgets.Checkbox(value=False, description='Offer 50%, Accept 25%',
                                            style={'description_width': 'initial'})
scenario_checkbox_50_50 = widgets.Checkbox(value=False, description='Offer 50%, Accept 50%',
                                            style={'description_width': 'initial'})
scenario_checkbox_50_75 = widgets.Checkbox(value=True, description='Offer 50%, Accept 75%',
                                            style={'description_width': 'initial'})
scenario_checkbox_75_25 = widgets.Checkbox(value=False, description='Offer 75%, Accept 25%',
                                            style={'description_width': 'initial'})
scenario_checkbox_75_50 = widgets.Checkbox(value=False, description='Offer 75%, Accept 50%',
                                            style={'description_width': 'initial'})
scenario_checkbox_75_75 = widgets.Checkbox(value=False, description='Offer 75%, Accept 75%',
                                            style={'description_width': 'initial'})

# Dictionary to map checkbox to scenario key
scenario_checkboxes = {
    'baseline': scenario_checkbox_baseline,
    'scenario_offer-25.0_accept-25.0': scenario_checkbox_25_25,
    'scenario_offer-25.0_accept-50.0': scenario_checkbox_25_50,
    'scenario_offer-25.0_accept-75.0': scenario_checkbox_25_75,
    'scenario_offer-50.0_accept-25.0': scenario_checkbox_50_25,
    'scenario_offer-50.0_accept-50.0': scenario_checkbox_50_50,
    'scenario_offer-50.0_accept-75.0': scenario_checkbox_50_75,
    'scenario_offer-75.0_accept-25.0': scenario_checkbox_75_25,
    'scenario_offer-75.0_accept-50.0': scenario_checkbox_75_50,
    'scenario_offer-75.0_accept-75.0': scenario_checkbox_75_75,
}

# Select All / Deselect All buttons
select_all_scenarios_btn = widgets.Button(
    description='Select All',
    button_style='primary',
    layout=widgets.Layout(width='100px', height='30px')
)

deselect_all_scenarios_btn = widgets.Button(
    description='Deselect All',
    button_style='warning',
    layout=widgets.Layout(width='100px', height='30px')
)

def select_all_scenarios(b):
    for cb in scenario_checkboxes.values():
        cb.value = True

def deselect_all_scenarios(b):
    for cb in scenario_checkboxes.values():
        cb.value = False

select_all_scenarios_btn.on_click(select_all_scenarios)
deselect_all_scenarios_btn.on_click(deselect_all_scenarios)

outcome_dropdown = widgets.Dropdown(
    options=[(outcome_labels_full[k], k) for k in all_outcomes],
    value='hmb',
    description='Outcome:',
    style={'description_width': '100px'}
)

scale_radio_compare = widgets.RadioButtons(
    options=[('Variable Scale', False), ('Fixed 0-100%', True)],
    value=False,
    description='Y-Axis:',
    style={'description_width': '100px'},
    layout=widgets.Layout(width='200px')
)

compare_btn = widgets.Button(
    description='Compare Scenarios',
    button_style='success',
    layout=widgets.Layout(width='180px', height='40px')
)

def compare_scenarios(b):
    with comparison_output:
        clear_output(wait=True)
        
        if all_results is None:
            print("Results not loaded.")
            return
        
        # Get selected scenarios from checkboxes
        selected_scenarios = [key for key, cb in scenario_checkboxes.items() if cb.value]
        outcome = outcome_dropdown.value
        fixed_scale = scale_radio_compare.value
        
        if len(selected_scenarios) == 0:
            print("Please select at least one scenario.")
            return
        
        # Set up time vector
        years_full = np.arange(2000, 2033)
        si = np.where(years_full >= 2020)[0][0]
        years = years_full[si:]
        
        # Rainbow color palette
        rainbow_colors = [
            '#e6194b',  # red
            '#f58231',  # orange
            '#ffe119',  # yellow
            '#3cb44b',  # green
            '#42d4f4',  # cyan
            '#4363d8',  # blue
            '#911eb4',  # purple
            '#f032e6',  # magenta
            '#a9a9a9',  # gray
            '#800000',  # maroon
        ]
        
        n_scenarios = len(selected_scenarios)
        colors = []
        color_idx = 0
        for i, scenario in enumerate(selected_scenarios):
            if scenario == 'baseline':
                colors.append('#6c757d')  # gray for baseline
            else:
                colors.append(rainbow_colors[color_idx % len(rainbow_colors)])
                color_idx += 1
        
        set_font(14)
        fig, ax = plt.subplots(figsize=(14, 6))
        
        for i, scenario in enumerate(selected_scenarios):
            if scenario not in all_results:
                continue
            
            stats = all_results[scenario]
            mean = stats[outcome]['mean'][si:] * 100
            lower = stats[outcome]['lower'][si:] * 100
            upper = stats[outcome]['upper'][si:] * 100
            
            # Get display name
            if scenario == 'baseline':
                label = 'Baseline'
                linestyle = '--'
            else:
                parts = scenario.replace('scenario_offer-', '').replace('_accept-', ', ').split(', ')
                # Convert to integers for display
                offer_val = int(float(parts[0]))
                accept_val = int(float(parts[1]))
                label = f'Offer {offer_val}%, Accept {accept_val}%'
                linestyle = '-'
            
            ax.plot(years, mean, label=label, color=colors[i], linewidth=2.5, linestyle=linestyle)
            ax.fill_between(years, lower, upper, color=colors[i], alpha=0.15)
        
        ax.axvline(x=2026, color='k', ls='--', linewidth=1.5, alpha=0.7)
        
        # Get y-axis limits for text placement
        if fixed_scale:
            text_height = 90
        else:
            ax.autoscale()
            ylim = ax.get_ylim()
            text_height = ylim[0] + (ylim[1] - ylim[0]) * 0.9
        
        ax.text(2025.5, text_height, 'Intervention\nstart', 
               ha='right', va='top', fontsize=10, color='#4d4d4d')
        
        ax.set_title(f'{outcome_labels_full[outcome]} - Scenario Comparison', 
                    fontsize=14, fontweight='bold')
        ax.set_xlabel('Year')
        ax.set_ylabel('Prevalence (%)')
        
        if fixed_scale:
            ax.set_ylim(0, 100)
        
        # Set x-axis to show whole number years
        ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
        
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        
        # Place legend outside on the right
        ax.legend(loc='center left', bbox_to_anchor=(1.02, 0.5), frameon=True, fancybox=True)
        
        plt.tight_layout()
        plt.subplots_adjust(right=0.72)
        plt.show()

compare_btn.on_click(compare_scenarios)

# Layout for comparison view 
scenario_box = widgets.VBox([
    widgets.HTML("<b>Select Scenarios:</b>"),
    widgets.HBox([select_all_scenarios_btn, deselect_all_scenarios_btn]),
    scenario_checkbox_baseline,
    scenario_checkbox_25_25,
    scenario_checkbox_25_50,
    scenario_checkbox_25_75,
    scenario_checkbox_50_25,
    scenario_checkbox_50_50,
    scenario_checkbox_50_75,
    scenario_checkbox_75_25,
    scenario_checkbox_75_50,
    scenario_checkbox_75_75,
], layout=widgets.Layout(padding='10px', border='1px solid #ddd', margin='5px', width='280px'))

options_box = widgets.VBox([
    widgets.HTML("<b>Options:</b>"),
    outcome_dropdown,
    scale_radio_compare,
    compare_btn
], layout=widgets.Layout(padding='10px', border='1px solid #ddd', margin='5px 5px 5px 50px', width='250px'))

comparison_interface = widgets.VBox([
    widgets.HTML("""
    <h2>Multi-Scenario Comparison</h2>
    <p>Compare multiple parameter combinations for a single outcome.</p>
    <hr>
    """),
    widgets.HBox([scenario_box, options_box]),
    comparison_output
])

display(comparison_interface)

# -- Summary Table 

table_output = widgets.Output()

def generate_summary_table(b=None):
    with table_output:
        clear_output(wait=True)
        
        if all_results is None:
            print("Results not loaded.")
            return
        
        prob_offer_values = [0.25, 0.5, 0.75]
        prob_accept_values = [0.25, 0.5, 0.75]
        outcomes = ['hmb', 'poor_mh', 'anemic', 'pain']
        
        intervention_idx = 26
        
        # Calculate baseline averages
        baseline_avgs = {}
        for outcome in outcomes:
            baseline_avgs[outcome] = np.mean(all_results['baseline'][outcome]['mean'][intervention_idx:])
        
        # Build summary data
        summary_data = []
        
        for prob_offer in prob_offer_values:
            for prob_accept in prob_accept_values:
                scenario_name = f'scenario_offer-{prob_offer*100}_accept-{prob_accept*100}'
                
                if scenario_name in all_results:
                    row = {
                        'Prob Offer': f'{int(prob_offer*100)}%',
                        'Prob Accept': f'{int(prob_accept*100)}%'
                    }
                    
                    for outcome in outcomes:
                        scenario_avg = np.mean(all_results[scenario_name][outcome]['mean'][intervention_idx:])
                        pct_change = ((scenario_avg - baseline_avgs[outcome]) / baseline_avgs[outcome]) * 100
                        row[outcome_labels_full[outcome]] = f'{pct_change:+.1f}%'
                    
                    summary_data.append(row)
        
        df = pd.DataFrame(summary_data)
        
        # Style the dataframe
        def style_cells(val, column):
            if column in ['Prob Offer', 'Prob Accept']:
                return 'color: #3c6e71; font-weight: bold'
            elif isinstance(val, str) and '%' in val:
                num = float(val.replace('%', '').replace('+', ''))
                if num < 0:
                    return 'color: green; font-weight: bold'
                elif num > 0:
                    return 'color: red'
            return ''
        
        # Apply styling column by column
        styled_df = df.style.apply(lambda col: [style_cells(v, col.name) for v in col], axis=0)
        
        print("Summary: % Change from Baseline (Post-Intervention Average)")
        print("="*80)
        print("Statistics calculated from start of intervention (2026) to end of simulation (2032)")
        print("Green = Reduction (Improvement), Red = Increase")
        print()
        display(styled_df)

summary_btn = widgets.Button(
    description='Generate Summary Table',
    button_style='info',
    layout=widgets.Layout(width='200px', height='40px')
)

summary_btn.on_click(generate_summary_table)

summary_interface = widgets.VBox([
    widgets.HTML("""
    <h2>Summary Statistics Table</h2>
    <p>View percentage change from baseline for all parameter combinations.<br>
    Statistics are calculated as the average from the start of intervention (2026) to the end of simulation (2032).</p>
    <hr>
    """),
    summary_btn,
    table_output
])

display(summary_interface)

# Generate initial table
generate_summary_table()

Loaded 10 scenarios
Available scenarios: ['baseline', 'scenario_offer-25.0_accept-25.0', 'scenario_offer-25.0_accept-50.0', 'scenario_offer-25.0_accept-75.0', 'scenario_offer-50.0_accept-25.0', 'scenario_offer-50.0_accept-50.0', 'scenario_offer-50.0_accept-75.0', 'scenario_offer-75.0_accept-25.0', 'scenario_offer-75.0_accept-50.0', 'scenario_offer-75.0_accept-75.0']


VBox(children=(HTML(value='\n    <h2>HMB Intervention Interactive Analysis</h2>\n    <p>Use the controls below…

VBox(children=(HTML(value='\n    <h2>Multi-Scenario Comparison</h2>\n    <p>Compare multiple parameter combina…

VBox(children=(HTML(value='\n    <h2>Summary Statistics Table</h2>\n    <p>View percentage change from baselin…