In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.tri import Triangulation
from math import sqrt

# =========================================================
# Global style: match your 8x6, large-font figure
# =========================================================
plt.rcParams.update({
    "figure.figsize": (8, 6),
    "savefig.dpi": 300,
    "font.family": "DejaVu Sans",
    "font.size": 18,      # base font
    "axes.labelsize": 24,
    "axes.titlesize": 24,
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
    "legend.fontsize": 24,
    "axes.grid": False,
    "axes.spines.top": False,
    "axes.spines.right": False,
})

# =========================================================
# Non-adaptive (ε-critical) primitives
# =========================================================
def rho(payment_k: float, penalty_k: float, u: float) -> float:
    return (payment_k - u) / penalty_k

def _equalize(u: float, A: set[int], eps: float,
              payments: np.ndarray, penalties: np.ndarray) -> np.ndarray:
    """Equalize misreport utility to u on A, and u-eps elsewhere above iota."""
    payments  = np.asarray(payments,  dtype=float)
    penalties = np.asarray(penalties, dtype=float)
    m         = len(payments)
    p         = np.zeros(m, dtype=float)

    where_ok = np.where(payments >= u)[0]
    iota     = int(np.min(where_ok)) if where_ok.size > 0 else m

    for k in range(m):
        if k < iota:
            p[k] = 0.0
        elif k in A:
            p[k] = rho(payments[k], penalties[k], u)
        else:
            p[k] = rho(payments[k], penalties[k], u - eps)
    return np.clip(p, 0.0, 1.0)

def equalize_plus(i: int, k: int, eps: float,
                  payments: np.ndarray, penalties: np.ndarray) -> np.ndarray:
    u = payments[i-1] + eps if i > 0 else eps
    return _equalize(u, {k}, eps, payments, penalties)

def equalize_minus(i: int, k: int, eps: float,
                   payments: np.ndarray, penalties: np.ndarray) -> np.ndarray:
    u = payments[i] - eps
    return _equalize(u, {k}, eps, payments, penalties)

def principal_utility_ik(q, p, cost, i_cut, k_pool, payments, penalties, valuation):
    tot = 0.0
    for j, q_j in enumerate(q):
        r = k_pool if j < i_cut else j
        if r == j:
            term = valuation[j, j] - payments[j] - p[j]*cost
        else:
            term = valuation[j, r] - payments[r] + p[r]*(penalties[r] - cost)
        tot += q_j * term
    return float(tot)

def enumerate_na_candidates(eps: float, payments: np.ndarray, penalties: np.ndarray):
    """Yield (p, i, k, label) for NA ε-critical policies."""
    m = len(payments)
    for i in range(m):
        for k in range(i, m):
            yield equalize_plus(i, k, eps, payments, penalties),  i, k, (i, k, '+')
            yield equalize_minus(i, k, eps, payments, penalties), i, k, (i, k, '-')

def best_policy_value_NA(q: np.ndarray, valuation: np.ndarray,
                         payments: np.ndarray, penalties: np.ndarray,
                         cost: float, eps: float):
    best_v, best_lbl, best_p = -np.inf, None, None
    for p, i, k, lbl in enumerate_na_candidates(eps, payments, penalties):
        v = principal_utility_ik(q, p, cost, i, k, payments, penalties, valuation)
        if v > best_v:
            best_v, best_lbl, best_p = v, lbl, p
    return best_v, best_lbl, best_p

def _safe_epsilon(eps, payments):
    diffs = np.diff(np.asarray(payments, dtype=float))
    if diffs.size == 0:
        return float(max(1e-12, eps))
    upper = 0.5 * float(np.min(np.abs(diffs))) - 1e-12
    return float(max(1e-12, min(float(eps), upper if upper > 0 else 1e-12)))

def deterministic_mapping_tuple(m: int, i_cut: int, k_pool: int):
    """General (r0,...,r_{m-1}) mapping: r[j]=k_pool if j<i_cut else j."""
    return tuple(k_pool if j < i_cut else j for j in range(m))

# =========================================================
# Ternary heatmap: q0,q1,q2 simplex -> equilateral triangle
# =========================================================
def bary_to_xy(q0, q1, q2):
    """Map (q0,q1,q2) with q0+q1+q2=1 to (x,y) in an equilateral triangle."""
    v0 = np.array([0.0, 0.0])           # q0 corner (left)
    v1 = np.array([1.0, 0.0])           # q1 corner (right)
    v2 = np.array([0.5, sqrt(3)/2.0])   # q2 corner (top)
    return q0*v0 + q1*v1 + q2*v2

