In [None]:
import os
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import itertools

# Global font settings
plt.rcParams.update({
    "font.size": 20,
    "axes.titlesize": 26,
    "axes.labelsize": 26,
    "xtick.labelsize": 22,
    "ytick.labelsize": 22,
    "legend.fontsize": 20
})

os.makedirs("figures", exist_ok=True)

sections = {
    "Main Estimators": [
        ("Matching (caliper)", 20.93, 20.51, 21.39),
        ("Doubly Robust",     19.58, 18.81, 20.32),
        ("IPW",               15.61,  7.37, 24.11),
        ("T-learner",         20.87, 20.05, 21.69),
    ],
    'Covariate Adjustment':[
        ('Naive (Unadjusted)', 67.80, 67.21, 68.39),
        ('Adjusted',           15.79, 15.19, 16.39)
    ],
    "Trimming Sensitivity": [
        ("None",        23.63, 21.59, 25.73),
        ("[0.01,0.99]", 20.49, 18.95, 22.20),
        ("[0.05,0.95]", 19.58, 18.81, 20.32),
        ("[0.10,0.90]", 17.94, 16.69, 19.17),
    ],
    "Matching Methods": [
        ("Caliper (PSM)", 20.93, 20.51, 21.39),
        ("KNN k=1",       23.03, 22.49, 23.56),
        ("KNN k=3",       23.01, 22.47, 23.54),
        ("KNN k=5",       23.01, 22.46, 23.54),
    ],
    "Temporal Stability": [
        ("2018", 19.58, 18.81, 20.32),
        ("2015", 26.62, 24.91, 28.34),
        ("2012", 21.72, 19.93, 23.35),
        ("2009", 21.84, 20.34, 23.39),
    ],
    "Imputation": [
        ("MICE+mean/mode",    19.58, 18.81, 20.32),
        ("Complete-case",     23.07, 22.97, 24.26),
        ("Simple mean/mode",  19.30, 18.58, 20.01),
    ],
    "Ambiguous Covariates": [
        ("Full set",        19.58, 18.81, 20.32),
        ("Excl. home env",  21.55, 20.88, 22.24),
    ],
    "Random Seeds": [
        ("Seed 0",    21.07, 20.93, 22.12),
        ("Seed 10",   20.92, 20.81, 21.96),
        ("Seed 42",   19.58, 18.81, 20.32),
    ],
    "Placebo Tests": [
        ("Rand. Treatment",   0.03, -0.70, 0.71),
        ("Rand. Outcome",    -0.15, -3.96, 3.79),
        ("Both Randomized",  -0.33, -2.35, 1.65),
    ],
}

# Colors per section (cycled)
color_cycle = itertools.cycle(plt.cm.tab20.colors)

labels, estimates, los, his, x_positions, section_bounds = [], [], [], [], [], []
section_colors = {}
x = 0
for sec_name, items in sections.items():
    start = x
    section_colors[sec_name] = next(color_cycle)
    for label, ate, lo, hi in items:
        labels.append(label); estimates.append(ate); los.append(lo); his.append(hi)
        x_positions.append(x); x += 2
    end = x - 1
    section_bounds.append((start, end, sec_name))
    x += 2

# Errors
lower_err = [e - l for e, l in zip(estimates, los)]
upper_err = [h - e for h, e in zip(his, estimates)]

# Bigger / taller figure
fig, ax = plt.subplots(figsize=(26, 18), dpi=150)

# Plot all error bars in black
ax.errorbar(
    x_positions, estimates, yerr=[lower_err, upper_err],
    fmt='o', ms=12, lw=2.5, elinewidth=3, capsize=10,
    color='black'
)

# Reference line
ax.axhline(y=0, linestyle='--', linewidth=1.6, color='gray')

# Y padding — extra big to spread things out
ymin = min(min(los) - 10, -5)
ymax_data = max(his)
yrange = ymax_data - ymin
ymax = ymax_data + 0.30 * yrange   # much more headroom
ax.set_ylim(ymin, ymax)

# Light shaded sections
for start, end, sec in section_bounds:
    ax.axvspan(start - 0.5, end + 0.5,
               facecolor=section_colors[sec], alpha=0.10)

# X ticks
ax.set_xticks(x_positions)
ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=22)

# Labels and title
ax.set_ylabel("Estimated Treatment Effect", fontsize=26)
ax.set_title("95% Confidence Intervals and ATEs — Robustness Overview", fontsize=28, fontweight="bold")

# Annotate ATEs
for xi, yi, lo, hi in zip(x_positions, estimates, los, his):
    if yi >= 0:
        y_text = hi + 0.05 * yrange
        va = 'bottom'
    else:
        y_text = lo - 0.05 * yrange
        va = 'top'
    ax.annotate(f"{yi:.2f}", xy=(xi, yi), xytext=(xi, y_text),
                textcoords='data', ha='center', va=va, fontsize=16,
                arrowprops=dict(arrowstyle='-', lw=0.8), clip_on=False)

# Legend
patches = [mpatches.Patch(color=section_colors[sec], alpha=0.3, label=sec)
           for sec in sections.keys()]
ax.legend(handles=patches, bbox_to_anchor=(1.02, 1), loc='upper left', title="Sections")

fig.tight_layout()
out_path = "figures/robustness_ci_overview_blackcis_taller.png"
plt.savefig(out_path, bbox_inches="tight")
print(out_path)
