# Combined Sensitivity Analysis

12-panel figure comparing YLL sensitivity (left column), GHG price sensitivity (middle column), and combined GHG+YLL sensitivity (right column).

- Row 1: Food consumption (kcal/person/day)
- Row 2: GHG emissions by food group (GtCO2eq)
- Row 3: Health cost by food group (million YLL)
- Row 4: Objective breakdown (billion USD)

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
from sensitivity_utils import (
    FONTSIZE_AXIS_LABEL,
    FONTSIZE_PANEL_LABEL,
    FONTSIZE_TICK_LABEL,
    FONTSIZE_TITLE,
    PRETTY_NAMES_HEALTH,
    aggregate_food_groups,
    assign_food_colors,
    extract_combined_scenarios,
    extract_consumption_data,
    extract_ghg_data,
    extract_health_data,
    extract_objective_data,
    extract_scenarios_with_param,
    get_log_ticks,
    load_food_to_group,
    log_scale_zero_position,
    plot_objective_sensitivity,
    plot_stacked_sensitivity,
    prepare_objective_data,
    set_dual_xaxis_labels,
    set_dual_xlabel,
)

In [None]:
# Configuration
PROJECT_ROOT = Path("..").resolve()
CACHE_DIR = Path("cache")  # Relative to notebooks/

# YLL config
YLL_CONFIG_NAME = "yll"
YLL_RESULTS_DIR = PROJECT_ROOT / "results" / YLL_CONFIG_NAME
YLL_PROCESSING_DIR = PROJECT_ROOT / "processing" / YLL_CONFIG_NAME
YLL_CACHE_DIR = CACHE_DIR / YLL_CONFIG_NAME

# GHG config
GHG_CONFIG_NAME = "ghg"
GHG_RESULTS_DIR = PROJECT_ROOT / "results" / GHG_CONFIG_NAME
GHG_PROCESSING_DIR = PROJECT_ROOT / "processing" / GHG_CONFIG_NAME
GHG_CACHE_DIR = CACHE_DIR / GHG_CONFIG_NAME

# Combined GHG+YLL config
GHG_YLL_CONFIG_NAME = "ghg_yll"
GHG_YLL_RESULTS_DIR = PROJECT_ROOT / "results" / GHG_YLL_CONFIG_NAME
GHG_YLL_PROCESSING_DIR = PROJECT_ROOT / "processing" / GHG_YLL_CONFIG_NAME
GHG_YLL_CACHE_DIR = CACHE_DIR / GHG_YLL_CONFIG_NAME

# Load food to group mapping
FOOD_TO_GROUP = load_food_to_group(PROJECT_ROOT)

# Constants
CONSTANT_HEALTH_VALUE_PER_YLL = 10000
CONSTANT_GHG_PRICE = 100
N_WORKERS = 8

## Load YLL Data

In [None]:
# Extract YLL scenarios from config
yll_scenarios = extract_scenarios_with_param(
    PROJECT_ROOT,
    YLL_CONFIG_NAME,
    param_path=["health", "value_per_yll"],
    scenario_prefix="yll_",
)

# Filter to only include scenarios with existing network files
yll_scenarios = [(p, s, f) for p, s, f in yll_scenarios if f.exists()]

print(f"Found {len(yll_scenarios)} YLL scenarios")
yll_param_values = [p for p, _, _ in yll_scenarios]
print(f"YLL values: {yll_param_values}")

In [None]:
# Extract YLL consumption data
df_yll_consumption = extract_consumption_data(
    yll_scenarios,
    FOOD_TO_GROUP,
    YLL_CACHE_DIR / "consumption.csv",
    param_name="yll_value",
    n_workers=N_WORKERS,
)

# Aggregate and prepare for plotting
df_yll_consumption_plot = aggregate_food_groups(df_yll_consumption)
min_yll = df_yll_consumption_plot.index.min()
yll_group_order = (
    df_yll_consumption_plot.loc[min_yll].sort_values(ascending=False).index.tolist()
)
df_yll_consumption_plot = df_yll_consumption_plot[yll_group_order]
yll_colors = assign_food_colors(df_yll_consumption_plot)

print(f"YLL consumption data shape: {df_yll_consumption_plot.shape}")

In [None]:
# Extract YLL objective data
df_yll_obj = extract_objective_data(
    yll_scenarios,
    YLL_CACHE_DIR / "objective_breakdown.csv",
    param_name="yll_value",
    constant_health_value=CONSTANT_HEALTH_VALUE_PER_YLL,
    constant_ghg_price=CONSTANT_GHG_PRICE,
    n_workers=N_WORKERS,
)
df_yll_obj = prepare_objective_data(df_yll_obj)
print(f"YLL objective data shape: {df_yll_obj.shape}")

In [None]:
# Extract YLL GHG emissions data
df_yll_ghg = extract_ghg_data(
    yll_scenarios,
    FOOD_TO_GROUP,
    YLL_CACHE_DIR / "ghg_by_food_group.csv",
    param_name="yll_value",
    n_workers=N_WORKERS,
)

# Aggregate and use same order as consumption plot
df_yll_ghg_plot = aggregate_food_groups(df_yll_ghg)
available_groups = [g for g in yll_group_order if g in df_yll_ghg_plot.columns]
df_yll_ghg_plot = df_yll_ghg_plot[available_groups]

print(f"YLL GHG data shape: {df_yll_ghg_plot.shape}")