def plot_heatmap_na_ternary(
    payments=np.array([0.3, 0.8, 1.3]),
    penalties=np.array([1.0, 1.2, 1.4]),
    valuation=np.array([[0.5,0,0],[0,1.4,0],[0,0,3.0]], dtype=float),
    cost=0.7,
    epsilon=1e-3,
    lattice_n=60,
    fname="Figure1_Ternary_Heatmap_NA.png"
):
    m = len(payments)
    eps_eff = _safe_epsilon(epsilon, payments)

    # Sample the simplex with i+j+k=lattice_n
    xs, ys, Z = [], [], []
    mapping = []
    for i0 in range(lattice_n + 1):
        for i1 in range(lattice_n + 1 - i0):
            i2 = lattice_n - i0 - i1
            q0 = i0 / lattice_n
            q1 = i1 / lattice_n
            q2 = i2 / lattice_n
            q  = np.array([q0, q1, q2], dtype=float)

            x, y = bary_to_xy(q0, q1, q2)

            best_val, lbl, _ = best_policy_value_NA(q, valuation, payments, penalties, cost, eps_eff)
            i_cut, k_pool, _ = lbl
            mapping.append(deterministic_mapping_tuple(m, i_cut, k_pool))

            xs.append(x); ys.append(y); Z.append(best_val)

    xs = np.array(xs); ys = np.array(ys); Z = np.array(Z); mapping = np.array(mapping, dtype=object)

    # Triangulate points
    tri = Triangulation(xs, ys)

    fig, ax = plt.subplots()
    # No rectangular axes: only triangle + colorbar
    ax.set_axis_off()

    # Heatmap
    col = ax.tripcolor(tri, Z, shading="gouraud")
    cbar = fig.colorbar(col, ax=ax, pad=0.02)
    cbar.set_label("Principal Utility", fontsize=24)
    cbar.ax.tick_params(labelsize=18)

    # Triangle boundary
    v0 = bary_to_xy(1, 0, 0)
    v1 = bary_to_xy(0, 1, 0)
    v2 = bary_to_xy(0, 0, 1)
    ax.plot([v0[0], v1[0]], [v0[1], v1[1]], color="black", linewidth=1.2)
    ax.plot([v1[0], v2[0]], [v1[1], v2[1]], color="black", linewidth=1.2)
    ax.plot([v2[0], v0[0]], [v2[1], v0[1]], color="black", linewidth=1.2)

    # Equilibrium boundaries: edges whose endpoints have different mapping
    edges = set()
    for tri_inds in tri.triangles:
        for a, b in ((tri_inds[0], tri_inds[1]),
                     (tri_inds[1], tri_inds[2]),
                     (tri_inds[2], tri_inds[0])):
            if tuple(mapping[a]) != tuple(mapping[b]):
                e = tuple(sorted((a, b)))
                edges.add(e)
    for a, b in edges:
        ax.plot([xs[a], xs[b]], [ys[a], ys[b]], color="white", linewidth=1.2)

    # Label each distinct mapping once (fontsize 24, like your other figure)
    unique_maps = sorted(set(tuple(mv) for mv in mapping))
    for mp in unique_maps:
        mask = np.array([tuple(mv) == mp for mv in mapping])
        if not np.any(mask):
            continue
        cx = float(xs[mask].mean())
        cy = float(ys[mask].mean())
        ax.text(cx, cy, str(mp),
                ha="center", va="center",
                fontsize=18, color="black",
                bbox=dict(facecolor="white", edgecolor="none",
                          alpha=0.7, pad=2.5))

    # Corner labels: q0, q1, q2
    ax.text(v0[0]-0.03, v0[1]-0.03, r"$q_0$", ha="right", va="top",   fontsize=22)
    ax.text(v1[0]+0.03, v1[1]-0.03, r"$q_1$", ha="left",  va="top",   fontsize=22)
    ax.text(v2[0],      v2[1]+0.04, r"$q_2$", ha="center",va="bottom",fontsize=22)

    # Square-ish bounds around triangle
    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, v2[1] + 0.08)

    plt.tight_layout()
    plt.savefig(fname, bbox_inches="tight")
    plt.close(fig)
    return fname

