In [5]:
import numpy as np
import matplotlib.pyplot as plt

# ------------------ Paper-wide style ------------------
FIGSIZE = (7.0, 4.8)  # consistent size for both figs
plt.rcParams.update({
    "figure.figsize": FIGSIZE,
    "savefig.dpi": 300,
    "font.family": "DejaVu Sans",
    "font.size": 11,
    "axes.titlesize": 13,
    "axes.labelsize": 12,
    "xtick.labelsize": 11,
    "ytick.labelsize": 11,
    "legend.fontsize": 10,
    "axes.grid": True,
    "grid.linestyle": "-",
    "grid.alpha": 0.25,
    "axes.spines.top": False,
    "axes.spines.right": False,
})

# =========================================================
# ============= Critical policy constructors ==============
# (NA uses ε-critical; AD (ε=0) intentionally omitted here)
# =========================================================

def rho(payment_k: float, penalty_k: float, u: float) -> float:
    return (payment_k - u) / penalty_k

def _equalize(u: float, A: set[int], eps: float,
              payment: np.ndarray, penalty: np.ndarray) -> np.ndarray:
    """
    Equalize misreport utility to u for k in A; for other eligible k use (u - eps);
    buckets with payment < u get p=0.  (NA critical-face construction)
    """
    payment  = np.asarray(payment, dtype=float)
    penalty  = np.asarray(penalty, dtype=float)
    m        = len(payment)
    p        = np.zeros(m, dtype=float)

    where_ok = np.where(payment >= u)[0]
    iota     = int(np.min(where_ok)) if where_ok.size > 0 else m  # safety

    for k in range(m):
        if k < iota:
            p[k] = 0.0
        elif k in A:
            p[k] = rho(payment[k], penalty[k], u)
        else:
            p[k] = rho(payment[k], penalty[k], u - eps)
    return np.clip(p, 0.0, 1.0)

def equalize_plus(i: int, k: int, eps: float,
                  payment: np.ndarray, penalty: np.ndarray) -> np.ndarray:
    u = payment[i-1] + eps if i > 0 else eps
    return _equalize(u, {k}, eps, payment, penalty)

def equalize_minus(i: int, k: int, eps: float,
                   payment: np.ndarray, penalty: np.ndarray) -> np.ndarray:
    u = payment[i] - eps
    return _equalize(u, {k}, eps, payment, penalty)

# =========================================================
# ========= Deterministic (i,k) evaluation (NA) ===========
# =========================================================
# report(j) = k  if j < i
# report(j) = j  if j >= i

def principal_utility_ik(q, p, cost, i_cut, k_pool, payments, penalties, valuation):
    tot = 0.0
    for j, q_j in enumerate(q):
        r = k_pool if j < i_cut else j
        if r == j:
            term = valuation[j, j] - payments[j] - p[j]*cost
        else:
            term = valuation[j, r] - payments[r] + p[r]*(penalties[r] - cost)
        tot += q_j * term
    return float(tot)

def welfare_ik(q, p, cost, i_cut, k_pool, payments, penalties, valuation):
    tot = 0.0
    for j, q_j in enumerate(q):
        r = k_pool if j < i_cut else j
        tot += q_j * (valuation[j, r] - p[r]*cost)
    return float(tot)

def enumerate_na_candidates(eps: float, payment: np.ndarray, penalty: np.ndarray):
    """Yield (p, i, k, label) for NA ε‑critical policies."""
    m = len(payment)
    for i in range(m):
        for k in range(i, m):
            yield equalize_plus(i, k, eps, payment, penalty),  i, k, (i, k, '+')
            yield equalize_minus(i, k, eps, payment, penalty), i, k, (i, k, '-')

def best_policy_value_NA(q: np.ndarray, val: np.ndarray,
                         payment: np.ndarray, penalty: np.ndarray,
                         cost: float, eps: float):
    """
    Return (best_value, best_label, best_p) by scanning NA ε‑critical policies.
    Utility is computed using deterministic (i,k) map (report rule).
    """
    best_v, best_lbl, best_p = -np.inf, None, None
    for p, i, k, lbl in enumerate_na_candidates(eps, payment, penalty):
        v = principal_utility_ik(q, p, cost, i, k, payment, penalty, val)
        if v > best_v:
            best_v, best_lbl, best_p = v, lbl, p
    return best_v, best_lbl, best_p

def best_policy_welfare_NA(q: np.ndarray, val: np.ndarray,
                           payment: np.ndarray, penalty: np.ndarray,
                           cost: float, eps: float):
    """Return (best_welfare, best_label, best_p) for NA ε‑critical policies."""
    best_w, best_lbl, best_p = -np.inf, None, None
    for p, i, k, lbl in enumerate_na_candidates(eps, payment, penalty):
        w = welfare_ik(q, p, cost, i, k, payment, penalty, val)
        if w > best_w:
            best_w, best_lbl, best_p = w, lbl, p
    return best_w, best_lbl, best_p