In [None]:
# Extract YLL health cost data
df_yll_health = extract_health_data(
    yll_scenarios,
    YLL_PROCESSING_DIR,
    YLL_CACHE_DIR / "health_by_food_group.csv",
    param_name="yll_value",
    n_workers=N_WORKERS,
)

# Aggregate fruits+vegetables to match other panels
df_yll_health_plot = aggregate_food_groups(df_yll_health)

# Use available groups that match the health risk factors
yll_health_groups = [g for g in yll_group_order if g in df_yll_health_plot.columns]
df_yll_health_plot = df_yll_health_plot[yll_health_groups]

print(f"YLL health data shape: {df_yll_health_plot.shape}")

## Load GHG Data

In [None]:
# Extract GHG scenarios from config
ghg_scenarios = extract_scenarios_with_param(
    PROJECT_ROOT,
    GHG_CONFIG_NAME,
    param_path=["emissions", "ghg_price"],
    scenario_prefix="ghg_",
)

# Filter to only include scenarios with existing network files
ghg_scenarios = [(p, s, f) for p, s, f in ghg_scenarios if f.exists()]

print(f"Found {len(ghg_scenarios)} GHG scenarios")
ghg_param_values = [p for p, _, _ in ghg_scenarios]
print(f"GHG prices: {ghg_param_values}")

In [None]:
# Extract GHG consumption data
df_ghg_consumption = extract_consumption_data(
    ghg_scenarios,
    FOOD_TO_GROUP,
    GHG_CACHE_DIR / "consumption.csv",
    param_name="ghg_price",
    n_workers=N_WORKERS,
)

# Aggregate and use same order as YLL plots (for consistent colors across columns)
df_ghg_consumption_plot = aggregate_food_groups(df_ghg_consumption)
available_groups_ghg = [
    g for g in yll_group_order if g in df_ghg_consumption_plot.columns
]
df_ghg_consumption_plot = df_ghg_consumption_plot[available_groups_ghg]

print(f"GHG consumption data shape: {df_ghg_consumption_plot.shape}")

In [None]:
# Extract GHG objective data
df_ghg_obj = extract_objective_data(
    ghg_scenarios,
    GHG_CACHE_DIR / "objective_breakdown.csv",
    param_name="ghg_price",
    constant_health_value=CONSTANT_HEALTH_VALUE_PER_YLL,
    constant_ghg_price=CONSTANT_GHG_PRICE,
    n_workers=N_WORKERS,
)
df_ghg_obj = prepare_objective_data(df_ghg_obj)
print(f"GHG objective data shape: {df_ghg_obj.shape}")

In [None]:
# Extract GHG GHG emissions data
df_ghg_ghg = extract_ghg_data(
    ghg_scenarios,
    FOOD_TO_GROUP,
    GHG_CACHE_DIR / "ghg_by_food_group.csv",
    param_name="ghg_price",
    n_workers=N_WORKERS,
)

# Aggregate and use same order as YLL plots (for consistent colors across columns)
df_ghg_ghg_plot = aggregate_food_groups(df_ghg_ghg)
available_groups_ghg_ghg = [g for g in yll_group_order if g in df_ghg_ghg_plot.columns]
df_ghg_ghg_plot = df_ghg_ghg_plot[available_groups_ghg_ghg]

print(f"GHG GHG data shape: {df_ghg_ghg_plot.shape}")

In [None]:
# Extract GHG health cost data
df_ghg_health = extract_health_data(
    ghg_scenarios,
    GHG_PROCESSING_DIR,
    GHG_CACHE_DIR / "health_by_food_group.csv",
    param_name="ghg_price",
    n_workers=N_WORKERS,
)

# Aggregate fruits+vegetables to match other panels
df_ghg_health_plot = aggregate_food_groups(df_ghg_health)

# Use available groups that match YLL health data (for consistent colors)
ghg_health_groups = [g for g in yll_health_groups if g in df_ghg_health_plot.columns]
df_ghg_health_plot = df_ghg_health_plot[ghg_health_groups]

print(f"GHG health data shape: {df_ghg_health_plot.shape}")

## Load GHG+YLL Data (Combined)

In [None]:
# Extract combined GHG+YLL scenarios from config
ghg_yll_scenarios_all = extract_combined_scenarios(
    PROJECT_ROOT,
    GHG_YLL_CONFIG_NAME,
    ghg_param_path=["emissions", "ghg_price"],
    yll_param_path=["health", "value_per_yll"],
    scenario_prefix="ghg_yll_",
)

# Filter to only include scenarios with existing network files
# Also exclude experimental scenarios (e.g., "break15" variants)
ghg_yll_scenarios_full = [
    (ghg, yll, s, f)
    for ghg, yll, s, f in ghg_yll_scenarios_all
    if f.exists() and "break" not in s
]

# Convert to format expected by extraction functions: (ghg_price, scenario_name, path)
# Use ghg_price as the primary index for plotting
ghg_yll_scenarios = [(ghg, s, f) for ghg, yll, s, f in ghg_yll_scenarios_full]

# Keep track of yll values for dual x-axis labels
ghg_yll_param_pairs = [(ghg, yll) for ghg, yll, s, f in ghg_yll_scenarios_full]

print(f"Found {len(ghg_yll_scenarios)} GHG_YLL scenarios")
ghg_yll_ghg_values = [ghg for ghg, _ in ghg_yll_param_pairs]
ghg_yll_yll_values = [yll for _, yll in ghg_yll_param_pairs]
print(f"GHG prices: {ghg_yll_ghg_values}")
print(f"YLL values: {ghg_yll_yll_values}")