# ========================= RUN ==========================
if __name__ == "__main__":
    _ = plot_heatmap_na_ternary(
        payments  = np.array([0.3, 0.8, 1.3]),
        penalties = np.array([1.0, 1.2, 1.4]),
        valuation = np.array([[0.5, 0.0, 0.0],
                              [0.0, 1.4, 0.0],
                              [0.0, 0.0, 3.0]], dtype=float),
        cost      = 0.7,
        epsilon   = 1e-3,
        lattice_n = 60,
        fname     = "Figure1_Ternary_Heatmap_NA.png"
    )
    print("Saved Figure1_Ternary_Heatmap_NA.png")


Saved Figure1_Ternary_Heatmap_NA.png


In [3]:
# -*- coding: utf-8 -*-
"""
Non-adaptive figures (heatmap + welfare vs pay(1)) with
figure size and fonts unified to match an 8×6, large-font style.

Outputs:
  - Figure1_Heatmap_NA.png
  - Figure2_NA_Welfare_vs_pay1.png
"""

import numpy as np
import matplotlib.pyplot as plt

# ------------------ Paper-wide style (match 8×6 & large fonts) ------------------
PAPER_FIGSIZE = (8, 6)  # same size as your 2-type proposition figure
plt.rcParams.update({
    "figure.figsize": PAPER_FIGSIZE,
    "savefig.dpi": 600,
    "font.family": "DejaVu Sans",
    "font.size": 18,          # base
    "axes.labelsize": 24,     # axis labels (match your reference)
    "xtick.labelsize": 18,
    "ytick.labelsize": 18,
    "legend.fontsize": 20,    # balanced on 8×6 (your ref used 24; 20 avoids crowding)
    "axes.grid": True,
    "grid.linestyle": "-",
    "grid.alpha": 0.25,
    "axes.spines.top": False,
    "axes.spines.right": False,
})

# =========================================================
# ============= Critical policy constructors (NA) =========
# =========================================================

def rho(payment_k: float, penalty_k: float, u: float) -> float:
    return (payment_k - u) / penalty_k

def _equalize(u: float, A: set[int], eps: float,
              payment: np.ndarray, penalty: np.ndarray) -> np.ndarray:
    """
    Equalize misreport utility to u for k in A; for other eligible k use (u - eps).
    Buckets with payment < u get p=0.  (Non-adaptive ε-critical construction)
    """
    payment  = np.asarray(payment, dtype=float)
    penalty  = np.asarray(penalty, dtype=float)
    m        = len(payment)
    p        = np.zeros(m, dtype=float)

    where_ok = np.where(payment >= u)[0]
    iota     = int(np.min(where_ok)) if where_ok.size > 0 else m  # safety

    for k in range(m):
        if k < iota:
            p[k] = 0.0
        elif k in A:
            p[k] = rho(payment[k], penalty[k], u)
        else:
            p[k] = rho(payment[k], penalty[k], u - eps)
    return np.clip(p, 0.0, 1.0)

def equalize_plus(i: int, k: int, eps: float,
                  payment: np.ndarray, penalty: np.ndarray) -> np.ndarray:
    u = payment[i-1] + eps if i > 0 else eps
    return _equalize(u, {k}, eps, payment, penalty)

def equalize_minus(i: int, k: int, eps: float,
                   payment: np.ndarray, penalty: np.ndarray) -> np.ndarray:
    u = payment[i] - eps
    return _equalize(u, {k}, eps, payment, penalty)

# =========================================================
# ========= Deterministic (i,k) evaluation (NA) ===========
# =========================================================
# report(j) = k  if j < i;  report(j) = j  if j >= i

def principal_utility_ik(q, p, cost, i_cut, k_pool, payments, penalties, valuation):
    tot = 0.0
    for j, q_j in enumerate(q):
        r = k_pool if j < i_cut else j
        if r == j:
            term = valuation[j, j] - payments[j] - p[j]*cost
        else:
            term = valuation[j, r] - payments[r] + p[r]*(penalties[r] - cost)
        tot += q_j * term
    return float(tot)

def welfare_ik(q, p, cost, i_cut, k_pool, payments, penalties, valuation):
    tot = 0.0
    for j, q_j in enumerate(q):
        r = k_pool if j < i_cut else j
        tot += q_j * (valuation[j, r] - p[r]*cost)
    return float(tot)

def enumerate_na_candidates(eps: float, payment: np.ndarray, penalty: np.ndarray):
    """Yield (p, i, k, label) for NA ε‑critical policies."""
    m = len(payment)
    for i in range(m):
        for k in range(i, m):
            yield equalize_plus(i, k, eps, payment, penalty),  i, k, (i, k, '+')
            yield equalize_minus(i, k, eps, payment, penalty), i, k, (i, k, '-')