# ---------------- Helpers ----------------

def _safe_epsilon(eps, payments):
    diffs = np.diff(np.asarray(payments, dtype=float))
    if diffs.size == 0:
        return float(max(1e-12, eps))
    upper = 0.5 * float(np.min(np.abs(diffs))) - 1e-12
    return float(max(1e-12, min(float(eps), upper if upper > 0 else 1e-12)))

def deterministic_mapping_tuple(m: int, i_cut: int, k_pool: int):
    """r[j] = k_pool for j < i_cut; r[j] = j otherwise."""
    return tuple(k_pool if j < i_cut else j for j in range(m))

# =========================================================
# ================= Heatmap (NA only) =====================
# =========================================================

def plot_heatmap_na(
    payments=np.array([0.3, 0.8, 1.3]),
    penalties=np.array([1.0, 1.2, 1.4]),
    valuation=np.array([[0.5,0,0],[0,1.4,0],[0,0,3.0]], dtype=float),
    cost=0.7, epsilon=1e-3, resolution=250,
    objective="utility",  # or "welfare"
    fname="Figure1_Heatmap_NA.png"
):
    m = len(payments)
    eps_eff = _safe_epsilon(epsilon, payments)

    grid = np.linspace(0.0, 1.0, int(resolution))
    X, Y = np.meshgrid(grid, grid, indexing='ij')
    valid = (X >= 0) & (Y >= 0) & ((X + Y) <= 1.0 + 1e-12)

    Z = np.full(X.shape, np.nan, dtype=float)
    eq_map = np.empty(X.shape, dtype=object)

    for ix in range(resolution):
        for iy in range(resolution):
            if not valid[ix, iy]:
                continue
            q0, q1 = float(X[ix, iy]), float(Y[ix, iy])
            q2 = 1.0 - q0 - q1
            if q2 < -1e-12:  # numerical guard
                continue
            q = np.array([q0, q1, q2], dtype=float)

            if objective == "utility":
                best_val, lbl, _ = best_policy_value_NA(q, valuation, payments, penalties, cost, eps_eff)
            else:
                best_val, lbl, _ = best_policy_welfare_NA(q, valuation, payments, penalties, cost, eps_eff)

            Z[ix, iy] = best_val
            i_cut, k_pool, _ = lbl
            eq_map[ix, iy] = deterministic_mapping_tuple(m, i_cut, k_pool)

    # Plot heatmap (masked where invalid)
    fig, ax = plt.subplots()  # use global FIGSIZE
    Z_masked = np.ma.array(Z, mask=~valid)
    im = ax.imshow(
        Z_masked, origin="lower",
        extent=[0.0, 1.0, 0.0, 1.0],
        aspect="equal", interpolation="nearest"
    )

    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label("Principal Utility" if objective == "utility" else "Welfare")

    # Optional: show the simplex boundary line (q0 + q1 = 1)
    ax.plot([0, 1], [1, 0], linestyle=":", linewidth=1, color="0.5")

    # Boundaries of equilibrium regions (where mapping changes)
    boundaries = np.zeros_like(Z_masked, dtype=float)
    for ix in range(resolution - 1):
        for iy in range(resolution - 1):
            if not valid[ix, iy]:
                continue
            here = eq_map[ix, iy]
            right_change = valid[ix, iy+1] and (eq_map[ix, iy+1] is not None) and (eq_map[ix, iy+1] != here)
            up_change    = valid[ix+1, iy] and (eq_map[ix+1, iy] is not None) and (eq_map[ix+1, iy] != here)
            if right_change or up_change:
                boundaries[ix, iy] = 1.0

    ax.contour(X, Y, boundaries, levels=[0.5], linewidths=1.0, colors="white")

    # Annotate ONLY equilibrium mapping labels (one per region)
    unique_maps = sorted(
        {eq_map[ix, iy] for ix in range(resolution) for iy in range(resolution)
         if valid[ix, iy] and (eq_map[ix, iy] is not None)}
    )
    for mapping in unique_maps:
        # Build mask for this mapping
        same = np.vectorize(lambda x: x == mapping)(eq_map)
        mask = valid & same
        ys, xs = np.where(mask)
        if xs.size == 0:
            continue
        cx = float(np.mean(grid[xs]))  # mean in x-space
        cy = float(np.mean(grid[ys]))  # mean in y-space
        ax.text(cx, cy, str(mapping), ha="center", va="center",
                fontsize=10, color="black",
                bbox=dict(facecolor="white", edgecolor="none", alpha=0.6, pad=1.5))

    ax.set_xlabel("Proportion Type 0")
    ax.set_ylabel("Proportion Type 1")
    ax.set_title("" if objective == "utility"
                 else "Non‑Adaptive Heatmap (Welfare)")

    plt.tight_layout()
    plt.savefig(fname, bbox_inches="tight")
    plt.close(fig)
    return fname

# =========================================================
# ======== Line figure: NA welfare vs pay(1) + p's ========
# =========================================================