In [None]:
# Extract GHG_YLL consumption data
df_ghg_yll_consumption = extract_consumption_data(
    ghg_yll_scenarios,
    FOOD_TO_GROUP,
    GHG_YLL_CACHE_DIR / "consumption.csv",
    param_name="ghg_price",
    n_workers=N_WORKERS,
)

# Aggregate and use same order as YLL plots (for consistent colors)
df_ghg_yll_consumption_plot = aggregate_food_groups(df_ghg_yll_consumption)
available_groups_ghg_yll = [
    g for g in yll_group_order if g in df_ghg_yll_consumption_plot.columns
]
df_ghg_yll_consumption_plot = df_ghg_yll_consumption_plot[available_groups_ghg_yll]

print(f"GHG_YLL consumption data shape: {df_ghg_yll_consumption_plot.shape}")

In [None]:
# Extract GHG_YLL objective data
# For combined sensitivity, use actual values from scenarios (varying with ghg_price)
df_ghg_yll_obj = extract_objective_data(
    ghg_yll_scenarios,
    GHG_YLL_CACHE_DIR / "objective_breakdown.csv",
    param_name="ghg_price",
    constant_health_value=CONSTANT_HEALTH_VALUE_PER_YLL,
    constant_ghg_price=CONSTANT_GHG_PRICE,
    n_workers=N_WORKERS,
)
df_ghg_yll_obj = prepare_objective_data(df_ghg_yll_obj)
print(f"GHG_YLL objective data shape: {df_ghg_yll_obj.shape}")

In [None]:
# Extract GHG_YLL GHG emissions data
df_ghg_yll_ghg = extract_ghg_data(
    ghg_yll_scenarios,
    FOOD_TO_GROUP,
    GHG_YLL_CACHE_DIR / "ghg_by_food_group.csv",
    param_name="ghg_price",
    n_workers=N_WORKERS,
)

# Aggregate and use same order as YLL plots
df_ghg_yll_ghg_plot = aggregate_food_groups(df_ghg_yll_ghg)
available_groups_ghg_yll_ghg = [
    g for g in yll_group_order if g in df_ghg_yll_ghg_plot.columns
]
df_ghg_yll_ghg_plot = df_ghg_yll_ghg_plot[available_groups_ghg_yll_ghg]

print(f"GHG_YLL GHG data shape: {df_ghg_yll_ghg_plot.shape}")

In [None]:
# Extract GHG_YLL health cost data
df_ghg_yll_health = extract_health_data(
    ghg_yll_scenarios,
    GHG_YLL_PROCESSING_DIR,
    GHG_YLL_CACHE_DIR / "health_by_food_group.csv",
    param_name="ghg_price",
    n_workers=N_WORKERS,
)

# Aggregate fruits+vegetables to match other panels
df_ghg_yll_health_plot = aggregate_food_groups(df_ghg_yll_health)

# Use available groups that match YLL health data (for consistent colors)
ghg_yll_health_groups = [
    g for g in yll_health_groups if g in df_ghg_yll_health_plot.columns
]
df_ghg_yll_health_plot = df_ghg_yll_health_plot[ghg_yll_health_groups]

print(f"GHG_YLL health data shape: {df_ghg_yll_health_plot.shape}")

## Combined 12-Panel Figure

In [None]:
# X-axis configuration - derived from scenarios
YLL_XTICKS, YLL_XTICKLABELS = get_log_ticks(yll_param_values)
YLL_XLABEL = "Value per Year of Life Lost [USD/YLL]"

GHG_XTICKS, GHG_XTICKLABELS = get_log_ticks(ghg_param_values)
GHG_XLABEL = "GHG price [USD/tCO2eq]"

# Combined GHG+YLL x-axis (both parameters vary together)
GHG_YLL_XTICKS, GHG_YLL_XTICKLABELS = get_log_ticks(ghg_yll_ghg_values)
# Map tick positions to corresponding YLL values for dual axis
tick_to_yll = dict(zip(ghg_yll_ghg_values, ghg_yll_yll_values))
GHG_YLL_GHG_VALUES = GHG_YLL_XTICKS.copy()
GHG_YLL_YLL_VALUES = []
for tick in GHG_YLL_XTICKS:
    # Find closest GHG value and get its YLL
    if tick in tick_to_yll:
        GHG_YLL_YLL_VALUES.append(tick_to_yll[tick])
    elif tick == 1 and 0 in tick_to_yll:
        # Handle 0 mapped to 1 for log scale
        GHG_YLL_YLL_VALUES.append(tick_to_yll[0])
    else:
        # Interpolate based on ratio from nearest values
        nonzero_pairs = [
            (g, y) for g, y in zip(ghg_yll_ghg_values, ghg_yll_yll_values) if g > 0
        ]
        if nonzero_pairs:
            ratio = nonzero_pairs[0][1] / nonzero_pairs[0][0]
            GHG_YLL_YLL_VALUES.append(tick * ratio)
        else:
            GHG_YLL_YLL_VALUES.append(0)

print(f"YLL ticks: {YLL_XTICKS} -> {YLL_XTICKLABELS}")
print(f"GHG ticks: {GHG_XTICKS} -> {GHG_XTICKLABELS}")
print(
    f"GHG_YLL ticks: {GHG_YLL_XTICKS} -> GHG {GHG_YLL_GHG_VALUES}, YLL {GHG_YLL_YLL_VALUES}"
)