def best_policy_value_NA(q: np.ndarray, val: np.ndarray,
                         payment: np.ndarray, penalty: np.ndarray,
                         cost: float, eps: float):
    """
    Return (best_value, best_label, best_p) by scanning NA ε‑critical policies.
    Utility computed with deterministic (i,k) map.
    """
    best_v, best_lbl, best_p = -np.inf, None, None
    for p, i, k, lbl in enumerate_na_candidates(eps, payment, penalty):
        v = principal_utility_ik(q, p, cost, i, k, payment, penalty, val)
        if v > best_v:
            best_v, best_lbl, best_p = v, lbl, p
    return best_v, best_lbl, best_p

def best_policy_welfare_NA(q: np.ndarray, val: np.ndarray,
                           payment: np.ndarray, penalty: np.ndarray,
                           cost: float, eps: float):
    """Return (best_welfare, best_label, best_p) for NA ε‑critical policies."""
    best_w, best_lbl, best_p = -np.inf, None, None
    for p, i, k, lbl in enumerate_na_candidates(eps, payment, penalty):
        w = welfare_ik(q, p, cost, i, k, payment, penalty, val)
        if w > best_w:
            best_w, best_lbl, best_p = w, lbl, p
    return best_w, best_lbl, best_p

# ---------------- Helpers ----------------

def _safe_epsilon(eps, payments):
    diffs = np.diff(np.asarray(payments, dtype=float))
    if diffs.size == 0:
        return float(max(1e-12, eps))
    upper = 0.5 * float(np.min(np.abs(diffs))) - 1e-12
    return float(max(1e-12, min(float(eps), upper if upper > 0 else 1e-12)))

def deterministic_mapping_tuple(m: int, i_cut: int, k_pool: int):
    """r[j] = k_pool for j < i_cut; r[j] = j otherwise."""
    return tuple(k_pool if j < i_cut else j for j in range(m))

# =========================================================
# ================= Heatmap (NA only) =====================
# =========================================================

def plot_heatmap_na(
    payments=np.array([0.3, 0.8, 1.3]),
    penalties=np.array([1.0, 1.2, 1.4]),
    valuation=np.array([[0.5,0,0],[0,1.4,0],[0,0,3.0]], dtype=float),
    cost=0.7, epsilon=1e-3, resolution=250,
    objective="utility",  # or "welfare"
    fname="Figure1_Heatmap_NA.png"
):
    m = len(payments)
    eps_eff = _safe_epsilon(epsilon, payments)

    grid = np.linspace(0.0, 1.0, int(resolution))
    X, Y = np.meshgrid(grid, grid, indexing='ij')
    valid = (X >= 0) & (Y >= 0) & ((X + Y) <= 1.0 + 1e-12)

    Z = np.full(X.shape, np.nan, dtype=float)
    eq_map = np.empty(X.shape, dtype=object)

    for ix in range(resolution):
        for iy in range(resolution):
            if not valid[ix, iy]:
                continue
            q0, q1 = float(X[ix, iy]), float(Y[ix, iy])
            q2 = 1.0 - q0 - q1
            if q2 < -1e-12:
                continue
            q = np.array([q0, q1, q2], dtype=float)

            if objective == "utility":
                best_val, lbl, _ = best_policy_value_NA(q, valuation, payments, penalties, cost, eps_eff)
            else:
                best_val, lbl, _ = best_policy_welfare_NA(q, valuation, payments, penalties, cost, eps_eff)

            Z[ix, iy] = best_val
            i_cut, k_pool, _ = lbl
            eq_map[ix, iy] = deterministic_mapping_tuple(m, i_cut, k_pool)

    fig, ax = plt.subplots()  # size & fonts from rcParams
    Z_masked = np.ma.array(Z, mask=~valid)
    im = ax.imshow(
        Z_masked, origin="lower",
        extent=[0.0, 1.0, 0.0, 1.0],
        aspect="equal", interpolation="nearest"
    )

    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label("Principal Utility" if objective == "utility" else "Welfare")
    cbar.ax.tick_params(labelsize=18)

    # Show the simplex boundary q0 + q1 = 1 (muted)
    ax.plot([0, 1], [1, 0], linestyle=":", linewidth=1, color="0.5")

    # Boundaries where the mapping changes
    boundaries = np.zeros_like(Z_masked, dtype=float)
    for ix in range(resolution - 1):
        for iy in range(resolution - 1):
            if not valid[ix, iy]:
                continue
            here = eq_map[ix, iy]
            change_r = valid[ix, iy+1] and (eq_map[ix, iy+1] is not None) and (eq_map[ix, iy+1] != here)
            change_u = valid[ix+1, iy] and (eq_map[ix+1, iy] is not None) and (eq_map[ix+1, iy] != here)
            if change_r or change_u:
                boundaries[ix, iy] = 1.0

    ax.contour(X, Y, boundaries, levels=[0.5], linewidths=1.0, colors="white")

    # Label each region once with its (r0,r1,r2) — fontsize 24 to match your request
    unique_maps = sorted(
        {eq_map[ix, iy] for ix in range(resolution) for iy in range(resolution)
         if valid[ix, iy] and (eq_map[ix, iy] is not None)}
    )
    for mapping in unique_maps:
        same = np.vectorize(lambda x, m=mapping: x == m)(eq_map)
        mask = valid & same
        ys, xs = np.where(mask)
        if xs.size == 0:
            continue
        cx = float(np.mean(grid[xs]))
        cy = float(np.mean(grid[ys]))
        ax.text(cx, cy, str(mapping), ha="center", va="center",
                fontsize=24, color="black",
                bbox=dict(facecolor="white", edgecolor="none", alpha=0.6, pad=2.0))

    ax.set_xlabel("Proportion Type 0")
    ax.set_ylabel("Proportion Type 1")

    plt.tight_layout()
    plt.savefig(fname, bbox_inches="tight")
    plt.close(fig)
    return fname