def plot_line_na_welfare_vs_pay1(
    output_fname="Figure2_NA_Welfare_vs_pay1.png",
    n_points=100,
    fixed_pay0=1.0, fixed_pay2=3.0,
    penalty_margin=0.5,  # penalties = payments + margin
    valuation=np.array([[0.99,0.90,0.50],
                        [0.00,1.50,1.40],
                        [0.00,0.00,4.00]], dtype=float),
    base_cost=1.0,
    epsilon=1e-2,
    prior=np.array([0.4, 0.3, 0.3], dtype=float)
):
    # x-axis domain for pay(1)
    delta = 2*epsilon + 1e-9  # ensure spacing satisfies epsilon-critical construction
    pay1_min = fixed_pay0 + delta
    pay1_max = fixed_pay2 - delta
    pay1_values = np.linspace(pay1_min, pay1_max, n_points)

    y_welfare = []
    p0, p1, p2 = [], [], []
    mapping_seq = []

    for pay1 in pay1_values:
        payments = np.array([fixed_pay0, pay1, fixed_pay2], dtype=float)
        penalties = payments + penalty_margin
        eps_eff = _safe_epsilon(epsilon, payments)

        best_w, lbl, p = best_policy_welfare_NA(prior, valuation, payments, penalties, base_cost, eps_eff)
        y_welfare.append(best_w)
        p0.append(float(p[0])); p1.append(float(p[1])); p2.append(float(p[2]))

        i_cut, k_pool, _ = lbl
        mapping_seq.append(deterministic_mapping_tuple(3, i_cut, k_pool))

    # Detect transitions in equilibrium mapping
    boundaries = []
    for t in range(1, len(pay1_values)):
        if mapping_seq[t] != mapping_seq[t-1]:
            boundaries.append((pay1_values[t], mapping_seq[t-1], mapping_seq[t]))

    # ---- Plot
    fig, ax_left = plt.subplots()  # size from globals

    # NA welfare (primary axis)
    ax_left.plot(pay1_values, y_welfare, label="Welfare",
                 linestyle=(0, (5, 1.8)), marker="x", linewidth=2)
    ax_left.set_xlabel("pay(1)")
    ax_left.set_ylabel("Optimal Social Welfare")
    ax_left.set_title("")

    # Vertical boundary lines + *horizontal* transition text
    y_top = ax_left.get_ylim()[1]
    yrange = y_top - ax_left.get_ylim()[0]
    for x, prev_map, new_map in boundaries:
        ax_left.axvline(x=x, linestyle=":", linewidth=1.2, color="0.4")
        ax_left.text(x, y_top - 0.02*yrange,
                     f"{prev_map} \N{RIGHTWARDS ARROW} {new_map}",
                     ha="center", va="top", rotation=0, fontsize=10)

    # Secondary axis for audit probabilities
    ax_right = ax_left.twinx()
    ax_right.plot(pay1_values, p0, label="p0", linestyle="--", linewidth=2)
    ax_right.plot(pay1_values, p1, label="p1", linestyle="dashdot", linewidth=2)
    ax_right.plot(pay1_values, p2, label="p2", linestyle="dotted", linewidth=2)
    ax_right.set_ylabel("Audit Probability")

    # Legend in the upper-right (combine both axes)
    h_left, l_left = ax_left.get_legend_handles_labels()
    h_right, l_right = ax_right.get_legend_handles_labels()
    ax_left.legend(h_left + h_right, l_left + l_right, loc="upper right", framealpha=0.9)

    plt.tight_layout()
    plt.savefig(output_fname, bbox_inches="tight")
    plt.close(fig)
    return output_fname

# =========================================================
# ======================= Run both ========================
# =========================================================

if __name__ == "__main__":
    # Figure 1 — Heatmap (NA-only; utility by default)
    _ = plot_heatmap_na(
        payments=np.array([0.3, 0.8, 1.3]),
        penalties=np.array([1.0, 1.2, 1.4]),
        valuation=np.array([[0.5, 0.0, 0.0],
                            [0.0, 1.4, 0.0],
                            [0.0, 0.0, 3.0]], dtype=float),
        cost=0.7,
        epsilon=1e-3,
        resolution=250,
        objective="utility",   
        fname="Figure1_Heatmap_NA.png"
    )

    # Figure 2 — NA Social Welfare vs pay(1) (with p's)
    _ = plot_line_na_welfare_vs_pay1(
        output_fname="Figure2_NA_Welfare_vs_pay1.png",
        n_points=120,
        fixed_pay0=1.0, fixed_pay2=3.0,
        penalty_margin=0.5,
        valuation=np.array([[0.99, 0.90, 0.50],
                            [0.00, 1.50, 1.40],
                            [0.00, 0.00, 4.00]], dtype=float),
        base_cost=1.0,
        epsilon=1e-2,
        prior=np.array([0.4, 0.3, 0.3], dtype=float)
    )