# Manual label positions for YLL plots
YLL_CONSUMPTION_LABEL_X = {
    "grain": 7,
    "dairy": 30,
    "starchy_vegetable": 5,
    "legumes": 3000,
    "oil": 5,
    "red_meat": 3,
    "sugar": 10,
    "nuts_seeds": 10000,
    "whole_grains": 1000,
    "fruits_vegetables": 300,
    "eggs_poultry": 30,
}

YLL_GHG_LABEL_X = {
    "red_meat": 3,
    "dairy": 30,
    "grain": 7,
    "oil": 10,
    "starchy_vegetable": 5,
    "legumes": 3000,
    "sugar": 5,
    "nuts_seeds": 30000,
    "whole_grains": 500,
    "fruits_vegetables": 3000,
    "eggs_poultry": 30,
}

YLL_OBJ_LABEL_X = {
    "Crop production": 50000,
    "Health burden": 10,
    "GHG cost": 3,
    "Trade": 100000,
    "Consumer values": 5000,
}

# Manual label positions for GHG plots
GHG_CONSUMPTION_LABEL_X = {
    "eggs_poultry": 8,
}

GHG_GHG_LABEL_SKIP = {"legumes", "fruits_vegetables"}

GHG_HEALTH_LABEL_SKIP = {"legumes"}

# GHG_YLL plots - skip small groups
GHG_YLL_LABEL_SKIP = {"legumes"}

In [None]:
# Configuration option
INCLUDE_OBJECTIVE_ROW = True  # Set to False to hide the objective breakdown row

# Create multipanel figure with shared axes within rows
n_rows = 4 if INCLUDE_OBJECTIVE_ROW else 3
fig_height = 9.5 if INCLUDE_OBJECTIVE_ROW else 7.5
fig, axes = plt.subplots(n_rows, 3, figsize=(10.5, fig_height))

# Row 1: Food consumption
# Panel a: YLL food consumption (top-left)
plot_stacked_sensitivity(
    df_yll_consumption_plot,
    yll_colors,
    axes[0, 0],
    xlabel=YLL_XLABEL,
    ylabel="Food consumption [kcal/person/day]",
    panel_label="a",
    x_ticks=YLL_XTICKS,
    x_ticklabels=YLL_XTICKLABELS,
    label_x_positions=YLL_CONSUMPTION_LABEL_X,
    y_max=2400,
)

# Panel b: GHG food consumption (top-middle)
plot_stacked_sensitivity(
    df_ghg_consumption_plot,
    yll_colors,
    axes[0, 1],
    xlabel=GHG_XLABEL,
    ylabel="Food consumption [kcal/person/day]",
    panel_label="b",
    x_ticks=GHG_XTICKS,
    x_ticklabels=GHG_XTICKLABELS,
    label_x_positions=GHG_CONSUMPTION_LABEL_X,
    y_max=2400,
)

# Panel c: GHG_YLL food consumption (top-right)
plot_stacked_sensitivity(
    df_ghg_yll_consumption_plot,
    yll_colors,
    axes[0, 2],
    xlabel="",  # Will be set with dual labels
    ylabel="Food consumption [kcal/person/day]",
    panel_label="c",
    x_ticks=GHG_YLL_XTICKS,
    x_ticklabels=[""] * len(GHG_YLL_XTICKS),  # Will be replaced
    label_skip=GHG_YLL_LABEL_SKIP,
    y_max=2400,
)

# Row 2: GHG emissions
# Panel d: YLL GHG emissions
plot_stacked_sensitivity(
    df_yll_ghg_plot,
    yll_colors,
    axes[1, 0],
    xlabel=YLL_XLABEL,
    ylabel="GHG emissions [GtCO2eq]",
    panel_label="d",
    x_ticks=YLL_XTICKS,
    x_ticklabels=YLL_XTICKLABELS,
    label_x_positions=YLL_GHG_LABEL_X,
    min_height_for_label=0.08,
)

# Panel e: GHG GHG emissions
plot_stacked_sensitivity(
    df_ghg_ghg_plot,
    yll_colors,
    axes[1, 1],
    xlabel=GHG_XLABEL,
    ylabel="GHG emissions [GtCO2eq]",
    panel_label="e",
    x_ticks=GHG_XTICKS,
    x_ticklabels=GHG_XTICKLABELS,
    label_skip=GHG_GHG_LABEL_SKIP,
    min_height_for_label=0.08,
)

# Panel f: GHG_YLL GHG emissions
plot_stacked_sensitivity(
    df_ghg_yll_ghg_plot,
    yll_colors,
    axes[1, 2],
    xlabel="",
    ylabel="GHG emissions [GtCO2eq]",
    panel_label="f",
    x_ticks=GHG_YLL_XTICKS,
    x_ticklabels=[""] * len(GHG_YLL_XTICKS),
    label_skip=GHG_YLL_LABEL_SKIP,
    min_height_for_label=0.08,
)

# Row 3: Health cost by food group
# Panel g: YLL health cost
plot_stacked_sensitivity(
    df_yll_health_plot,
    yll_colors,
    axes[2, 0],
    xlabel=YLL_XLABEL,
    ylabel="Health cost [million YLL]",
    panel_label="g",
    x_ticks=YLL_XTICKS,
    x_ticklabels=YLL_XTICKLABELS,
    min_height_for_label=0.08,
    pretty_names=PRETTY_NAMES_HEALTH,
)