# =========================================================
# ======== Line figure: NA welfare vs pay(1) + p's ========
# =========================================================

def plot_line_na_welfare_vs_pay1(
    output_fname="Figure2_NA_Welfare_vs_pay1.png",
    n_points=120,
    fixed_pay0=1.0, fixed_pay2=3.0,
    penalty_margin=0.5,  # penalties = payments + margin
    valuation=np.array([[0.99,0.90,0.50],
                        [0.00,1.50,1.40],
                        [0.00,0.00,4.00]], dtype=float),
    base_cost=1.0,
    epsilon=1e-2,
    prior=np.array([0.4, 0.3, 0.3], dtype=float)
):
    # x-axis domain for pay(1)
    delta = 2*epsilon + 1e-9
    pay1_min = fixed_pay0 + delta
    pay1_max = fixed_pay2 - delta
    pay1_values = np.linspace(pay1_min, pay1_max, n_points)

    y_welfare = []
    p0, p1, p2 = [], [], []
    mapping_seq = []

    for pay1 in pay1_values:
        payments = np.array([fixed_pay0, pay1, fixed_pay2], dtype=float)
        penalties = payments + penalty_margin
        eps_eff = _safe_epsilon(epsilon, payments)

        best_w, lbl, p = best_policy_welfare_NA(prior, valuation, payments, penalties, base_cost, eps_eff)
        y_welfare.append(best_w)
        p0.append(float(p[0])); p1.append(float(p[1])); p2.append(float(p[2]))

        i_cut, k_pool, _ = lbl
        mapping_seq.append(deterministic_mapping_tuple(3, i_cut, k_pool))

    # Detect transitions in equilibrium mapping
    boundaries = []
    for t in range(1, len(pay1_values)):
        if mapping_seq[t] != mapping_seq[t-1]:
            boundaries.append((pay1_values[t], mapping_seq[t-1], mapping_seq[t]))

    # ---- Plot
    fig, ax_left = plt.subplots()  # size & fonts from rcParams

    # NA welfare (primary axis), distinct style
    ax_left.plot(pay1_values, y_welfare, label="Welfare",
                 linestyle=(0, (5, 1.8)), marker="x", linewidth=2, markersize=6)
    ax_left.set_xlabel("pay(1)")
    ax_left.set_ylabel("Optimal Social Welfare")

    # Vertical boundary lines + horizontal transition text (nudged to the right)
    y_bottom, y_top = ax_left.get_ylim()
    yrange = y_top - y_bottom
    xrange = pay1_values[-1] - pay1_values[0]
    for x, prev_map, new_map in boundaries:
        ax_left.axvline(x=x, linestyle=":", linewidth=1.2, color="0.4")
        ax_left.text(x + 0.02*xrange, y_top - 0.02*yrange,   # nudge right of dashed line
                     f"{prev_map} \N{RIGHTWARDS ARROW} {new_map}",
                     ha="left", va="top", rotation=0, fontsize=18)

    # Secondary axis for audit probabilities (different line styles)
    ax_right = ax_left.twinx()
    ax_right.plot(pay1_values, p0, label="p0", linestyle="--", linewidth=2)
    ax_right.plot(pay1_values, p1, label="p1", linestyle="dashdot", linewidth=2)
    ax_right.plot(pay1_values, p2, label="p2", linestyle="dotted", linewidth=2)
    ax_right.set_ylabel("Audit Probability")
    ax_left.tick_params(labelsize=18)
    ax_right.tick_params(labelsize=18)

    # Legend in upper-right (combine both axes) with balanced size
    h_left, l_left = ax_left.get_legend_handles_labels()
    h_right, l_right = ax_right.get_legend_handles_labels()
    ax_left.legend(h_left + h_right, l_left + l_right,
                   loc="upper right", framealpha=0.9, prop={"size": 20},
                   borderpad=0.25, labelspacing=0.4, handlelength=2.2, handletextpad=0.8)

    plt.tight_layout()
    plt.savefig(output_fname, bbox_inches="tight")
    plt.close(fig)
    return output_fname

# =========================================================
# ======================= Run both ========================
# =========================================================

if __name__ == "__main__":
    # Figure 1 — Heatmap (NA-only; utility shown)
    _ = plot_heatmap_na(
        payments=np.array([0.3, 0.8, 1.3]),
        penalties=np.array([1.0, 1.2, 1.4]),
        valuation=np.array([[0.5, 0.0, 0.0],
                            [0.0, 1.4, 0.0],
                            [0.0, 0.0, 3.0]], dtype=float),
        cost=0.7,
        epsilon=1e-3,
        resolution=250,
        objective="utility",
        fname="Figure1_Heatmap_NA.png"
    )

    # Figure 2 — NA Social Welfare vs pay(1) (with p0,p1,p2 overlay)
    _ = plot_line_na_welfare_vs_pay1(
        output_fname="Figure2_NA_Welfare_vs_pay1.png",
        n_points=120,
        fixed_pay0=1.0, fixed_pay2=3.0,
        penalty_margin=0.5,
        valuation=np.array([[0.99, 0.90, 0.50],
                            [0.00, 1.50, 1.40],
                            [0.00, 0.00, 4.00]], dtype=float),
        base_cost=1.0,
        epsilon=1e-2,
        prior=np.array([0.4, 0.3, 0.3], dtype=float)
    )

    print("Saved:\n - Figure1_Heatmap_NA.png\n - Figure2_NA_Welfare_vs_pay1.png")


Saved:
 - Figure1_Heatmap_NA.png
 - Figure2_NA_Welfare_vs_pay1.png


In [None]:
# -*- coding: utf-8 -*-
"""
Figure 2 (split) — YOUR settings that show the non‑monotone segment:
  pay(0)=1.0, pay(2)=3.0, penalties = payments + 0.5,
  cost λ = 1.0, epsilon = 1e-2, prior q = (0.4, 0.3, 0.3),
  welfare‑optimal non‑adaptive policy at each pay(1).

Outputs:
  - Figure2a_NA_Welfare_vs_pay1.png
  - Figure2b_NA_Audits_vs_pay1.png
"""

import numpy as np
import matplotlib.pyplot as plt

# ------------------ Paper-wide style (8×6, large fonts) ------------------
FIGSIZE = (8, 6)
BASE_FONTSIZE = 18
LABEL_FONTSIZE = 24
LEGEND_FONTSIZE = 20
ANNOT_FONTSIZE  = 20

plt.rcParams.update({
    "figure.figsize": FIGSIZE,
    "savefig.dpi": 600,
    "font.family": "DejaVu Sans",
    "font.size": BASE_FONTSIZE,     # tick/base text
    "axes.labelsize": LABEL_FONTSIZE,
    "xtick.labelsize": BASE_FONTSIZE,
    "ytick.labelsize": BASE_FONTSIZE,
    "legend.fontsize": LEGEND_FONTSIZE,
    "axes.grid": True,
    "grid.linestyle": "-",
    "grid.alpha": 0.25,
    "axes.spines.top": False,
    "axes.spines.right": False,
})

# =========================================================
# ============= NA critical-policy constructors ===========
# =========================================================
def rho(payment_k: float, penalty_k: float, u: float) -> float:
    return (payment_k - u) / penalty_k