# Panel h: GHG health cost
plot_stacked_sensitivity(
    df_ghg_health_plot,
    yll_colors,
    axes[2, 1],
    xlabel=GHG_XLABEL,
    ylabel="Health cost [million YLL]",
    panel_label="h",
    x_ticks=GHG_XTICKS,
    x_ticklabels=GHG_XTICKLABELS,
    label_skip=GHG_HEALTH_LABEL_SKIP,
    min_height_for_label=0.08,
    pretty_names=PRETTY_NAMES_HEALTH,
)

# Panel i: GHG_YLL health cost
plot_stacked_sensitivity(
    df_ghg_yll_health_plot,
    yll_colors,
    axes[2, 2],
    xlabel="",
    ylabel="Health cost [million YLL]",
    panel_label="i",
    x_ticks=GHG_YLL_XTICKS,
    x_ticklabels=[""] * len(GHG_YLL_XTICKS),
    label_skip=GHG_YLL_LABEL_SKIP,
    min_height_for_label=0.08,
    pretty_names=PRETTY_NAMES_HEALTH,
)

# Row 4: Objective breakdown (optional)
if INCLUDE_OBJECTIVE_ROW:
    # Panel j: YLL objective breakdown
    plot_objective_sensitivity(
        df_yll_obj,
        axes[3, 0],
        xlabel=YLL_XLABEL,
        panel_label="j",
        x_ticks=YLL_XTICKS,
        x_ticklabels=YLL_XTICKLABELS,
        health_value=CONSTANT_HEALTH_VALUE_PER_YLL,
        ghg_price=CONSTANT_GHG_PRICE,
        label_x_positions=YLL_OBJ_LABEL_X,
        highlight_cat="GHG cost",
    )

    # Panel k: GHG objective breakdown
    plot_objective_sensitivity(
        df_ghg_obj,
        axes[3, 1],
        xlabel=GHG_XLABEL,
        panel_label="k",
        x_ticks=GHG_XTICKS,
        x_ticklabels=GHG_XTICKLABELS,
        health_value=CONSTANT_HEALTH_VALUE_PER_YLL,
        ghg_price=CONSTANT_GHG_PRICE,
        highlight_cat="Health burden",
    )

    # Panel l: GHG_YLL objective breakdown
    plot_objective_sensitivity(
        df_ghg_yll_obj,
        axes[3, 2],
        xlabel="",
        panel_label="l",
        x_ticks=GHG_YLL_XTICKS,
        x_ticklabels=[""] * len(GHG_YLL_XTICKS),
        health_value=CONSTANT_HEALTH_VALUE_PER_YLL,
        ghg_price=CONSTANT_GHG_PRICE,
    )

# Add column titles
axes[0, 0].set_title("Health value sensitivity", fontsize=9, fontweight="bold", pad=10)
axes[0, 1].set_title("GHG price sensitivity", fontsize=9, fontweight="bold", pad=10)
axes[0, 2].set_title("Combined sensitivity", fontsize=9, fontweight="bold", pad=10)

# Share y-axis limits within each row
for row in range(n_rows):
    y_min = min(ax.get_ylim()[0] for ax in axes[row, :])
    y_max = max(ax.get_ylim()[1] for ax in axes[row, :])
    for col in range(3):
        axes[row, col].set_ylim(y_min, y_max)

# Remove x-axis labels except for bottom row
for row in range(n_rows - 1):
    for col in range(3):
        axes[row, col].set_xlabel("")
        axes[row, col].set_xticklabels([])

# Set dual-colored x-axis labels for right column (bottom row only)
set_dual_xaxis_labels(
    axes[n_rows - 1, 2],
    GHG_YLL_XTICKS,
    GHG_YLL_GHG_VALUES,
    GHG_YLL_YLL_VALUES,
)
set_dual_xlabel(axes[n_rows - 1, 2])

# Remove y-axis labels from middle and right columns
for row in range(n_rows):
    for col in [1, 2]:
        axes[row, col].set_ylabel("")
        axes[row, col].set_yticklabels([])

# Align y-axis labels horizontally across rows
fig.align_ylabels(axes[:, 0])

plt.tight_layout()

# Add extra bottom margin for dual x-axis label on right column
plt.subplots_adjust(bottom=0.08)

# Save to notebooks/figures/
output_dir = PROJECT_ROOT / "notebooks" / "figures"
output_dir.mkdir(parents=True, exist_ok=True)
output_path = output_dir / "combined_sensitivity.pdf"
plt.savefig(output_path, dpi=300, bbox_inches="tight")
print(f"Saved to: {output_path}")

plt.show()

In [None]:
# Simplified 2-row figure: diet composition + total health/GHG line plots
# Row 1: Diet composition (same as main figure)
# Row 2: Line graphs showing total health cost (MYLL) and total GHG emissions (GtCO2eq)

import numpy as np

# Compute total health cost and GHG emissions for each sensitivity analysis
yll_total_health = df_yll_health_plot.sum(axis=1)  # MYLL
yll_total_ghg = df_yll_ghg_plot.sum(axis=1)  # GtCO2eq

ghg_total_health = df_ghg_health_plot.sum(axis=1)  # MYLL
ghg_total_ghg = df_ghg_ghg_plot.sum(axis=1)  # GtCO2eq

ghg_yll_total_health = df_ghg_yll_health_plot.sum(axis=1)  # MYLL
ghg_yll_total_ghg = df_ghg_yll_ghg_plot.sum(axis=1)  # GtCO2eq

# Colors for the lines
HEALTH_COLOR = "darkblue"
GHG_COLOR = "darkgreen"

# Spine styling
SPINE_LINEWIDTH = 0.5
SPINE_COLOR = "0.7"

# Custom label positions for panel a (move sugar to the right to avoid overlap with red meat)
PANEL_A_LABEL_X = YLL_CONSUMPTION_LABEL_X.copy()
PANEL_A_LABEL_X["sugar"] = 50  # Move sugar label to the right

# Figure size: 180mm width, maintain ~2:1 aspect ratio with 2:1 row heights
fig_width_mm = 180
fig_width_in = fig_width_mm / 25.4  # Convert mm to inches
fig_height_in = fig_width_in * 0.52  # Aspect ratio for 2:1 row heights

# Calculate zero position for log scale based on data
# YLL x-values (for panel d)
yll_x_values = np.array(yll_total_health.index.values)
yll_zero_pos = log_scale_zero_position(yll_x_values)

# GHG x-values (for panels b, e)
ghg_x_values = np.array(ghg_total_health.index.values)
ghg_zero_pos = log_scale_zero_position(ghg_x_values)

# Combined GHG+YLL x-values (for panels c, f)
ghg_yll_x_values = np.array(ghg_yll_total_health.index.values)
ghg_yll_zero_pos = log_scale_zero_position(ghg_yll_x_values)

# Fix YLL_XTICKS for panel d - replace first tick (1) with calculated zero position
YLL_XTICKS_FIXED = [
    yll_zero_pos if x == 1 and i == 0 else x for i, x in enumerate(YLL_XTICKS)
]

# Modify x_ticks for GHG panels to use calculated zero position
GHG_XTICKS_FIXED = [
    ghg_zero_pos if x == 1 and i == 0 else x for i, x in enumerate(GHG_XTICKS)
]
# Add the last data point (500) if not already in ticks
if 500 not in GHG_XTICKS_FIXED:
    GHG_XTICKS_FIXED.append(500)
# Create fixed tick labels to match
GHG_XTICKLABELS_FIXED = [*list(GHG_XTICKLABELS), "500"]

# For panel f, add tick at 500 to show the last data point
GHG_YLL_XTICKS_FIXED = [
    ghg_yll_zero_pos if x == 1 and i == 0 else x for i, x in enumerate(GHG_YLL_XTICKS)
]
# Add the last data point (500) if not already in ticks
if 500 not in GHG_YLL_XTICKS_FIXED:
    GHG_YLL_XTICKS_FIXED.append(500)
GHG_YLL_GHG_VALUES_FIXED = GHG_YLL_XTICKS_FIXED.copy()


# Compute correct YLL values for tick positions using log interpolation
def interpolate_yll_for_ghg(ghg_tick, ghg_values, yll_values, zero_pos):
    """Interpolate YLL value for a given GHG tick using log-linear interpolation."""
    ghg_arr = np.array(ghg_values)
    yll_arr = np.array(yll_values)

    # Handle zero case
    if ghg_tick == 0 or ghg_tick == zero_pos:
        return 0.0

    # Exact match
    if ghg_tick in ghg_arr:
        idx = np.where(ghg_arr == ghg_tick)[0][0]
        return yll_arr[idx]

    # Find bracketing points (excluding zeros)
    nonzero_mask = ghg_arr > 0
    ghg_nz = ghg_arr[nonzero_mask]
    yll_nz = yll_arr[nonzero_mask]

    if ghg_tick < ghg_nz.min():
        # Extrapolate below
        return yll_nz[0] * (ghg_tick / ghg_nz[0])
    if ghg_tick > ghg_nz.max():
        # Extrapolate above
        return yll_nz[-1] * (ghg_tick / ghg_nz[-1])

    # Find bracketing indices
    idx_upper = np.searchsorted(ghg_nz, ghg_tick)
    idx_lower = idx_upper - 1

    # Log-linear interpolation
    log_ghg_low, log_ghg_high = np.log(ghg_nz[idx_lower]), np.log(ghg_nz[idx_upper])
    log_yll_low, log_yll_high = np.log(yll_nz[idx_lower]), np.log(yll_nz[idx_upper])

    t = (np.log(ghg_tick) - log_ghg_low) / (log_ghg_high - log_ghg_low)
    log_yll = log_yll_low + t * (log_yll_high - log_yll_low)

    return np.exp(log_yll)


# Compute correct YLL values for each GHG tick
GHG_YLL_YLL_VALUES_CORRECT = []
for ghg_tick in GHG_YLL_GHG_VALUES_FIXED:
    yll_val = interpolate_yll_for_ghg(
        ghg_tick, ghg_yll_ghg_values, ghg_yll_yll_values, ghg_yll_zero_pos
    )
    GHG_YLL_YLL_VALUES_CORRECT.append(yll_val)

print(f"YLL zero position: {yll_zero_pos}")
print(f"YLL ticks (fixed): {YLL_XTICKS_FIXED}")
print(f"GHG ticks (fixed): {GHG_XTICKS_FIXED}")
print(f"GHG tick labels (fixed): {GHG_XTICKLABELS_FIXED}")
print(f"GHG_YLL ticks: {GHG_YLL_GHG_VALUES_FIXED}")
print(f"YLL values (corrected): {GHG_YLL_YLL_VALUES_CORRECT}")