def _equalize(u: float, A: set[int], eps: float,
              payment: np.ndarray, penalty: np.ndarray) -> np.ndarray:
    """Tie A at u; others at u-ε; buckets with pay<u get p=0 (NA ε‑critical)."""
    payment  = np.asarray(payment, dtype=float)
    penalty  = np.asarray(penalty, dtype=float)
    m = len(payment)
    p = np.zeros(m, dtype=float)

    where_ok = np.where(payment >= u)[0]
    iota = int(np.min(where_ok)) if where_ok.size > 0 else m

    for k in range(m):
        if k < iota:
            p[k] = 0.0
        elif k in A:
            p[k] = rho(payment[k], penalty[k], u)
        else:
            p[k] = rho(payment[k], penalty[k], u - eps)
    return np.clip(p, 0.0, 1.0)

def equalize_plus(i: int, k: int, eps: float,
                  payment: np.ndarray, penalty: np.ndarray) -> np.ndarray:
    u = payment[i-1] + eps if i > 0 else eps
    return _equalize(u, {k}, eps, payment, penalty)

def equalize_minus(i: int, k: int, eps: float,
                   payment: np.ndarray, penalty: np.ndarray) -> np.ndarray:
    u = payment[i] - eps
    return _equalize(u, {k}, eps, payment, penalty)

# =========================================================
# ========= Deterministic (i,k) evaluation (NA) ===========
# =========================================================
# report(j) = k  if j < i;  report(j) = j  if j >= i
def principal_utility_ik(q, p, cost, i_cut, k_pool, payments, penalties, valuation):
    tot = 0.0
    for j, q_j in enumerate(q):
        r = k_pool if j < i_cut else j
        if r == j:
            term = valuation[j, j] - payments[j] - p[j]*cost
        else:
            term = valuation[j, r] - payments[r] + p[r]*(penalties[r] - cost)
        tot += q_j * term
    return float(tot)

def welfare_ik(q, p, cost, i_cut, k_pool, payments, penalties, valuation):
    tot = 0.0
    for j, q_j in enumerate(q):
        r = k_pool if j < i_cut else j
        tot += q_j * (valuation[j, r] - p[r]*cost)
    return float(tot)

def enumerate_na_candidates(eps: float, payment: np.ndarray, penalty: np.ndarray):
    """Yield (p, i, k, label) for NA ε‑critical policies."""
    m = len(payment)
    for i in range(m):
        for k in range(i, m):
            yield equalize_plus(i, k, eps, payment, penalty),  i, k, (i, k, '+')
            yield equalize_minus(i, k, eps, payment, penalty), i, k, (i, k, '-')

def best_policy_welfare_NA(q: np.ndarray, val: np.ndarray,
                           payment: np.ndarray, penalty: np.ndarray,
                           cost: float, eps: float):
    """Return (best_welfare, best_label, best_p) for NA ε‑critical policies."""
    best_w, best_lbl, best_p = -np.inf, None, None
    for p, i, k, lbl in enumerate_na_candidates(eps, payment, penalty):
        w = welfare_ik(q, p, cost, i, k, payment, penalty, val)
        if w > best_w:
            best_w, best_lbl, best_p = w, lbl, p
    return best_w, best_lbl, best_p

def deterministic_mapping_tuple(m: int, i_cut: int, k_pool: int):
    """r[j] = k_pool for j < i_cut; r[j] = j otherwise."""
    return tuple(k_pool if j < i_cut else j for j in range(m))

def _safe_epsilon(eps, payments):
    diffs = np.diff(np.asarray(payments, dtype=float))
    if diffs.size == 0:
        return float(max(1e-12, eps))
    upper = 0.5 * float(np.min(np.abs(diffs))) - 1e-12
    return float(max(1e-12, min(float(eps), upper if upper > 0 else 1e-12)))

# =========================================================
# ============== Shared sweep for Figure 2 =================
# =========================================================
def compute_na_vs_pay1_sequence(
    n_points=200,
    fixed_pay0=1.0, fixed_pay2=3.0,         # YOUR setting
    penalty_margin=0.5,                      # penalties = payments + 0.5
    valuation=np.array([[0.99,0.90,0.50],
                        [0.00,1.50,1.40],
                        [0.00,0.00,4.00]], dtype=float),
    base_cost=1.0,                           # λ
    epsilon=1e-2,
    prior=np.array([0.4, 0.3, 0.3], dtype=float)  # YOUR prior
):
    # domain for pay(1) with ε buffer
    delta = 2*epsilon + 1e-9
    pay1_min = fixed_pay0 + delta
    pay1_max = fixed_pay2 - delta
    x = np.linspace(pay1_min, pay1_max, n_points)

    y_w, p0, p1, p2, maps = [], [], [], [], []
    for pay1 in x:
        payments  = np.array([fixed_pay0, pay1, fixed_pay2], dtype=float)
        penalties = payments + penalty_margin
        eps_eff   = _safe_epsilon(epsilon, payments)

        w, lbl, p = best_policy_welfare_NA(prior, valuation, payments, penalties, base_cost, eps_eff)
        y_w.append(w)
        p0.append(float(p[0])); p1.append(float(p[1])); p2.append(float(p[2]))
        i_cut, k_pool, _ = lbl
        maps.append(deterministic_mapping_tuple(3, i_cut, k_pool))

    # detect equilibrium changes
    boundaries = []
    for t in range(1, len(x)):
        if maps[t] != maps[t-1]:
            boundaries.append((x[t], maps[t-1], maps[t]))

    return x, np.array(y_w), (np.array(p0), np.array(p1), np.array(p2)), boundaries