# Create 2x3 figure with 2:1 height ratio between rows
fig, axes = plt.subplots(
    2, 3, figsize=(fig_width_in, fig_height_in), gridspec_kw={"height_ratios": [2, 1]}
)

# Row 1: Food consumption (same as before)
plot_stacked_sensitivity(
    df_yll_consumption_plot,
    yll_colors,
    axes[0, 0],
    xlabel=YLL_XLABEL,
    ylabel="Food consumption [kcal/person/day]",
    panel_label="a",
    x_ticks=YLL_XTICKS_FIXED,
    x_ticklabels=YLL_XTICKLABELS,
    label_x_positions=PANEL_A_LABEL_X,
    y_max=2400,
)

plot_stacked_sensitivity(
    df_ghg_consumption_plot,
    yll_colors,
    axes[0, 1],
    xlabel=GHG_XLABEL,
    ylabel="Food consumption [kcal/person/day]",
    panel_label="b",
    x_ticks=GHG_XTICKS_FIXED,
    x_ticklabels=GHG_XTICKLABELS_FIXED,
    label_x_positions=GHG_CONSUMPTION_LABEL_X,
    y_max=2400,
)

plot_stacked_sensitivity(
    df_ghg_yll_consumption_plot,
    yll_colors,
    axes[0, 2],
    xlabel="",
    ylabel="Food consumption [kcal/person/day]",
    panel_label="c",
    x_ticks=GHG_YLL_XTICKS_FIXED,
    x_ticklabels=[""] * len(GHG_YLL_XTICKS_FIXED),
    y_max=2400,
)


# Row 2: Line plots with total health cost and GHG emissions
def plot_totals_line(
    ax,
    x_values,
    health_data,
    ghg_data,
    xlabel,
    panel_label,
    x_ticks,
    x_ticklabels,
    show_legend=False,
):
    """Plot total health cost and GHG emissions as line graphs with twin y-axes."""
    # Handle x=0 for log scale: calculate position for even spacing
    x_arr = np.array(x_values)
    zero_pos = log_scale_zero_position(x_arr)
    x_plot = np.where(x_arr == 0, zero_pos, x_values)

    # Plot health cost on primary y-axis
    (line1,) = ax.plot(
        x_plot,
        health_data.values,
        color=HEALTH_COLOR,
        linewidth=1,
        marker="o",
        markersize=2,
        label="Health cost",
    )
    ax.set_xscale("log")
    ax.set_ylabel(
        "Health cost\n[million YLL]", fontsize=FONTSIZE_AXIS_LABEL, color=HEALTH_COLOR
    )
    ax.tick_params(axis="y", labelcolor=HEALTH_COLOR, labelsize=FONTSIZE_TICK_LABEL)
    ax.tick_params(axis="x", labelsize=FONTSIZE_TICK_LABEL)

    # Create twin axis for GHG emissions
    ax2 = ax.twinx()
    (line2,) = ax2.plot(
        x_plot,
        ghg_data.values,
        color=GHG_COLOR,
        linewidth=1,
        marker="s",
        markersize=2,
        label="GHG emissions",
    )
    ax2.set_ylabel(
        "GHG emissions\n[GtCO2eq]", fontsize=FONTSIZE_AXIS_LABEL, color=GHG_COLOR
    )
    ax2.tick_params(axis="y", labelcolor=GHG_COLOR, labelsize=FONTSIZE_TICK_LABEL)

    # Set x-axis
    ax.set_xlabel(xlabel, fontsize=FONTSIZE_AXIS_LABEL)
    ax.set_xticks(x_ticks)
    ax.set_xticklabels(x_ticklabels)

    # Panel label
    ax.text(
        -0.10,
        1.05,
        panel_label,
        transform=ax.transAxes,
        fontsize=FONTSIZE_PANEL_LABEL,
        fontweight="bold",
        va="top",
        ha="left",
    )

    # Grid on primary axis only
    ax.grid(True, alpha=0.3, which="both")
    ax.set_axisbelow(True)

    # Legend - position at top right
    if show_legend:
        lines = [line1, line2]
        labels = [line.get_label() for line in lines]
        ax.legend(
            lines,
            labels,
            loc="upper right",
            fontsize=FONTSIZE_TICK_LABEL,
            framealpha=0.9,
        )

    return ax2, zero_pos  # Return twin axis and zero_pos for later use


# Panel d: YLL sensitivity - totals (with fixed zero position)
ax2_yll, yll_zero = plot_totals_line(
    axes[1, 0],
    yll_total_health.index.values,
    yll_total_health,
    yll_total_ghg,
    xlabel=YLL_XLABEL,
    panel_label="d",
    x_ticks=YLL_XTICKS_FIXED,
    x_ticklabels=YLL_XTICKLABELS,
)

# Panel e: GHG sensitivity - totals (with 500 tick)
ax2_ghg, _ = plot_totals_line(
    axes[1, 1],
    ghg_total_health.index.values,
    ghg_total_health,
    ghg_total_ghg,
    xlabel=GHG_XLABEL,
    panel_label="e",
    x_ticks=GHG_XTICKS_FIXED,
    x_ticklabels=GHG_XTICKLABELS_FIXED,
)

# Panel f: Combined sensitivity - totals (with legend)
ax2_combined, combined_zero_pos = plot_totals_line(
    axes[1, 2],
    ghg_yll_total_health.index.values,
    ghg_yll_total_health,
    ghg_yll_total_ghg,
    xlabel="",
    panel_label="f",
    x_ticks=GHG_YLL_XTICKS_FIXED,
    x_ticklabels=[""] * len(GHG_YLL_XTICKS_FIXED),
    show_legend=True,
)