# =========================================================
# ============= 2a — Welfare vs pay(1) ====================
# =========================================================
def plot_welfare_vs_pay1(output_fname="Figure2a_NA_Welfare_vs_pay1.png", **kwargs):
    x, y_w, _, boundaries = compute_na_vs_pay1_sequence(**kwargs)
    fig, ax = plt.subplots()
    ax.plot(x, y_w, label="Welfare",
            linestyle=(0, (5, 1.8)), marker="x", linewidth=2, markersize=6)
    ax.set_xlabel("pay(1)")
    ax.set_ylabel("Optimal Social Welfare")

    # vertical change lines + horizontal labels (to the right of the line)
    y0, y1 = ax.get_ylim()
    yr = y1 - y0
    xr = x[-1] - x[0]
    for xpos, prev_map, new_map in boundaries:
        ax.axvline(x=xpos, linestyle=":", linewidth=1.2, color="0.4")
        ax.text(xpos + 0.02*xr, y1 - 0.02*yr,
                f"{prev_map} \N{RIGHTWARDS ARROW} {new_map}",
                ha="left", va="top", fontsize=ANNOT_FONTSIZE)

    ax.legend(loc="upper right", framealpha=0.9, fontsize=LEGEND_FONTSIZE)
    plt.tight_layout()
    plt.savefig(output_fname, bbox_inches="tight")
    plt.close(fig)
    return output_fname

# =========================================================
# ===== 2b — Audit probabilities vs pay(1) ================
# =========================================================
def plot_audit_probs_vs_pay1(output_fname="Figure2b_NA_Audits_vs_pay1.png", **kwargs):
    x, _, (p0, p1, p2), boundaries = compute_na_vs_pay1_sequence(**kwargs)
    fig, ax = plt.subplots()
    ax.plot(x, p0, label="p0", linestyle="--",      linewidth=2)
    ax.plot(x, p1, label="p1", linestyle="dashdot", linewidth=2)
    ax.plot(x, p2, label="p2", linestyle="dotted",  linewidth=2)
    ax.set_xlabel("pay(1)")
    ax.set_ylabel("Audit Probability")
    ax.set_ylim(-0.02, 1.02)

    y0, y1 = ax.get_ylim()
    yr = y1 - y0
    xr = x[-1] - x[0]
    for xpos, prev_map, new_map in boundaries:
        ax.axvline(x=xpos, linestyle=":", linewidth=1.2, color="0.4")
        ax.text(xpos + 0.02*xr, y1 - 0.02*yr,
                f"{prev_map} \N{RIGHTWARDS ARROW} {new_map}",
                ha="left", va="top", fontsize=ANNOT_FONTSIZE)

    ax.legend(loc="upper right", framealpha=0.9, borderpad=0.25,
              labelspacing=0.4, handlelength=2.2, handletextpad=0.8,
              fontsize=LEGEND_FONTSIZE)
    plt.tight_layout()
    plt.savefig(output_fname, bbox_inches="tight")
    plt.close(fig)
    return output_fname

# ============================== Run ==============================
if __name__ == "__main__":
    # Two panels with YOUR Figure 2 parameters
    common_kwargs = dict(
        n_points=200,
        fixed_pay0=1.0, fixed_pay2=3.0,
        penalty_margin=0.5,
        base_cost=1.0,
        epsilon=1e-2,
        prior=np.array([0.4, 0.3, 0.3], dtype=float)
    )
    _ = plot_welfare_vs_pay1("Figure2a_NA_Welfare_vs_pay1.png", **common_kwargs)
    _ = plot_audit_probs_vs_pay1("Figure2b_NA_Audits_vs_pay1.png", **common_kwargs)
    print("Saved:\n - Figure2a_NA_Welfare_vs_pay1.png\n - Figure2b_NA_Audits_vs_pay1.png")