# Add column titles
axes[0, 0].set_title("Health costs priced in", fontsize=FONTSIZE_TITLE, pad=8)
axes[0, 1].set_title("GHG costs priced in", fontsize=FONTSIZE_TITLE, pad=8)
axes[0, 2].set_title("Both priced in simultaneously", fontsize=FONTSIZE_TITLE, pad=8)

# Share y-axis limits for row 1 (consumption)
y_min = min(ax.get_ylim()[0] for ax in axes[0, :])
y_max = max(ax.get_ylim()[1] for ax in axes[0, :])
for col in range(3):
    axes[0, col].set_ylim(y_min, y_max)

# Share y-axis limits for row 2 (health - left axes), starting from 0
health_max = max(axes[1, col].get_ylim()[1] for col in range(3))
for col in range(3):
    axes[1, col].set_ylim(0, health_max)

# Share y-axis limits for row 2 (GHG - right axes), starting from 0
ghg_max = max(ax.get_ylim()[1] for ax in [ax2_yll, ax2_ghg, ax2_combined])
for ax in [ax2_yll, ax2_ghg, ax2_combined]:
    ax.set_ylim(0, ghg_max)

# Share x-axis limits between rows (for each column)
for col in range(3):
    # Get x limits from top row and apply to both rows
    x_lim = axes[0, col].get_xlim()
    axes[1, col].set_xlim(x_lim)

# Remove x-axis labels from top row
for col in range(3):
    axes[0, col].set_xlabel("")
    axes[0, col].set_xticklabels([])

# Set dual-colored x-axis labels for panel f with custom spacing
# Clear default labels
axes[1, 2].set_xticks(GHG_YLL_XTICKS_FIXED)
axes[1, 2].set_xticklabels([])
axes[1, 2].set_xlabel("")

# Add dual tick labels manually with increased spacing
trans = axes[1, 2].get_xaxis_transform()
for x, ghg, yll in zip(
    GHG_YLL_XTICKS_FIXED, GHG_YLL_GHG_VALUES_FIXED, GHG_YLL_YLL_VALUES_CORRECT
):
    # Format values - use 0 for the zero position
    if x == combined_zero_pos:
        ghg_str = "0"
    elif ghg < 1000:
        ghg_str = f"{int(ghg)}"
    else:
        ghg_str = f"{int(ghg/1000)}k"

    if x == combined_zero_pos:
        yll_str = "0"
    elif yll < 1000:
        yll_str = f"{int(round(yll))}"
    else:
        yll_str = f"{int(round(yll/1000))}k"

    # GHG label (top row of tick labels) - shifted down
    axes[1, 2].text(
        x,
        -0.06,
        ghg_str,
        transform=trans,
        ha="center",
        va="top",
        fontsize=FONTSIZE_TICK_LABEL,
        color=GHG_COLOR,
        fontweight="bold",
    )
    # YLL label (bottom row of tick labels) - shifted down
    axes[1, 2].text(
        x,
        -0.18,
        yll_str,
        transform=trans,
        ha="center",
        va="top",
        fontsize=FONTSIZE_TICK_LABEL,
        color=HEALTH_COLOR,
        fontweight="bold",
    )

# Add dual axis labels with more spacing from tick labels
axes[1, 2].text(
    0.5,
    -0.36,
    "GHG price [USD/tCO2eq]",
    transform=axes[1, 2].transAxes,
    ha="center",
    va="top",
    fontsize=FONTSIZE_AXIS_LABEL,
    color=GHG_COLOR,
)
axes[1, 2].text(
    0.5,
    -0.50,
    "Health value [USD/YLL]",
    transform=axes[1, 2].transAxes,
    ha="center",
    va="top",
    fontsize=FONTSIZE_AXIS_LABEL,
    color=HEALTH_COLOR,
)

# Remove y-axis labels from middle column (for both rows)
axes[0, 1].set_ylabel("")
axes[0, 1].set_yticklabels([])
axes[1, 1].set_ylabel("")
axes[1, 1].set_yticklabels([])

# Remove y-axis labels from right column (for both rows)
axes[0, 2].set_ylabel("")
axes[0, 2].set_yticklabels([])
axes[1, 2].set_ylabel("")
axes[1, 2].set_yticklabels([])

# Keep only right-side y-axis label for rightmost column in row 2
ax2_yll.set_ylabel("")
ax2_yll.set_yticklabels([])
ax2_ghg.set_ylabel("")
ax2_ghg.set_yticklabels([])
# ax2_combined keeps its label

# Style all spines with 0.5 linewidth and grey color
all_axes = [*list(axes.flat), ax2_yll, ax2_ghg, ax2_combined]
for ax in all_axes:
    for spine in ax.spines.values():
        spine.set_linewidth(SPINE_LINEWIDTH)
        spine.set_color(SPINE_COLOR)

plt.tight_layout()
plt.subplots_adjust(bottom=0.22, wspace=0.15, hspace=0.25)

# Save
output_path_simple = output_dir / "combined_sensitivity_simple.pdf"
plt.savefig(output_path_simple, dpi=300, bbox_inches="tight")
print(f"Saved to: {output_path_simple}")
print(f"Figure size: {fig_width_mm}mm x {fig_height_in * 25.4:.1f}mm")

plt.show()