In [2]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
import cvxpy as cp
from scipy.stats import poisson, uniform, expon, pareto, hypergeom
from scipy.optimize import minimize, fsolve, least_squares
import scipy as sc
from tqdm import tqdm
from mdptoolbox import mdp, util
import itertools
from scipy.sparse import csr_matrix, lil_matrix
from matplotlib.patches import Patch
import math
import random
import sympy as sp
from sympy.printing.latex import print_latex

plt.rcParams['axes.labelsize'] = 14
plt.rcParams['axes.titlesize'] = 16

In [3]:
import itertools
import numpy as np
from dataclasses import dataclass
from typing import Dict, Tuple, List, Optional

from scipy.optimize import minimize, Bounds, NonlinearConstraint


def all_bitstrings(n: int) -> np.ndarray:
    """Return array of shape (2^n, n) with rows in {0,1}^n."""
    return np.array(list(itertools.product([0, 1], repeat=n)), dtype=int)


def branch_list(n: int, h: int):
    """All branches (A, aA) where |A|=n-h and aA is assignment on A."""
    k = n - h
    subsets_A = list(itertools.combinations(range(n), k))
    branches = []
    for A in subsets_A:
        for aA in itertools.product([0, 1], repeat=k):
            branches.append((tuple(A), tuple(aA)))
    return branches


@dataclass
class SciPySolveResult:
    success: bool
    message: str
    t: float
    s: np.ndarray                 # shape (n,)
    f: np.ndarray                 # shape (n, 2^n)  f[i, idx(a)]
    fun: float                    # objective
    nit: int


def solve_full_program_scipy(
    n: int,
    h: int,
    C: float,
    n_starts: int = 10,
    seed: int = 0,
    maxiter: int = 2000,
    verbose: bool = False,
) -> SciPySolveResult:
    """
    Solve the full program with SciPy (trust-constr) using multi-start.

    Variables:
      s in [0,1]^n
      f_i(a) >= 0 for i in [n], a in {0,1}^n
      t free (but we bound t>=0 since objective is a max of nonneg-ish terms)

    Constraints (implemented):
      - Eq constraints enforced as equality for all i:
            E_{a_-i}[ f_i(1,a_-i) ] - 1  =  E_{a_-i}[ f_i(0,a_-i) ]
        where expectation uses product distribution over coordinates != i.
      - IR:
            sum_i s_i <= sum_a pi_s(a) * f(a)
      - Branch constraints:
            For each (A,aA),  E_{a_-A}[ f(aA,a_-A) + C*1[a=0] ] <= t
    """

    assert 1 <= n
    assert 0 <= h <= n
    assert C >= 0

    bits = all_bitstrings(n)          # (m, n), m=2^n
    m = bits.shape[0]
    branches = branch_list(n, h)

    # Precompute useful masks
    all_zero_mask = (bits.sum(axis=1) == 0).astype(float)  # (m,)

    # Variable packing: x = [s (n), f (n*m), t (1)]
    def unpack(x: np.ndarray):
        s = x[:n]
        f_flat = x[n:n + n * m]
        f = f_flat.reshape((n, m))
        t = x[-1]
        return s, f, t

    def pack(s: np.ndarray, f: np.ndarray, t: float):
        return np.concatenate([s.ravel(), f.ravel(), np.array([t], dtype=float)])

    # Probability helpers
    def prob_full(s: np.ndarray) -> np.ndarray:
        # pi_s(a) = ∏_j s_j^{a_j} (1-s_j)^{1-a_j}
        # compute vectorized: for each a row, product over j
        s = np.clip(s, 0.0, 1.0)
        term = np.where(bits == 1, s[None, :], (1.0 - s)[None, :])  # (m,n)
        return term.prod(axis=1)  # (m,)

    def prob_over_complement(s: np.ndarray, fixed_idx: Tuple[int, ...]) -> np.ndarray:
        # product over j not in fixed_idx
        s = np.clip(s, 0.0, 1.0)
        fixed = np.zeros(n, dtype=bool)
        fixed[list(fixed_idx)] = True
        term = np.where(bits == 1, s[None, :], (1.0 - s)[None, :])  # (m,n)
        if fixed.all():
            return np.ones(m, dtype=float)
        return term[:, ~fixed].prod(axis=1)  # (m,)

    # Objective: minimize t
    def objective(x: np.ndarray) -> float:
        return x[-1]

    # Constraints
    def eq_constraints(x: np.ndarray) -> np.ndarray:
        # size n: for each i, E[f_i | a_i=1] - 1 - E[f_i | a_i=0] = 0
        s, f, _t = unpack(x)
        out = np.zeros(n, dtype=float)
        for i in range(n):
            w = prob_over_complement(s, (i,))  # (m,)
            mask1 = (bits[:, i] == 1)
            mask0 = ~mask1
            E1 = np.dot(w[mask1], f[i, mask1])
            E0 = np.dot(w[mask0], f[i, mask0])
            out[i] = (E1 - 1.0) - E0
        return out

    def ir_constraint(x: np.ndarray) -> float:
        # sum_i s_i - sum_a pi(a)*f(a) <= 0
        s, f, _t = unpack(x)
        pi = prob_full(s)
        fsum = f.sum(axis=0)  # (m,)
        rhs = np.dot(pi, fsum)
        return s.sum() - rhs

    def branch_constraints(x: np.ndarray) -> np.ndarray:
        # for each branch b: E_{a_-A}[ f(a) + C*1[a=0] ] - t <= 0
        s, f, t = unpack(x)
        out = np.zeros(len(branches), dtype=float)

        fsum = f.sum(axis=0)  # (m,)

        for b_idx, (A, aA) in enumerate(branches):
            # build mask for rows matching assignment on A
            mask = np.ones(m, dtype=bool)
            for pos, j in enumerate(A):
                mask &= (bits[:, j] == aA[pos])

            w = prob_over_complement(s, A)  # (m,)
            expr = np.dot(w[mask], fsum[mask] + C * all_zero_mask[mask])
            out[b_idx] = expr - t
        return out

    # SciPy constraint objects:
    eq_con = NonlinearConstraint(eq_constraints, 0.0, 0.0)          # equality
    ir_con = NonlinearConstraint(ir_constraint, -np.inf, 0.0)       # <= 0
    br_con = NonlinearConstraint(branch_constraints, -np.inf, 0.0)  # <= 0

    # Bounds:
    # s in [0,1], f >= 0, t >= 0
    lb = np.zeros(n + n * m + 1, dtype=float)
    ub = np.full(n + n * m + 1, np.inf, dtype=float)
    ub[:n] = 1.0  # s upper bounds
    # f already has lb=0 and ub=inf
    # t already has lb=0 and ub=inf
    bounds = Bounds(lb, ub)

    rng = np.random.default_rng(seed)
    best_res = None

    for start in range(n_starts):
        # init s in (0,1), f positive, t something
        s0 = 0.1 + 0.8 * rng.random(n)
        f0 = np.full((n, m), 0.1, dtype=float)
        t0 = 1.0
        x0 = pack(s0, f0, t0)

        res = minimize(
            objective,
            x0,
            method="trust-constr",
            bounds=bounds,
            constraints=[eq_con, ir_con, br_con],
            options={
                "maxiter": maxiter,
                "verbose": 3 if (verbose and start == 0) else 0,
            },
        )

        if best_res is None or (res.success and res.fun < best_res.fun):
            best_res = res

    if best_res is None:
        raise RuntimeError("No result returned by SciPy (unexpected).")

    s_sol, f_sol, t_sol = unpack(best_res.x)

    return SciPySolveResult(
        success=bool(best_res.success),
        message=str(best_res.message),
        t=float(t_sol),
        s=s_sol.astype(float),
        f=f_sol.astype(float),
        fun=float(best_res.fun),
        nit=int(getattr(best_res, "nit", -1)),
    )

In [7]:
%%time
sol = solve_full_program_scipy(n=4, h=1, C=3.0, n_starts=5, seed=1, verbose=False)
print("success:", sol.success)
print("message:", sol.message)
print("t* =", sol.t)
print("s* =", sol.s)

print("f_1(000) =", sol.f[0, 0])

success: True
message: `gtol` termination condition is satisfied.
t* = 2.0000029661913885
s* = [0.33333293 0.33333304 0.99999981 0.33333299]
f_1(000) = 1.7706752292408065e-07
CPU times: user 1min 59s, sys: 9min 24s, total: 11min 24s
Wall time: 1min 8s


In [17]:
%%time
out = minimize_over_s(n=6, h=4, C=8.0, n_random=300, n_local_starts=8, seed=1, local_method="Powell")
print("success:", out.success)
print("t* =", out.t)
print("s* =", out.s)

success: True
t* = 2.3885382809802693
s* = [0.18550826 0.40629624 0.4039483  1.         0.37186237 0.02092311]
CPU times: user 23.9 s, sys: 55.1 ms, total: 24 s
Wall time: 24 s


In [21]:
%%time
n = 6
h = 4
C = 8.0

solver = FullProgramLPSolver(n=n, h=h, C=C)

interior = minimize_over_s_sorted(
    solver,
    n_random=300,
    n_local_starts=8,
    seed=1,
    local_method="Powell",
)

faces = minimize_over_faces(
    solver,
    seed=1,
    local_method="Powell",
    n_local_starts_per_face=3,
    include_endpoints=True,
    max_faces=None,   # set to e.g. 200 if you want a quick partial scan
)

print("=== Sorted 'interior-ish' search (always sorts s decreasing) ===")
print("ok:", interior.ok)
print("t:", interior.t)
print("s:", interior.s)

print("\n=== Best face over 1>=s1>=...>=sn>=0 (tight adjacencies + endpoints) ===")
print("ok:", faces.ok)
print("t:", faces.t)
print("s:", faces.s)
print("face (tight_adj, top_fixed, bottom_fixed):", faces.face)
if faces.face is not None:
    tight_adj, top_fixed, bottom_fixed = faces.face
    print("  tight_adj indices (0-based):", tight_adj)
    print("  meaning: s_i = s_{i+1} for i in tight_adj")
    print("  top_fixed (s1=1):", top_fixed)
    print("  bottom_fixed (sn=0):", bottom_fixed)


=== Sorted 'interior-ish' search (always sorts s decreasing) ===
ok: True
t: 2.3599996845764584
s: [1.         0.3107921  0.29620201 0.28817095 0.25090803 0.21392658]

=== Best face over 1>=s1>=...>=sn>=0 (tight adjacencies + endpoints) ===
ok: True
t: 2.3245553182504723
s: [0.70943058 0.70943058 0.70943058 0.         0.         0.        ]
face (tight_adj, top_fixed, bottom_fixed): ((0, 1, 3, 4), False, True)
  tight_adj indices (0-based): (0, 1, 3, 4)
  meaning: s_i = s_{i+1} for i in tight_adj
  top_fixed (s1=1): False
  bottom_fixed (sn=0): True
CPU times: user 5min 33s, sys: 549 ms, total: 5min 34s
Wall time: 5min 35s


In [26]:
%%time
run_grid()

n= 2, h= 1, C= 3.0
  face:     t= 1.5, s=[1.000000 0.500000]
  interior: t= 1.5, s=[1.000000 0.500000]
n= 3, h= 1, C= 3.0
  face:     t= 1.8, s=[1.000000 0.400000 0.400000]
  interior: t= 1.80001, s=[0.999853 0.400006 0.399996]
n= 3, h= 2, C= 3.0
  face:     t= 1.5, s=[1.000000 0.500000 0.000000]
  interior: t= 1.50053, s=[1.000000 0.499356 0.001170]
n= 4, h= 1, C= 3.0
  face:     t= 2, s=[1.000000 0.333333 0.333333 0.333333]
  interior: t= 2.00001, s=[1.000000 0.333352 0.333331 0.333329]
n= 4, h= 2, C= 3.0
  face:     t= 1.72508, s=[1.000000 0.241694 0.241694 0.241694]
  interior: t= 1.77783, s=[1.000000 0.333455 0.333455 0.110922]
n= 4, h= 3, C= 3.0
  face:     t= 1.5, s=[1.000000 0.500000 0.000000 0.000000]
  interior: t= 1.5, s=[1.000000e+00 5.000000e-01 4.248354e-18 4.248354e-18]
n= 5, h= 1, C= 3.0
  face:     t= 2.14286, s=[1.000000 0.285714 0.285714 0.285714 0.285714]
  interior: t= 2.16192, s=[1.000000 0.322750 0.279511 0.279360 0.279360]
n= 5, h= 2, C= 3.0
  face:     t= 1.854

KeyboardInterrupt: 

In [28]:
%%time
run_grid(Ns = [2, 3, 4, 5], Cs = [7.0, 11.0, 15.0])

n= 2, h= 1, C= 7.0
  face:     t= 1.75, s=[1.000000 0.750000]
  interior: t= 1.75, s=[1.000000 0.750000]
n= 3, h= 1, C= 7.0
  face:     t= 2.26628, s=[0.676246 0.676246 0.676246]
  interior: t= 2.2668, s=[0.677773 0.676743 0.676171]
n= 3, h= 2, C= 7.0
  face:     t= 1.75, s=[1.000000 0.750000 0.000000]
  interior: t= 1.75004, s=[1.000000e+00 7.499926e-01 4.388965e-05]
n= 4, h= 1, C= 7.0
  face:     t= 2.63488, s=[0.623589 0.623589 0.623589 0.623589]
  interior: t= 2.67581, s=[0.672342 0.635874 0.617760 0.617741]
n= 4, h= 2, C= 7.0
  face:     t= 2.26628, s=[0.676246 0.676246 0.676246 0.000000]
  interior: t= 2.32923, s=[0.655447 0.616391 0.615949 0.133584]
n= 4, h= 3, C= 7.0
  face:     t= 1.75, s=[1.000000 0.750000 0.000000 0.000000]
  interior: t= 1.88536, s=[1.000000 0.647167 0.231520 0.006670]
n= 5, h= 1, C= 7.0
  face:     t= 2.97311, s=[0.575271 0.575271 0.575271 0.575271 0.575271]
  interior: t= 3.03858, s=[0.608803 0.595602 0.593073 0.565917 0.565917]
n= 5, h= 2, C= 7.0
  face:

In [31]:
%%time
run_grid(Ns = [6], Cs = [7.0, 11.0, 15.0])

n= 6, h= 1, C= 7.0
  face:     t= 3.26993, s=[0.532867 0.532867 0.532867 0.532867 0.532867 0.532867]
  interior: t= 3.75315, s=[0.996676 0.629296 0.582244 0.472825 0.463836 0.463836]
n= 6, h= 2, C= 7.0
  face:     t= 2.82413, s=[1.000000 0.364825 0.364825 0.364825 0.364825 0.364825]
  interior: t= 3.00814, s=[0.550214 0.550214 0.550214 0.442343 0.424130 0.253764]
n= 6, h= 3, C= 7.0
  face:     t= 2.46783, s=[1.000000 0.293565 0.293565 0.293565 0.293565 0.293565]
  interior: t= 2.67365, s=[0.525013 0.525013 0.525013 0.525013 0.107660 0.107660]


KeyboardInterrupt: 

In [49]:
%%time
run_grid(Ns = [6], Cs = [7.0, 11.0, 15.0])

n= 6, h= 1, C= 7.0
  face:     t= 3.27024, s=[0.532799 0.532799 0.532799 0.532799 0.532799 0.532799]
  interior: t= 3.98865, s=[0.999654 0.798359 0.736484 0.554120 0.469698 0.430155]


KeyboardInterrupt: 

In [30]:
def run_grid(Ns = [2, 3, 4, 5, 6, 7, 8], Cs = [3.0, 7.0, 11.0, 15.0]):

    # Tuning knobs (increase for better quality, decrease for speed)
    interior_n_random = 100
    interior_n_local_starts = 3

    face_local_starts_per_face = 1   # keep small; full face scan is already a lot
    face_include_endpoints = True
    face_max_faces = None            # full enumeration; set e.g. 400 to cap runtime
    face_maxiter = 160               # local optimizer effort per face-start
    interior_maxiter = 250

    # Monkey-patch maxiter controls (your functions call minimize(..., options={...}))
    # If you want cleaner control, just edit the maxiter numbers inside those functions.
    import types
    from scipy.optimize import minimize as _minimize

    # Wrap scipy.minimize so we can override maxiter consistently (optional).
    def minimize_with_cap(fun, x0, method, options):
        opts = dict(options or {})
        if method == "Powell":
            # apply caps
            opts["maxiter"] = min(opts.get("maxiter", interior_maxiter), interior_maxiter)
        else:
            opts["maxiter"] = min(opts.get("maxiter", interior_maxiter), interior_maxiter)
        return _minimize(fun, x0, method=method, options=opts)

    # If you DON'T want this wrapper trick, delete it and just edit the maxiter values
    # inside minimize_over_s_sorted / minimize_over_faces.

    # Run
    for C in Cs:
        for n in Ns:
            for h in range(1, n):  # all h up to n-1
                solver = FullProgramLPSolver(n=n, h=h, C=C)

                # Best "interior-ish" (sorted)
                interior = minimize_over_s_sorted(
                    solver,
                    n_random=interior_n_random,
                    n_local_starts=interior_n_local_starts,
                    seed=1,
                    local_method="Powell",
                )

                # Best face on chain polytope faces
                faces = minimize_over_faces(
                    solver,
                    seed=1,
                    local_method="Powell",
                    n_local_starts_per_face=face_local_starts_per_face,
                    include_endpoints=face_include_endpoints,
                    max_faces=face_max_faces,
                )

                # Print summary line
                print(
                    f"n={n:>2}, h={h:>2}, C={C:>4}\n"
                    f"  face:     t={faces.t: .6g}, "
                    f"s={np.array2string(faces.s, precision=6, floatmode='fixed')}\n"
                    f"  interior: t={interior.t: .6g}, "
                    f"s={np.array2string(interior.s, precision=6, floatmode='fixed')}"
                )


In [20]:
import itertools
import numpy as np
from dataclasses import dataclass
from typing import Tuple, List, Optional

from scipy.optimize import linprog, minimize


# --------------------------
# Utilities
# --------------------------

def all_bitstrings(n: int) -> np.ndarray:
    """Array shape (2^n, n) containing all bitstrings in lexicographic order."""
    return np.array(list(itertools.product([0, 1], repeat=n)), dtype=int)


def all_branches(n: int, h: int) -> List[Tuple[Tuple[int, ...], Tuple[int, ...]]]:
    """All branches (A, aA) where |A|=n-h and aA is an assignment on A."""
    k = n - h
    branches = []
    for A in itertools.combinations(range(n), k):
        for aA in itertools.product([0, 1], repeat=k):
            branches.append((tuple(A), tuple(aA)))
    return branches


def sigmoid(z):
    z = np.clip(z, -40, 40)
    return 1.0 / (1.0 + np.exp(-z))


def sort_desc(s: np.ndarray) -> np.ndarray:
    return np.sort(np.asarray(s, dtype=float))[::-1]


def linprog_any(c, A_ub, b_ub, A_eq, b_eq, bounds):
    """
    Try modern HiGHS methods; fall back to older ones if SciPy doesn't have HiGHS.
    """
    methods = ("highs", "highs-ds", "highs-ipm", "interior-point", "revised simplex", "simplex")
    last_err = None
    for method in methods:
        try:
            return linprog(
                c=c, A_ub=A_ub, b_ub=b_ub,
                A_eq=A_eq, b_eq=b_eq,
                bounds=bounds,
                method=method
            )
        except ValueError as e:
            last_err = e
            continue
    if last_err is not None:
        raise last_err
    raise RuntimeError("linprog failed unexpectedly.")


# --------------------------
# Inner LP solver
# --------------------------

@dataclass
class LPValue:
    ok: bool
    t: float
    f: Optional[np.ndarray]  # shape (n, 2^n)
    msg: str


class FullProgramLPSolver:
    """
    For fixed s, solve the inner LP in (f,t).

    Variables: f_{i,a} >= 0 for i in [n], a in {0,1}^n, and t >= 0.
    Objective: minimize t.

    Constraints implemented (linear for fixed s):
      - Eq (interior equality form) for each i:
          E_{a_-i}[ f_i(1,a_-i) ] - E_{a_-i}[ f_i(0,a_-i) ] = 1
        where expectation uses product distribution over coordinates != i.
      - IR:
          sum_i s_i <= sum_a pi_s(a) * sum_i f_i(a)
      - Branch epigraph constraints for each (A,aA):
          E_{a_-A}[ sum_i f_i(aA,a_-A) + C*1[a=0] ] <= t
    """

    def __init__(self, n: int, h: int, C: float):
        assert 1 <= n
        assert 0 <= h <= n
        assert C >= 0

        self.n = n
        self.h = h
        self.C = float(C)

        self.bits = all_bitstrings(n)  # (m,n)
        self.m = self.bits.shape[0]

        self.branches = all_branches(n, h)
        self.num_branches = len(self.branches)

        self.all_zero = (self.bits.sum(axis=1) == 0).astype(float)  # (m,)

        # For each branch, store indices of a that match the fixed assignment on A
        self.branch_match_idx = []
        for (A, aA) in self.branches:
            mask = np.ones(self.m, dtype=bool)
            for pos, j in enumerate(A):
                mask &= (self.bits[:, j] == aA[pos])
            self.branch_match_idx.append(np.where(mask)[0])

    def _pi_full(self, s: np.ndarray) -> np.ndarray:
        # pi_s(a) = ∏_j s_j^{a_j} (1-s_j)^{1-a_j}
        term = np.where(self.bits == 1, s[None, :], (1.0 - s)[None, :])
        return term.prod(axis=1)  # (m,)

    def _pi_over_complement(self, s: np.ndarray, fixed: Tuple[int, ...]) -> np.ndarray:
        # product over j not in fixed
        fixed_mask = np.zeros(self.n, dtype=bool)
        fixed_mask[list(fixed)] = True
        if fixed_mask.all():
            return np.ones(self.m, dtype=float)
        term = np.where(self.bits == 1, s[None, :], (1.0 - s)[None, :])
        return term[:, ~fixed_mask].prod(axis=1)

    def solve_lp(self, s: np.ndarray, return_f: bool = False) -> LPValue:
        n, m, C = self.n, self.m, self.C
        s = np.clip(np.asarray(s, dtype=float), 0.0, 1.0)

        # x = [f_{0,0..m-1}, f_{1,0..m-1}, ..., f_{n-1,0..m-1}, t]
        N = n * m + 1
        t_idx = n * m

        # Objective min t
        c = np.zeros(N, dtype=float)
        c[t_idx] = 1.0

        # Bounds: f >= 0, t >= 0
        bounds = [(0.0, None)] * (n * m) + [(0.0, None)]

        # ----------------------
        # Equality constraints (Eq): A_eq x = b_eq
        A_eq = np.zeros((n, N), dtype=float)
        b_eq = np.ones(n, dtype=float)  # E1 - E0 = 1

        for i in range(n):
            w = self._pi_over_complement(s, (i,))  # (m,)
            mask1 = (self.bits[:, i] == 1)
            mask0 = ~mask1

            base = i * m
            A_eq[i, base + np.where(mask1)[0]] = w[mask1]
            A_eq[i, base + np.where(mask0)[0]] = -w[mask0]

        # ----------------------
        # Inequality constraints: A_ub x <= b_ub
        A_ub_rows = []
        b_ub_rows = []

        # IR: sum s_i <= sum_a pi(a) * sum_i f_i(a)
        # => - sum_a pi(a)*sum_i f_i(a) <= -sum s_i
        pi = self._pi_full(s)
        row_ir = np.zeros(N, dtype=float)
        for i in range(n):
            row_ir[i * m:(i + 1) * m] = -pi
        A_ub_rows.append(row_ir)
        b_ub_rows.append(-float(s.sum()))

        # Branch constraints: for each branch, expectation <= t
        for b, (A, _aA) in enumerate(self.branches):
            idx = self.branch_match_idx[b]
            w = self._pi_over_complement(s, A)  # (m,)

            row = np.zeros(N, dtype=float)
            for i in range(n):
                base = i * m
                row[base + idx] += w[idx]
            row[t_idx] = -1.0

            # Move constant C term to RHS:
            # sum w*(fsum + C*I0) <= t
            # sum w*fsum - t <= - sum w*(C*I0)
            rhs = -np.dot(w[idx], C * self.all_zero[idx])

            A_ub_rows.append(row)
            b_ub_rows.append(float(rhs))

        A_ub = np.vstack(A_ub_rows)
        b_ub = np.array(b_ub_rows, dtype=float)

        res = linprog_any(c, A_ub, b_ub, A_eq, b_eq, bounds)

        if not res.success:
            return LPValue(ok=False, t=float("inf"), f=None, msg=str(res.message))

        t_val = float(res.fun)
        if not return_f:
            return LPValue(ok=True, t=t_val, f=None, msg="OK")

        x = res.x
        f = x[:n * m].reshape((n, m))
        return LPValue(ok=True, t=t_val, f=f, msg="OK")


# --------------------------
# Face enumeration & optimization
# --------------------------

def blocks_from_equalities(n: int, tight_adj: Tuple[int, ...]) -> List[List[int]]:
    """
    tight_adj are 0-based adjacency indices: i in {0,...,n-2} means s_i = s_{i+1}.
    Returns contiguous blocks of equal coordinates.
    """
    tight = set(tight_adj)
    blocks = []
    cur = [0]
    for i in range(n - 1):
        if i in tight:
            cur.append(i + 1)
        else:
            blocks.append(cur)
            cur = [i + 1]
    blocks.append(cur)
    return blocks


def enumerate_faces(n: int, include_endpoints: bool = True):
    """
    Faces of chain polytope 1>=s1>=...>=sn>=0 determined by:
      - which adjacent inequalities are tight (s_i=s_{i+1})
      - optional endpoint tightness: s1=1 and/or sn=0
    """
    faces = []
    for mask in range(1 << (n - 1)):
        tight_adj = tuple(i for i in range(n - 1) if (mask >> i) & 1)
        if include_endpoints:
            for top_fixed in (False, True):
                for bottom_fixed in (False, True):
                    faces.append((tight_adj, top_fixed, bottom_fixed))
        else:
            faces.append((tight_adj, False, False))
    return faces


def face_parameterization(z: np.ndarray, k: int, top_fixed: bool, bottom_fixed: bool) -> np.ndarray:
    """
    Build monotone block-values v0>=v1>=...>=v_{k-1} in [0,1],
    with optional v0=1 and/or v_{k-1}=0.

    Parameterization:
      if top_fixed: v0=1 else v0=sigmoid(z0)
      for j>=1: vj = v_{j-1} * sigmoid(zj)   (monotone decreasing)
      if bottom_fixed: set last block exactly 0
    """
    z = np.asarray(z, dtype=float)
    v = np.zeros(k, dtype=float)
    if k == 0:
        return v

    idx = 0
    if top_fixed:
        v[0] = 1.0
    else:
        v[0] = sigmoid(z[idx]); idx += 1

    last_free = k - (1 if bottom_fixed else 0)
    for j in range(1, last_free):
        v[j] = v[j - 1] * sigmoid(z[idx])
        idx += 1

    if bottom_fixed:
        v[-1] = 0.0

    return v


def blocks_to_full_s(n: int, blocks: List[List[int]], v: np.ndarray) -> np.ndarray:
    s = np.zeros(n, dtype=float)
    for b_idx, idxs in enumerate(blocks):
        s[idxs] = v[b_idx]
    return s


@dataclass
class SearchResult:
    ok: bool
    t: float
    s: np.ndarray
    face: Optional[Tuple[Tuple[int, ...], bool, bool]]
    msg: str


def minimize_over_s_sorted(
    solver: FullProgramLPSolver,
    n_random: int = 200,
    n_local_starts: int = 8,
    seed: int = 0,
    local_method: str = "Powell",
) -> SearchResult:
    """
    Outer search over s in [0,1]^n, ALWAYS projecting to sorted decreasing order.
    """
    n = solver.n
    rng = np.random.default_rng(seed)

    best_t = float("inf")
    best_s = None

    # Random screening
    for _ in range(n_random):
        s = sort_desc(rng.random(n))
        val = solver.solve_lp(s, return_f=False)
        if val.ok and val.t < best_t:
            best_t = val.t
            best_s = s.copy()

    if best_s is None:
        best_s = np.full(n, 0.5)

    def obj(z):
        s = sort_desc(sigmoid(z))
        val = solver.solve_lp(s, return_f=False)
        return val.t if val.ok else 1e6

    # Build start points in z-space
    def inv_sig(u):
        u = np.clip(u, 1e-6, 1 - 1e-6)
        return np.log(u / (1 - u))

    starts = [inv_sig(best_s)]
    for _ in range(max(0, n_local_starts - 1)):
        s0 = sort_desc(rng.random(n))
        starts.append(inv_sig(s0))

    best_z = None
    best_local = best_t

    for z0 in starts:
        res = minimize(obj, z0, method=local_method, options={"maxiter": 400, "disp": False})
        if float(res.fun) < best_local:
            best_local = float(res.fun)
            best_z = res.x.copy()

    s_star = sort_desc(sigmoid(best_z)) if best_z is not None else best_s
    lp_star = solver.solve_lp(s_star, return_f=False)
    return SearchResult(lp_star.ok, lp_star.t, s_star, None, lp_star.msg)


def minimize_over_faces(
    solver: FullProgramLPSolver,
    seed: int = 0,
    local_method: str = "Powell",
    n_local_starts_per_face: int = 3,
    include_endpoints: bool = True,
    max_faces: Optional[int] = None,
) -> SearchResult:
    """
    Enumerate and optimize over faces of the chain polytope 1>=s1>=...>=sn>=0.

    Face is determined by (tight_adj, top_fixed, bottom_fixed).
    """
    n = solver.n
    rng = np.random.default_rng(seed)

    faces = enumerate_faces(n, include_endpoints=include_endpoints)
    if max_faces is not None:
        faces = faces[:max_faces]

    best = SearchResult(False, float("inf"), np.zeros(n), None, "no feasible face found")

    for (tight_adj, top_fixed, bottom_fixed) in faces:
        blocks = blocks_from_equalities(n, tight_adj)
        k = len(blocks)

        # If k==1 and both endpoints fixed => impossible unless 1==0; skip
        if k == 1 and top_fixed and bottom_fixed:
            continue

        # Dimension of z:
        dim = 0
        if not top_fixed:
            dim += 1
        dim += max(0, (k - 1) - (1 if bottom_fixed else 0))

        if dim == 0:
            # fully fixed (rare); evaluate directly
            v = np.zeros(k)
            if top_fixed:
                v[0] = 1.0
            if bottom_fixed:
                v[-1] = 0.0
            s = blocks_to_full_s(n, blocks, v)
            val = solver.solve_lp(s, return_f=False)
            if val.ok and val.t < best.t:
                best = SearchResult(True, val.t, s, (tight_adj, top_fixed, bottom_fixed), "OK")
            continue

        def obj(z):
            v = face_parameterization(z, k, top_fixed, bottom_fixed)
            s = blocks_to_full_s(n, blocks, v)
            # already satisfies ordering by construction
            val = solver.solve_lp(s, return_f=False)
            return val.t if val.ok else 1e6

        for _ in range(n_local_starts_per_face):
            z0 = rng.normal(size=dim)
            res = minimize(obj, z0, method=local_method, options={"maxiter": 300, "disp": False})
            if float(res.fun) < best.t:
                v = face_parameterization(res.x, k, top_fixed, bottom_fixed)
                s = blocks_to_full_s(n, blocks, v)
                best = SearchResult(True, float(res.fun), s, (tight_adj, top_fixed, bottom_fixed), "OK")

    return best

In [56]:
import itertools
import numpy as np
from dataclasses import dataclass
from typing import Tuple, List, Optional

from scipy.optimize import linprog, minimize
from scipy.sparse import coo_matrix, csr_matrix


# --------------------------
# Utilities
# --------------------------

def all_bitstrings(n: int) -> np.ndarray:
    return np.array(list(itertools.product([0, 1], repeat=n)), dtype=int)


def all_branches(n: int, h: int) -> List[Tuple[Tuple[int, ...], Tuple[int, ...]]]:
    k = n - h
    branches = []
    for A in itertools.combinations(range(n), k):
        for aA in itertools.product([0, 1], repeat=k):
            branches.append((tuple(A), tuple(aA)))
    return branches


def sigmoid(z):
    z = np.clip(z, -40, 40)
    return 1.0 / (1.0 + np.exp(-z))


def sort_desc(s: np.ndarray) -> np.ndarray:
    return np.sort(np.asarray(s, dtype=float))[::-1]


def linprog_any(c, A_ub, b_ub, A_eq, b_eq, bounds):
    methods = ("highs", "highs-ds", "highs-ipm", "interior-point", "revised simplex", "simplex")
    last_err = None
    for method in methods:
        try:
            return linprog(
                c=c, A_ub=A_ub, b_ub=b_ub,
                A_eq=A_eq, b_eq=b_eq,
                bounds=bounds,
                method=method
            )
        except ValueError as e:
            last_err = e
            continue
    if last_err is not None:
        raise last_err
    raise RuntimeError("linprog failed unexpectedly.")


def cache_key_from_s(s: np.ndarray, *, tol: float = 1e-4) -> tuple:
    s = np.clip(np.asarray(s, dtype=float), 0.0, 1.0)
    q = np.round(s / tol).astype(np.int64)
    diffs = np.diff(q)
    return (int(q[0]), *map(int, diffs))


# --------------------------
# Inner LP solver
# --------------------------

@dataclass
class LPValue:
    ok: bool
    t: float
    f: Optional[np.ndarray]
    msg: str


class FullProgramLPSolver:
    """
    Fixes the build_ub bottleneck:
      - precompute A_ub CSR structure ONCE
      - per solve, only fill COO-order data vector then permute into CSR-order data
      - reuse indices/indptr (no COO->CSR conversion per miss)

    Also uses cheap weight computation from term = where(bits, s, 1-s):
      - pi = prod(term, axis=1)
      - eq weights: w_i = pi / term[:, i]
      - branch weights: w_A = prod(term[:, comp(A)], axis=1)   where comp(A) has size h
    """

    def __init__(
        self,
        n: int,
        h: int,
        C: float,
        cache_tol: float = 1e-4,
        cache_max: int = 200_000,
        profile: bool = True,
    ):
        self.n = int(n)
        self.h = int(h)
        self.C = float(C)

        self.bits = all_bitstrings(self.n)  # (m,n)
        self.m = self.bits.shape[0]
        self.all_zero = (self.bits.sum(axis=1) == 0).astype(float)  # (m,)

        self.branches = all_branches(self.n, self.h)
        self.num_branches = len(self.branches)

        # Variable layout
        self.N = self.n * self.m + 1
        self.t_idx = self.n * self.m
        self.bounds = [(0.0, None)] * (self.n * self.m) + [(0.0, None)]
        self.c = np.zeros(self.N, dtype=float)
        self.c[self.t_idx] = 1.0

        # Precompute eq index sets
        self.eq_idx1 = [np.where(self.bits[:, i] == 1)[0] for i in range(self.n)]
        self.eq_idx0 = [np.where(self.bits[:, i] == 0)[0] for i in range(self.n)]
        self.eq_cols1 = [i * self.m + self.eq_idx1[i] for i in range(self.n)]
        self.eq_cols0 = [i * self.m + self.eq_idx0[i] for i in range(self.n)]

        # Branch precompute:
        # - idx of a matching (A,aA)
        # - comp(A) columns (size h)
        # - whether idx contains 0 (all-zero bitstring, which is row 0)
        all_idx = np.arange(self.n)
        self.branch_match_idx: List[np.ndarray] = []
        self.branch_cols_per_i: List[List[np.ndarray]] = []
        self.branch_comp_cols: List[np.ndarray] = []
        self.branch_has_zero = np.zeros(self.num_branches, dtype=bool)

        for b, (A, aA) in enumerate(self.branches):
            mask = np.ones(self.m, dtype=bool)
            for pos, j in enumerate(A):
                mask &= (self.bits[:, j] == aA[pos])
            idx = np.where(mask)[0]
            self.branch_match_idx.append(idx)
            self.branch_cols_per_i.append([i * self.m + idx for i in range(self.n)])

            Aset = set(A)
            comp = np.array([j for j in all_idx if j not in Aset], dtype=int)  # size h
            self.branch_comp_cols.append(comp)

            self.branch_has_zero[b] = (idx.size > 0 and idx[0] == 0) or (0 in set(idx.tolist()))

        # ----------------------
        # Precompute A_ub pattern in COO order (rows, cols) + slices
        # ----------------------
        self.num_ub = 1 + self.num_branches

        total_nnz = self.n * self.m
        for idx in self.branch_match_idx:
            total_nnz += self.n * idx.size + 1

        self._ub_rows = np.empty(total_nnz, dtype=int)
        self._ub_cols = np.empty(total_nnz, dtype=int)

        self._ir_slices: List[slice] = []
        self._branch_slices: List[List[slice]] = []
        self._t_pos = np.empty(self.num_branches, dtype=int)

        off = 0
        # IR row pattern (row 0)
        for i in range(self.n):
            cols = i * self.m + np.arange(self.m, dtype=int)
            self._ub_rows[off:off + self.m] = 0
            self._ub_cols[off:off + self.m] = cols
            self._ir_slices.append(slice(off, off + self.m))
            off += self.m

        # Branch row patterns
        for b in range(self.num_branches):
            row = 1 + b
            idx = self.branch_match_idx[b]
            b_slices = []
            for i in range(self.n):
                cols = self.branch_cols_per_i[b][i]
                L = cols.size
                self._ub_rows[off:off + L] = row
                self._ub_cols[off:off + L] = cols
                b_slices.append(slice(off, off + L))
                off += L

            # -t entry
            self._ub_rows[off] = row
            self._ub_cols[off] = self.t_idx
            self._t_pos[b] = off
            off += 1
            self._branch_slices.append(b_slices)

        assert off == total_nnz

        # ----------------------
        # Build CSR template ONCE + permutation mapping COO-order -> CSR-order
        # Trick: put data = 0..nnz-1, convert to CSR, then csr.data are those ids in CSR order.
        # ----------------------
        ids = np.arange(total_nnz, dtype=float)
        csr_template = coo_matrix((ids, (self._ub_rows, self._ub_cols)),
                                  shape=(self.num_ub, self.N)).tocsr()

        # perm[k] = which COO entry sits at CSR position k
        self._ub_perm = csr_template.data.astype(np.int64)
        self._ub_indices = csr_template.indices.copy()
        self._ub_indptr = csr_template.indptr.copy()

        # A_eq is tiny; keep building it per miss (it wasn't your bottleneck)

        # Cache
        self.cache_tol = float(cache_tol)
        self._cache: dict[tuple, LPValue] = {}
        self._cache_max = int(cache_max)

        # Profiling
        self.profile = bool(profile)
        self._prof = {
            "calls": 0,
            "cache_hits": 0,
            "build_eq_s": 0.0,
            "build_ub_s": 0.0,
            "linprog_s": 0.0,
            "total_s": 0.0,
        }

    def solve_lp(self, s: np.ndarray, return_f: bool = False) -> LPValue:
        import time
        t0 = time.perf_counter()
        if self.profile:
            self._prof["calls"] += 1

        s = np.clip(np.asarray(s, dtype=float), 0.0, 1.0)

        key = cache_key_from_s(s, tol=self.cache_tol)
        cached = self._cache.get(key)
        if cached is not None and ((not return_f) or (cached.f is not None)):
            if self.profile:
                self._prof["cache_hits"] += 1
                self._prof["total_s"] += time.perf_counter() - t0
            return cached

        n, m, C = self.n, self.m, self.C
        N = self.N

        # Precompute term + pi once
        s_safe = np.clip(s, 1e-12, 1 - 1e-12)
        term = np.where(self.bits == 1, s_safe[None, :], (1.0 - s_safe)[None, :])  # (m,n)
        pi = term.prod(axis=1)  # (m,)

        # ----------------------
        # Build A_eq (sparse)
        # ----------------------
        t_eq0 = time.perf_counter()
        eq_rows = []
        eq_cols = []
        eq_data = []

        for i in range(n):
            w_i = pi / term[:, i]
            idx1 = self.eq_idx1[i]
            idx0 = self.eq_idx0[i]
            cols1 = self.eq_cols1[i]
            cols0 = self.eq_cols0[i]

            eq_rows.append(np.full(cols1.shape[0], i, dtype=int))
            eq_cols.append(cols1.astype(int))
            eq_data.append(w_i[idx1])

            eq_rows.append(np.full(cols0.shape[0], i, dtype=int))
            eq_cols.append(cols0.astype(int))
            eq_data.append(-w_i[idx0])

        A_eq = coo_matrix(
            (np.concatenate(eq_data), (np.concatenate(eq_rows), np.concatenate(eq_cols))),
            shape=(n, N),
        ).tocsr()
        b_eq = np.ones(n, dtype=float)

        if self.profile:
            self._prof["build_eq_s"] += time.perf_counter() - t_eq0

        # ----------------------
        # Build A_ub by ONLY filling data (no COO->CSR conversion)
        # ----------------------
        t_ub0 = time.perf_counter()

        data_coo = np.empty_like(self._ub_rows, dtype=float)
        b_ub = np.zeros(self.num_ub, dtype=float)

        # IR row blocks: fill -pi
        for sl in self._ir_slices:
            data_coo[sl] = -pi
        b_ub[0] = -float(s.sum())

        # Branch rows
        for b in range(self.num_branches):
            comp = self.branch_comp_cols[b]  # size h
            if comp.size == 0:
                w = np.ones(m, dtype=float)
            else:
                w = term[:, comp].prod(axis=1)  # (m,)

            idx = self.branch_match_idx[b]
            w_idx = w[idx]

            for i in range(n):
                data_coo[self._branch_slices[b][i]] = w_idx

            data_coo[self._t_pos[b]] = -1.0

            # RHS constant uses only the all-zero assignment (bitstring 000...0 is index 0)
            if self.branch_has_zero[b]:
                b_ub[1 + b] = -C * float(w[0])
            else:
                b_ub[1 + b] = 0.0

        # Reorder data into CSR order using perm computed once
        data_csr = data_coo[self._ub_perm]

        A_ub = csr_matrix((data_csr, self._ub_indices, self._ub_indptr), shape=(self.num_ub, self.N))

        if self.profile:
            self._prof["build_ub_s"] += time.perf_counter() - t_ub0

        # ----------------------
        # Solve LP
        # ----------------------
        t_lp0 = time.perf_counter()
        res = linprog_any(self.c, A_ub, b_ub, A_eq, b_eq, self.bounds)
        if self.profile:
            self._prof["linprog_s"] += time.perf_counter() - t_lp0

        if not res.success:
            val = LPValue(ok=False, t=float("inf"), f=None, msg=str(res.message))
            if len(self._cache) < self._cache_max:
                self._cache[key] = val
            if self.profile:
                self._prof["total_s"] += time.perf_counter() - t0
            return val

        t_val = float(res.fun)
        if not return_f:
            val = LPValue(ok=True, t=t_val, f=None, msg="OK")
            if len(self._cache) < self._cache_max:
                self._cache[key] = val
            if self.profile:
                self._prof["total_s"] += time.perf_counter() - t0
            return val

        x = res.x
        f = x[: n * m].reshape((n, m))
        val = LPValue(ok=True, t=t_val, f=f, msg="OK")
        if len(self._cache) < self._cache_max:
            self._cache[key] = val
        if self.profile:
            self._prof["total_s"] += time.perf_counter() - t0
        return val

    def print_profile(self, label: str = ""):
        p = self._prof
        calls = p["calls"]
        hits = p["cache_hits"]
        misses = calls - hits
        head = f"=== solve_lp profile {label} ===".strip()
        print("\n" + head)
        print(f"calls:      {calls}")
        print(f"cache_hits: {hits} ({hits / max(1, calls):.1%})")
        print(f"misses:     {misses} ({misses / max(1, calls):.1%})")
        print(f"build_eq:   {p['build_eq_s']:.3f}s")
        print(f"build_ub:   {p['build_ub_s']:.3f}s")
        print(f"linprog:    {p['linprog_s']:.3f}s")
        print(f"total:      {p['total_s']:.3f}s")


# --------------------------
# Outer optimization
# --------------------------

def blocks_from_equalities(n: int, tight_adj: Tuple[int, ...]) -> List[List[int]]:
    tight = set(tight_adj)
    blocks = []
    cur = [0]
    for i in range(n - 1):
        if i in tight:
            cur.append(i + 1)
        else:
            blocks.append(cur)
            cur = [i + 1]
    blocks.append(cur)
    return blocks


def enumerate_faces(n: int, include_endpoints: bool = True):
    faces = []
    for mask in range(1 << (n - 1)):
        tight_adj = tuple(i for i in range(n - 1) if (mask >> i) & 1)
        if include_endpoints:
            for top_fixed in (False, True):
                for bottom_fixed in (False, True):
                    faces.append((tight_adj, top_fixed, bottom_fixed))
        else:
            faces.append((tight_adj, False, False))
    return faces


def face_parameterization(z: np.ndarray, k: int, top_fixed: bool, bottom_fixed: bool) -> np.ndarray:
    z = np.asarray(z, dtype=float)
    v = np.zeros(k, dtype=float)
    if k == 0:
        return v

    idx = 0
    if top_fixed:
        v[0] = 1.0
    else:
        v[0] = sigmoid(z[idx]); idx += 1

    last_free = k - (1 if bottom_fixed else 0)
    for j in range(1, last_free):
        v[j] = v[j - 1] * sigmoid(z[idx])
        idx += 1

    if bottom_fixed:
        v[-1] = 0.0

    return v


def blocks_to_full_s(n: int, blocks: List[List[int]], v: np.ndarray) -> np.ndarray:
    s = np.zeros(n, dtype=float)
    for b_idx, idxs in enumerate(blocks):
        s[idxs] = v[b_idx]
    return s


@dataclass
class SearchResult:
    ok: bool
    t: float
    s: np.ndarray
    face: Optional[Tuple[Tuple[int, ...], bool, bool]]
    msg: str


def minimize_over_s_sorted(
    solver: FullProgramLPSolver,
    n_random: int = 120,
    n_local_starts: int = 3,
    seed: int = 0,
    local_method: str = "Powell",
    maxiter: int = 250,
) -> SearchResult:
    n = solver.n
    rng = np.random.default_rng(seed)

    best_t = float("inf")
    best_s = None

    for _ in range(n_random):
        s = sort_desc(rng.random(n))
        val = solver.solve_lp(s, return_f=False)
        if val.ok and val.t < best_t:
            best_t = val.t
            best_s = s.copy()

    if best_s is None:
        best_s = np.full(n, 0.5)

    def obj(z):
        s = sort_desc(sigmoid(z))
        val = solver.solve_lp(s, return_f=False)
        return val.t if val.ok else 1e6

    def inv_sig(u):
        u = np.clip(u, 1e-6, 1 - 1e-6)
        return np.log(u / (1 - u))

    starts = [inv_sig(best_s)]
    for _ in range(max(0, n_local_starts - 1)):
        s0 = sort_desc(rng.random(n))
        starts.append(inv_sig(s0))

    best_z = None
    best_local = best_t

    for z0 in starts:
        res = minimize(obj, z0, method=local_method, options={"maxiter": maxiter, "disp": False})
        if float(res.fun) < best_local:
            best_local = float(res.fun)
            best_z = res.x.copy()

    s_star = sort_desc(sigmoid(best_z)) if best_z is not None else best_s
    lp_star = solver.solve_lp(s_star, return_f=False)
    return SearchResult(lp_star.ok, lp_star.t, s_star, None, lp_star.msg)


def minimize_over_faces(
    solver: FullProgramLPSolver,
    seed: int = 0,
    local_method: str = "Powell",
    n_local_starts_per_face: int = 1,
    include_endpoints: bool = True,
    max_faces: Optional[int] = None,
    maxiter: int = 120,
) -> SearchResult:
    n = solver.n
    rng = np.random.default_rng(seed)

    faces = enumerate_faces(n, include_endpoints=include_endpoints)
    if max_faces is not None:
        faces = faces[:max_faces]

    best = SearchResult(False, float("inf"), np.zeros(n), None, "no feasible face found")

    for (tight_adj, top_fixed, bottom_fixed) in faces:
        blocks = blocks_from_equalities(n, tight_adj)
        k = len(blocks)

        if k == 1 and top_fixed and bottom_fixed:
            continue

        dim = 0
        if not top_fixed:
            dim += 1
        dim += max(0, (k - 1) - (1 if bottom_fixed else 0))

        if dim == 0:
            v = np.zeros(k)
            if top_fixed:
                v[0] = 1.0
            if bottom_fixed:
                v[-1] = 0.0
            s = blocks_to_full_s(n, blocks, v)
            val = solver.solve_lp(s, return_f=False)
            if val.ok and val.t < best.t:
                best = SearchResult(True, val.t, s, (tight_adj, top_fixed, bottom_fixed), "OK")
            continue

        def obj(z):
            v = face_parameterization(z, k, top_fixed, bottom_fixed)
            s = blocks_to_full_s(n, blocks, v)
            val = solver.solve_lp(s, return_f=False)
            return val.t if val.ok else 1e6

        for _ in range(n_local_starts_per_face):
            z0 = rng.normal(size=dim)
            res = minimize(obj, z0, method=local_method, options={"maxiter": maxiter, "disp": False})
            if float(res.fun) < best.t:
                v = face_parameterization(res.x, k, top_fixed, bottom_fixed)
                s = blocks_to_full_s(n, blocks, v)
                best = SearchResult(True, float(res.fun), s, (tight_adj, top_fixed, bottom_fixed), "OK")

    return best


# --------------------------
# Run ONLY n=6, C=7
# --------------------------



In [57]:
n = 6
C = 7.0

interior_n_random = 120
interior_n_local_starts = 3
interior_maxiter = 250

face_local_starts_per_face = 1
face_include_endpoints = True
face_max_faces = None
face_maxiter = 120

for h in range(1, n):
    solver = FullProgramLPSolver(n=n, h=h, C=C, cache_tol=1e-4, cache_max=200_000, profile=True)

    interior = minimize_over_s_sorted(
        solver,
        n_random=interior_n_random,
        n_local_starts=interior_n_local_starts,
        seed=1,
        local_method="Powell",
        maxiter=interior_maxiter,
    )

    faces = minimize_over_faces(
        solver,
        seed=1,
        local_method="Powell",
        n_local_starts_per_face=face_local_starts_per_face,
        include_endpoints=face_include_endpoints,
        max_faces=face_max_faces,
        maxiter=face_maxiter,
    )

    print(
        f"\nRESULT n={n}, h={h}, C={C}\n"
        f"  face:     t={faces.t: .6g}, s={np.array2string(faces.s, precision=6, floatmode='fixed')}\n"
        f"  interior: t={interior.t: .6g}, s={np.array2string(interior.s, precision=6, floatmode='fixed')}"
    )

    solver.print_profile(label=f"(n={n}, h={h}, C={C})")


RESULT n=6, h=1, C=7.0
  face:     t= 3.27024, s=[0.532799 0.532799 0.532799 0.532799 0.532799 0.532799]
  interior: t= 3.56722, s=[1.000000 0.581573 0.502470 0.499465 0.493272 0.490409]

=== solve_lp profile (n=6, h=1, C=7.0) ===
calls:      15430
cache_hits: 8195 (53.1%)
misses:     7235 (46.9%)
build_eq:   1.486s
build_ub:   10.238s
linprog:    18.201s
total:      30.773s

RESULT n=6, h=2, C=7.0
  face:     t= 2.82424, s=[1.000000 0.364764 0.364764 0.364764 0.364764 0.364764]
  interior: t= 2.94924, s=[0.999363 0.421074 0.418034 0.409483 0.379983 0.320468]

=== solve_lp profile (n=6, h=2, C=7.0) ===
calls:      23889
cache_hits: 13075 (54.7%)
misses:     10814 (45.3%)
build_eq:   2.334s
build_ub:   20.277s
linprog:    40.348s
total:      64.307s

RESULT n=6, h=3, C=7.0
  face:     t= 2.46808, s=[1.000000 0.293561 0.293561 0.293561 0.293561 0.293561]
  interior: t= 2.65997, s=[1.000000 0.432503 0.430735 0.430709 0.185300 0.180692]

=== solve_lp profile (n=6, h=3, C=7.0) ===
calls:  

In [None]:
import itertools
import numpy as np
from dataclasses import dataclass
from typing import Tuple, List, Optional

from scipy.optimize import linprog, minimize
from scipy.sparse import coo_matrix, csr_matrix


# --------------------------
# Utilities
# --------------------------

def all_bitstrings(n: int) -> np.ndarray:
    return np.array(list(itertools.product([0, 1], repeat=n)), dtype=int)


def all_branches(n: int, h: int) -> List[Tuple[Tuple[int, ...], Tuple[int, ...]]]:
    k = n - h
    branches = []
    for A in itertools.combinations(range(n), k):
        for aA in itertools.product([0, 1], repeat=k):
            branches.append((tuple(A), tuple(aA)))
    return branches


def sigmoid(z):
    z = np.clip(z, -40, 40)
    return 1.0 / (1.0 + np.exp(-z))


def inv_sig(u):
    u = np.clip(u, 1e-12, 1 - 1e-12)
    return np.log(u / (1 - u))


def sort_desc(s: np.ndarray) -> np.ndarray:
    return np.sort(np.asarray(s, dtype=float))[::-1]


def linprog_any(c, A_ub, b_ub, A_eq, b_eq, bounds):
    # Force HiGHS only (less overhead, consistent performance)
    return linprog(
        c=c, A_ub=A_ub, b_ub=b_ub,
        A_eq=A_eq, b_eq=b_eq,
        bounds=bounds,
        method="highs"
    )


def cache_key_from_s(s: np.ndarray, *, tol: float = 1e-4) -> tuple:
    s = np.clip(np.asarray(s, dtype=float), 0.0, 1.0)
    q = np.round(s / tol).astype(np.int64)
    diffs = np.diff(q)
    return (int(q[0]), *map(int, diffs))


# --------------------------
# Inner LP solver (fast A_ub CSR reuse)
# --------------------------

@dataclass
class LPValue:
    ok: bool
    t: float
    f: Optional[np.ndarray]
    msg: str


class FullProgramLPSolver:
    """
    Fast LP assembly:
      - precompute A_ub sparsity in CSR once
      - each solve fills a COO-order data vector then permutes into CSR-order data
      - avoids COO->CSR conversion per miss

    Uses:
      term[a,j] = s_j if bit=1 else (1-s_j)
      pi[a] = prod_j term[a,j]
      eq weight for i: w_i[a] = pi[a] / term[a,i]
      branch weight for A: w_A[a] = prod_{j in comp(A)} term[a,j]
    """

    def __init__(
        self,
        n: int,
        h: int,
        C: float,
        cache_tol: float = 1e-4,
        cache_max: int = 200_000,
        profile: bool = True,
    ):
        self.n = int(n)
        self.h = int(h)
        self.C = float(C)

        self.bits = all_bitstrings(self.n)  # (m,n)
        self.m = self.bits.shape[0]
        self.all_zero = (self.bits.sum(axis=1) == 0).astype(float)  # (m,)

        self.branches = all_branches(self.n, self.h)
        self.num_branches = len(self.branches)

        # Variable layout
        self.N = self.n * self.m + 1
        self.t_idx = self.n * self.m
        self.bounds = [(0.0, None)] * (self.n * self.m) + [(0.0, None)]
        self.c = np.zeros(self.N, dtype=float)
        self.c[self.t_idx] = 1.0

        # Precompute eq index sets
        self.eq_idx1 = [np.where(self.bits[:, i] == 1)[0] for i in range(self.n)]
        self.eq_idx0 = [np.where(self.bits[:, i] == 0)[0] for i in range(self.n)]
        self.eq_cols1 = [i * self.m + self.eq_idx1[i] for i in range(self.n)]
        self.eq_cols0 = [i * self.m + self.eq_idx0[i] for i in range(self.n)]

        # Branch precompute:
        all_idx = np.arange(self.n)
        self.branch_match_idx: List[np.ndarray] = []
        self.branch_cols_per_i: List[List[np.ndarray]] = []
        self.branch_comp_cols: List[np.ndarray] = []
        self.branch_has_zero = np.zeros(self.num_branches, dtype=bool)

        for b, (A, aA) in enumerate(self.branches):
            mask = np.ones(self.m, dtype=bool)
            for pos, j in enumerate(A):
                mask &= (self.bits[:, j] == aA[pos])
            idx = np.where(mask)[0]
            self.branch_match_idx.append(idx)
            self.branch_cols_per_i.append([i * self.m + idx for i in range(self.n)])

            Aset = set(A)
            comp = np.array([j for j in all_idx if j not in Aset], dtype=int)  # size h
            self.branch_comp_cols.append(comp)

            self.branch_has_zero[b] = (0 in set(idx.tolist()))

        # ----------------------
        # Precompute A_ub pattern in COO order + slices
        # ----------------------
        self.num_ub = 1 + self.num_branches

        total_nnz = self.n * self.m
        for idx in self.branch_match_idx:
            total_nnz += self.n * idx.size + 1

        self._ub_rows = np.empty(total_nnz, dtype=int)
        self._ub_cols = np.empty(total_nnz, dtype=int)

        self._ir_slices: List[slice] = []
        self._branch_slices: List[List[slice]] = []
        self._t_pos = np.empty(self.num_branches, dtype=int)

        off = 0
        # IR row pattern (row 0)
        for i in range(self.n):
            cols = i * self.m + np.arange(self.m, dtype=int)
            self._ub_rows[off:off + self.m] = 0
            self._ub_cols[off:off + self.m] = cols
            self._ir_slices.append(slice(off, off + self.m))
            off += self.m

        # Branch row patterns
        for b in range(self.num_branches):
            row = 1 + b
            idx = self.branch_match_idx[b]
            b_slices = []
            for i in range(self.n):
                cols = self.branch_cols_per_i[b][i]
                L = cols.size
                self._ub_rows[off:off + L] = row
                self._ub_cols[off:off + L] = cols
                b_slices.append(slice(off, off + L))
                off += L

            # -t entry
            self._ub_rows[off] = row
            self._ub_cols[off] = self.t_idx
            self._t_pos[b] = off
            off += 1
            self._branch_slices.append(b_slices)

        assert off == total_nnz

        # Build CSR template once; compute permutation COO->CSR
        ids = np.arange(total_nnz, dtype=float)
        csr_template = coo_matrix((ids, (self._ub_rows, self._ub_cols)),
                                  shape=(self.num_ub, self.N)).tocsr()
        self._ub_perm = csr_template.data.astype(np.int64)
        self._ub_indices = csr_template.indices.copy()
        self._ub_indptr = csr_template.indptr.copy()

        # Cache
        self.cache_tol = float(cache_tol)
        self._cache: dict[tuple, LPValue] = {}
        self._cache_max = int(cache_max)

        # Profiling
        self.profile = bool(profile)
        self._prof = {
            "calls": 0,
            "cache_hits": 0,
            "build_eq_s": 0.0,
            "build_ub_s": 0.0,
            "linprog_s": 0.0,
            "total_s": 0.0,
        }

    def solve_lp(self, s: np.ndarray, return_f: bool = False) -> LPValue:
        import time
        t0 = time.perf_counter()
        if self.profile:
            self._prof["calls"] += 1

        s = np.clip(np.asarray(s, dtype=float), 0.0, 1.0)

        key = cache_key_from_s(s, tol=self.cache_tol)
        cached = self._cache.get(key)
        if cached is not None and ((not return_f) or (cached.f is not None)):
            if self.profile:
                self._prof["cache_hits"] += 1
                self._prof["total_s"] += time.perf_counter() - t0
            return cached

        n, m, C = self.n, self.m, self.C
        N = self.N

        # Precompute term + pi once
        s_safe = np.clip(s, 1e-12, 1 - 1e-12)
        term = np.where(self.bits == 1, s_safe[None, :], (1.0 - s_safe)[None, :])  # (m,n)
        pi = term.prod(axis=1)  # (m,)

        # ----------------------
        # Build A_eq (small; keep as COO->CSR)
        # ----------------------
        t_eq0 = time.perf_counter()
        eq_rows = []
        eq_cols = []
        eq_data = []

        for i in range(n):
            w_i = pi / term[:, i]
            idx1 = self.eq_idx1[i]
            idx0 = self.eq_idx0[i]
            cols1 = self.eq_cols1[i]
            cols0 = self.eq_cols0[i]

            eq_rows.append(np.full(cols1.shape[0], i, dtype=int))
            eq_cols.append(cols1.astype(int))
            eq_data.append(w_i[idx1])

            eq_rows.append(np.full(cols0.shape[0], i, dtype=int))
            eq_cols.append(cols0.astype(int))
            eq_data.append(-w_i[idx0])

        A_eq = coo_matrix(
            (np.concatenate(eq_data), (np.concatenate(eq_rows), np.concatenate(eq_cols))),
            shape=(n, N),
        ).tocsr()
        b_eq = np.ones(n, dtype=float)
        if self.profile:
            self._prof["build_eq_s"] += time.perf_counter() - t_eq0

        # ----------------------
        # Build A_ub by filling data only
        # ----------------------
        t_ub0 = time.perf_counter()

        data_coo = np.empty_like(self._ub_rows, dtype=float)
        b_ub = np.zeros(self.num_ub, dtype=float)

        # IR row: -pi blocks
        for sl in self._ir_slices:
            data_coo[sl] = -pi
        b_ub[0] = -float(s.sum())

        # Branch rows
        for b in range(self.num_branches):
            comp = self.branch_comp_cols[b]  # size h
            if comp.size == 0:
                w = np.ones(m, dtype=float)
            else:
                w = term[:, comp].prod(axis=1)  # (m,)

            idx = self.branch_match_idx[b]
            w_idx = w[idx]

            for i in range(n):
                data_coo[self._branch_slices[b][i]] = w_idx

            data_coo[self._t_pos[b]] = -1.0

            # RHS constant uses only all-zero assignment (index 0)
            if self.branch_has_zero[b]:
                b_ub[1 + b] = -C * float(w[0])
            else:
                b_ub[1 + b] = 0.0

        data_csr = data_coo[self._ub_perm]
        A_ub = csr_matrix((data_csr, self._ub_indices, self._ub_indptr), shape=(self.num_ub, self.N))

        if self.profile:
            self._prof["build_ub_s"] += time.perf_counter() - t_ub0

        # ----------------------
        # Solve LP
        # ----------------------
        t_lp0 = time.perf_counter()
        res = linprog_any(self.c, A_ub, b_ub, A_eq, b_eq, self.bounds)
        if self.profile:
            self._prof["linprog_s"] += time.perf_counter() - t_lp0

        if not res.success:
            val = LPValue(ok=False, t=float("inf"), f=None, msg=str(res.message))
            if len(self._cache) < self._cache_max:
                self._cache[key] = val
            if self.profile:
                self._prof["total_s"] += time.perf_counter() - t0
            return val

        t_val = float(res.fun)
        if not return_f:
            val = LPValue(ok=True, t=t_val, f=None, msg="OK")
            if len(self._cache) < self._cache_max:
                self._cache[key] = val
            if self.profile:
                self._prof["total_s"] += time.perf_counter() - t0
            return val

        x = res.x
        f = x[: n * m].reshape((n, m))
        val = LPValue(ok=True, t=t_val, f=f, msg="OK")
        if len(self._cache) < self._cache_max:
            self._cache[key] = val
        if self.profile:
            self._prof["total_s"] += time.perf_counter() - t0
        return val

    def print_profile(self, label: str = ""):
        p = self._prof
        calls = p["calls"]
        hits = p["cache_hits"]
        misses = calls - hits
        head = f"=== solve_lp profile {label} ===".strip()
        print("\n" + head)
        print(f"calls:      {calls}")
        print(f"cache_hits: {hits} ({hits / max(1, calls):.1%})")
        print(f"misses:     {misses} ({misses / max(1, calls):.1%})")
        print(f"build_eq:   {p['build_eq_s']:.3f}s")
        print(f"build_ub:   {p['build_ub_s']:.3f}s")
        print(f"linprog:    {p['linprog_s']:.3f}s")
        print(f"total:      {p['total_s']:.3f}s")


# --------------------------
# Three-block outer search: 1^d t^f 0^z
# --------------------------

@dataclass
class SearchResult:
    ok: bool
    t: float
    s: np.ndarray
    tag: Optional[Tuple[int, int, int]]  # (d,f,z)
    msg: str


def three_block_s(n: int, d: int, z: int, t: float) -> np.ndarray:
    f = n - d - z
    s = np.empty(n, dtype=float)
    s[:d] = 1.0
    s[d:d+f] = float(t)
    s[d+f:] = 0.0
    return s


def minimize_over_three_blocks(
    solver: FullProgramLPSolver,
    seed: int = 0,
    grid_size: int = 21,          # coarse grid over t in [0,1]
    n_local_starts: int = 2,      # Powell starts per (d,z)
    maxiter: int = 60,            # Powell maxiter per start (1D so cheap)
) -> SearchResult:
    n = solver.n
    rng = np.random.default_rng(seed)

    best = SearchResult(False, float("inf"), np.zeros(n), None, "no feasible 3-block found")

    # Iterate all (d,z) with f >= 1
    for d in range(0, n + 1):
        for z in range(0, n - d + 1):
            f = n - d - z
            if f <= 0:
                continue

            # Coarse grid screening
            ts = np.linspace(0.0, 1.0, grid_size)
            best_grid_t = 0.5
            best_grid_val = float("inf")
            best_grid_ok = False

            for t0 in ts:
                s0 = three_block_s(n, d, z, t0)
                val = solver.solve_lp(s0, return_f=False)
                if val.ok and val.t < best_grid_val:
                    best_grid_val = val.t
                    best_grid_t = t0
                    best_grid_ok = True

            # If nothing feasible on grid, still try local starts (sometimes feasibility is narrow)
            starts_z = []
            if best_grid_ok:
                starts_z.append(inv_sig(best_grid_t))
            else:
                starts_z.append(inv_sig(0.5))

            # Additional random starts
            for _ in range(max(0, n_local_starts - len(starts_z))):
                t_rand = rng.random()
                starts_z.append(inv_sig(t_rand))

            # 1D Powell refinement in z-space (t = sigmoid(z))
            def obj(zvec):
                t = sigmoid(zvec[0])
                s = three_block_s(n, d, z, t)
                val = solver.solve_lp(s, return_f=False)
                return val.t if val.ok else 1e6

            best_local_val = best_grid_val
            best_local_z = None

            for z0 in starts_z:
                res = minimize(
                    obj,
                    x0=np.array([z0], dtype=float),
                    method="Powell",
                    options={"maxiter": maxiter, "disp": False},
                )
                if float(res.fun) < best_local_val:
                    best_local_val = float(res.fun)
                    best_local_z = float(res.x[0])

            # Evaluate best candidate for this (d,z)
            if best_local_z is not None:
                t_star = sigmoid(best_local_z)
            else:
                t_star = best_grid_t

            s_star = three_block_s(n, d, z, t_star)
            lp_star = solver.solve_lp(s_star, return_f=False)

            if lp_star.ok and lp_star.t < best.t:
                best = SearchResult(True, lp_star.t, s_star, (d, f, z), "OK")

    return best


# --------------------------
# Baseline interior (optional)
# --------------------------

def minimize_over_s_sorted(
    solver: FullProgramLPSolver,
    n_random: int = 120,
    n_local_starts: int = 3,
    seed: int = 0,
    maxiter: int = 250,
) -> SearchResult:
    n = solver.n
    rng = np.random.default_rng(seed)

    best_t = float("inf")
    best_s = None

    for _ in range(n_random):
        s = sort_desc(rng.random(n))
        val = solver.solve_lp(s, return_f=False)
        if val.ok and val.t < best_t:
            best_t = val.t
            best_s = s.copy()

    if best_s is None:
        best_s = np.full(n, 0.5)

    def obj(z):
        s = sort_desc(sigmoid(z))
        val = solver.solve_lp(s, return_f=False)
        return val.t if val.ok else 1e6

    starts = [inv_sig(best_s)]
    for _ in range(max(0, n_local_starts - 1)):
        s0 = sort_desc(rng.random(n))
        starts.append(inv_sig(s0))

    best_z = None
    best_local = best_t

    for z0 in starts:
        res = minimize(obj, z0, method="Powell", options={"maxiter": maxiter, "disp": False})
        if float(res.fun) < best_local:
            best_local = float(res.fun)
            best_z = res.x.copy()

    s_star = sort_desc(sigmoid(best_z)) if best_z is not None else best_s
    lp_star = solver.solve_lp(s_star, return_f=False)
    return SearchResult(lp_star.ok, lp_star.t, s_star, None, lp_star.msg)


# --------------------------
# Run ONLY n=6, C=7 (as requested)
# --------------------------

def main():
    n = 8
    C = 7.0

    # Outer knobs
    interior_n_random = 120
    interior_n_local_starts = 3
    interior_maxiter = 250

    three_grid_size = 21
    three_local_starts = 2
    three_maxiter = 60

    for h in range(1, n):
        solver = FullProgramLPSolver(
            n=n, h=h, C=C,
            cache_tol=1e-4,
            cache_max=200_000,
            profile=True,
        )

        interior = minimize_over_s_sorted(
            solver,
            n_random=interior_n_random,
            n_local_starts=interior_n_local_starts,
            seed=1,
            maxiter=interior_maxiter,
        )

        three = minimize_over_three_blocks(
            solver,
            seed=1,
            grid_size=three_grid_size,
            n_local_starts=three_local_starts,
            maxiter=three_maxiter,
        )

        print(
            f"\nRESULT n={n}, h={h}, C={C}\n"
            f"  three-block (d,f,z)={three.tag}: t={three.t: .6g}, s={np.array2string(three.s, precision=6, floatmode='fixed')}\n"
            f"  interior:             t={interior.t: .6g}, s={np.array2string(interior.s, precision=6, floatmode='fixed')}"
        )

        solver.print_profile(label=f"(n={n}, h={h}, C={C})")


if __name__ == "__main__":
    main()



RESULT n=8, h=1, C=7.0
  three-block (d,f,z)=(0, 8, 0): t= 3.81151, s=[0.455471 0.455471 0.455471 0.455471 0.455471 0.455471 0.455471 0.455471]
  interior:             t= 4.71582, s=[0.997939 0.996116 0.899269 0.407854 0.376272 0.372852 0.338958 0.326280]

=== solve_lp profile (n=8, h=1, C=7.0) ===
calls:      4191
cache_hits: 1838 (43.9%)
misses:     2353 (56.1%)
build_eq:   0.784s
build_ub:   19.878s
linprog:    34.718s
total:      55.661s


In [61]:
import itertools
import numpy as np
from dataclasses import dataclass
from typing import Tuple, List, Optional

from scipy.optimize import linprog, minimize
from scipy.sparse import coo_matrix, csr_matrix


# --------------------------
# Utilities
# --------------------------

def all_bitstrings(n: int) -> np.ndarray:
    return np.array(list(itertools.product([0, 1], repeat=n)), dtype=int)


def all_branches(n: int, h: int) -> List[Tuple[Tuple[int, ...], Tuple[int, ...]]]:
    k = n - h
    branches = []
    for A in itertools.combinations(range(n), k):
        for aA in itertools.product([0, 1], repeat=k):
            branches.append((tuple(A), tuple(aA)))
    return branches


def sigmoid(z):
    z = np.clip(z, -40, 40)
    return 1.0 / (1.0 + np.exp(-z))


def inv_sig(u):
    u = np.clip(u, 1e-12, 1 - 1e-12)
    return np.log(u / (1 - u))


def linprog_highs(c, A_ub, b_ub, A_eq, b_eq, bounds):
    return linprog(
        c=c, A_ub=A_ub, b_ub=b_ub,
        A_eq=A_eq, b_eq=b_eq,
        bounds=bounds,
        method="highs"
    )


def cache_key_from_s(s: np.ndarray, *, tol: float = 1e-4) -> tuple:
    s = np.clip(np.asarray(s, dtype=float), 0.0, 1.0)
    q = np.round(s / tol).astype(np.int64)
    diffs = np.diff(q)
    return (int(q[0]), *map(int, diffs))


# --------------------------
# Monotone (sorted) parameterization WITHOUT sorting
# --------------------------
# z in R^n -> s in [0,1]^n with 1 >= s1 >= ... >= sn >= 0
#
# s0 = sigmoid(z0)
# s1 = s0 * sigmoid(z1)
# s2 = s1 * sigmoid(z2)
# ...
#
# This avoids the non-smooth "sort_desc(sigmoid(z))" and makes finite-diff gradients useful.

def z_to_monotone_s(z: np.ndarray) -> np.ndarray:
    z = np.asarray(z, dtype=float)
    n = z.size
    s = np.empty(n, dtype=float)
    s[0] = sigmoid(z[0])
    for i in range(1, n):
        s[i] = s[i - 1] * sigmoid(z[i])
    return s


def monotone_s_to_z(s: np.ndarray) -> np.ndarray:
    """Approx inverse map for warm starts: recover z such that z_to_monotone_s(z) ~ s."""
    s = np.clip(np.asarray(s, dtype=float), 1e-12, 1 - 1e-12)
    n = s.size
    z = np.empty(n, dtype=float)
    z[0] = inv_sig(s[0])
    for i in range(1, n):
        if s[i - 1] <= 1e-12:
            ratio = 1e-12
        else:
            ratio = s[i] / s[i - 1]
        ratio = np.clip(ratio, 1e-12, 1 - 1e-12)
        z[i] = inv_sig(ratio)
    return z


# --------------------------
# Inner LP solver (fast A_ub CSR reuse)
# --------------------------

@dataclass
class LPValue:
    ok: bool
    t: float
    f: Optional[np.ndarray]
    msg: str


class FullProgramLPSolver:
    """
    Fast LP assembly:
      - precompute A_ub sparsity in CSR once
      - each solve fills a COO-order data vector then permutes into CSR-order data
      - avoids COO->CSR conversion per miss

    Uses:
      term[a,j] = s_j if bit=1 else (1-s_j)
      pi[a] = prod_j term[a,j]
      eq weight for i: w_i[a] = pi[a] / term[a,i]
      branch weight for A: w_A[a] = prod_{j in comp(A)} term[a,j]
    """

    def __init__(
        self,
        n: int,
        h: int,
        C: float,
        cache_tol: float = 1e-4,
        cache_max: int = 200_000,
        profile: bool = True,
    ):
        self.n = int(n)
        self.h = int(h)
        self.C = float(C)

        self.bits = all_bitstrings(self.n)  # (m,n)
        self.m = self.bits.shape[0]
        self.all_zero = (self.bits.sum(axis=1) == 0).astype(float)  # (m,)

        self.branches = all_branches(self.n, self.h)
        self.num_branches = len(self.branches)

        # Variable layout
        self.N = self.n * self.m + 1
        self.t_idx = self.n * self.m
        self.bounds = [(0.0, None)] * (self.n * self.m) + [(0.0, None)]
        self.c = np.zeros(self.N, dtype=float)
        self.c[self.t_idx] = 1.0

        # Precompute eq index sets
        self.eq_idx1 = [np.where(self.bits[:, i] == 1)[0] for i in range(self.n)]
        self.eq_idx0 = [np.where(self.bits[:, i] == 0)[0] for i in range(self.n)]
        self.eq_cols1 = [i * self.m + self.eq_idx1[i] for i in range(self.n)]
        self.eq_cols0 = [i * self.m + self.eq_idx0[i] for i in range(self.n)]

        # Branch precompute:
        all_idx = np.arange(self.n)
        self.branch_match_idx: List[np.ndarray] = []
        self.branch_cols_per_i: List[List[np.ndarray]] = []
        self.branch_comp_cols: List[np.ndarray] = []
        self.branch_has_zero = np.zeros(self.num_branches, dtype=bool)

        for b, (A, aA) in enumerate(self.branches):
            mask = np.ones(self.m, dtype=bool)
            for pos, j in enumerate(A):
                mask &= (self.bits[:, j] == aA[pos])
            idx = np.where(mask)[0]
            self.branch_match_idx.append(idx)
            self.branch_cols_per_i.append([i * self.m + idx for i in range(self.n)])

            Aset = set(A)
            comp = np.array([j for j in all_idx if j not in Aset], dtype=int)  # size h
            self.branch_comp_cols.append(comp)

            self.branch_has_zero[b] = (0 in set(idx.tolist()))

        # ----------------------
        # Precompute A_ub pattern in COO order + slices
        # ----------------------
        self.num_ub = 1 + self.num_branches

        total_nnz = self.n * self.m
        for idx in self.branch_match_idx:
            total_nnz += self.n * idx.size + 1

        self._ub_rows = np.empty(total_nnz, dtype=int)
        self._ub_cols = np.empty(total_nnz, dtype=int)

        self._ir_slices: List[slice] = []
        self._branch_slices: List[List[slice]] = []
        self._t_pos = np.empty(self.num_branches, dtype=int)

        off = 0
        # IR row pattern (row 0)
        for i in range(self.n):
            cols = i * self.m + np.arange(self.m, dtype=int)
            self._ub_rows[off:off + self.m] = 0
            self._ub_cols[off:off + self.m] = cols
            self._ir_slices.append(slice(off, off + self.m))
            off += self.m

        # Branch row patterns
        for b in range(self.num_branches):
            row = 1 + b
            idx = self.branch_match_idx[b]
            b_slices = []
            for i in range(self.n):
                cols = self.branch_cols_per_i[b][i]
                L = cols.size
                self._ub_rows[off:off + L] = row
                self._ub_cols[off:off + L] = cols
                b_slices.append(slice(off, off + L))
                off += L

            # -t entry
            self._ub_rows[off] = row
            self._ub_cols[off] = self.t_idx
            self._t_pos[b] = off
            off += 1
            self._branch_slices.append(b_slices)

        assert off == total_nnz

        # Build CSR template once; compute permutation COO->CSR
        ids = np.arange(total_nnz, dtype=float)
        csr_template = coo_matrix((ids, (self._ub_rows, self._ub_cols)),
                                  shape=(self.num_ub, self.N)).tocsr()
        self._ub_perm = csr_template.data.astype(np.int64)
        self._ub_indices = csr_template.indices.copy()
        self._ub_indptr = csr_template.indptr.copy()

        # Cache
        self.cache_tol = float(cache_tol)
        self._cache: dict[tuple, LPValue] = {}
        self._cache_max = int(cache_max)

        # Profiling
        self.profile = bool(profile)
        self._prof = {
            "calls": 0,
            "cache_hits": 0,
            "build_eq_s": 0.0,
            "build_ub_s": 0.0,
            "linprog_s": 0.0,
            "total_s": 0.0,
        }

    def solve_lp(self, s: np.ndarray, return_f: bool = False) -> LPValue:
        import time
        t0 = time.perf_counter()
        if self.profile:
            self._prof["calls"] += 1

        s = np.clip(np.asarray(s, dtype=float), 0.0, 1.0)

        key = cache_key_from_s(s, tol=self.cache_tol)
        cached = self._cache.get(key)
        if cached is not None and ((not return_f) or (cached.f is not None)):
            if self.profile:
                self._prof["cache_hits"] += 1
                self._prof["total_s"] += time.perf_counter() - t0
            return cached

        n, m, C = self.n, self.m, self.C
        N = self.N

        # Precompute term + pi once
        s_safe = np.clip(s, 1e-12, 1 - 1e-12)
        term = np.where(self.bits == 1, s_safe[None, :], (1.0 - s_safe)[None, :])  # (m,n)
        pi = term.prod(axis=1)  # (m,)

        # ----------------------
        # Build A_eq (small)
        # ----------------------
        t_eq0 = time.perf_counter()
        eq_rows = []
        eq_cols = []
        eq_data = []

        for i in range(n):
            w_i = pi / term[:, i]
            idx1 = self.eq_idx1[i]
            idx0 = self.eq_idx0[i]
            cols1 = self.eq_cols1[i]
            cols0 = self.eq_cols0[i]

            eq_rows.append(np.full(cols1.shape[0], i, dtype=int))
            eq_cols.append(cols1.astype(int))
            eq_data.append(w_i[idx1])

            eq_rows.append(np.full(cols0.shape[0], i, dtype=int))
            eq_cols.append(cols0.astype(int))
            eq_data.append(-w_i[idx0])

        A_eq = coo_matrix(
            (np.concatenate(eq_data), (np.concatenate(eq_rows), np.concatenate(eq_cols))),
            shape=(n, N),
        ).tocsr()
        b_eq = np.ones(n, dtype=float)

        if self.profile:
            self._prof["build_eq_s"] += time.perf_counter() - t_eq0

        # ----------------------
        # Build A_ub by filling data only
        # ----------------------
        t_ub0 = time.perf_counter()

        data_coo = np.empty_like(self._ub_rows, dtype=float)
        b_ub = np.zeros(self.num_ub, dtype=float)

        # IR row: -pi blocks
        for sl in self._ir_slices:
            data_coo[sl] = -pi
        b_ub[0] = -float(s.sum())

        # Branch rows
        for b in range(self.num_branches):
            comp = self.branch_comp_cols[b]  # size h
            if comp.size == 0:
                w = np.ones(m, dtype=float)
            else:
                w = term[:, comp].prod(axis=1)  # (m,)

            idx = self.branch_match_idx[b]
            w_idx = w[idx]

            for i in range(n):
                data_coo[self._branch_slices[b][i]] = w_idx

            data_coo[self._t_pos[b]] = -1.0

            # RHS constant uses only all-zero assignment (index 0)
            if self.branch_has_zero[b]:
                b_ub[1 + b] = -C * float(w[0])
            else:
                b_ub[1 + b] = 0.0

        data_csr = data_coo[self._ub_perm]
        A_ub = csr_matrix((data_csr, self._ub_indices, self._ub_indptr), shape=(self.num_ub, self.N))

        if self.profile:
            self._prof["build_ub_s"] += time.perf_counter() - t_ub0

        # ----------------------
        # Solve LP
        # ----------------------
        t_lp0 = time.perf_counter()
        res = linprog_highs(self.c, A_ub, b_ub, A_eq, b_eq, self.bounds)
        if self.profile:
            self._prof["linprog_s"] += time.perf_counter() - t_lp0

        if not res.success:
            val = LPValue(ok=False, t=float("inf"), f=None, msg=str(res.message))
            if len(self._cache) < self._cache_max:
                self._cache[key] = val
            if self.profile:
                self._prof["total_s"] += time.perf_counter() - t0
            return val

        t_val = float(res.fun)
        if not return_f:
            val = LPValue(ok=True, t=t_val, f=None, msg="OK")
            if len(self._cache) < self._cache_max:
                self._cache[key] = val
            if self.profile:
                self._prof["total_s"] += time.perf_counter() - t0
            return val

        x = res.x
        f = x[: n * m].reshape((n, m))
        val = LPValue(ok=True, t=t_val, f=f, msg="OK")
        if len(self._cache) < self._cache_max:
            self._cache[key] = val
        if self.profile:
            self._prof["total_s"] += time.perf_counter() - t0
        return val

    def print_profile(self, label: str = ""):
        p = self._prof
        calls = p["calls"]
        hits = p["cache_hits"]
        misses = calls - hits
        head = f"=== solve_lp profile {label} ===".strip()
        print("\n" + head)
        print(f"calls:      {calls}")
        print(f"cache_hits: {hits} ({hits / max(1, calls):.1%})")
        print(f"misses:     {misses} ({misses / max(1, calls):.1%})")
        print(f"build_eq:   {p['build_eq_s']:.3f}s")
        print(f"build_ub:   {p['build_ub_s']:.3f}s")
        print(f"linprog:    {p['linprog_s']:.3f}s")
        print(f"total:      {p['total_s']:.3f}s")


# --------------------------
# Three-block search: 1^d t^f 0^z
# --------------------------

@dataclass
class SearchResult:
    ok: bool
    t: float
    s: np.ndarray
    tag: Optional[Tuple[int, int, int]]  # (d,f,z)
    msg: str


def three_block_s(n: int, d: int, z: int, t: float) -> np.ndarray:
    f = n - d - z
    s = np.empty(n, dtype=float)
    s[:d] = 1.0
    s[d:d+f] = float(t)
    s[d+f:] = 0.0
    return s


def minimize_over_three_blocks(
    solver: FullProgramLPSolver,
    seed: int = 0,
    grid_size: int = 21,
    n_local_starts: int = 2,
    maxiter: int = 60,
) -> SearchResult:
    n = solver.n
    rng = np.random.default_rng(seed)

    best = SearchResult(False, float("inf"), np.zeros(n), None, "no feasible 3-block found")

    for d in range(0, n + 1):
        for z in range(0, n - d + 1):
            f = n - d - z
            if f <= 0:
                continue

            # Coarse grid
            ts = np.linspace(0.0, 1.0, grid_size)
            best_grid_t = 0.5
            best_grid_val = float("inf")
            best_grid_ok = False
            for t0 in ts:
                s0 = three_block_s(n, d, z, t0)
                val = solver.solve_lp(s0, return_f=False)
                if val.ok and val.t < best_grid_val:
                    best_grid_val = val.t
                    best_grid_t = t0
                    best_grid_ok = True

            starts = []
            starts.append(inv_sig(best_grid_t if best_grid_ok else 0.5))
            for _ in range(max(0, n_local_starts - 1)):
                starts.append(inv_sig(rng.random()))

            def obj(zvec):
                t = sigmoid(zvec[0])
                s = three_block_s(n, d, z, t)
                val = solver.solve_lp(s, return_f=False)
                return val.t if val.ok else 1e6

            best_local_val = best_grid_val
            best_local_z = None
            for z0 in starts:
                res = minimize(
                    obj,
                    x0=np.array([z0], dtype=float),
                    method="Powell",
                    options={"maxiter": maxiter, "disp": False},
                )
                if float(res.fun) < best_local_val:
                    best_local_val = float(res.fun)
                    best_local_z = float(res.x[0])

            t_star = sigmoid(best_local_z) if best_local_z is not None else best_grid_t
            s_star = three_block_s(n, d, z, t_star)
            lp_star = solver.solve_lp(s_star, return_f=False)

            if lp_star.ok and lp_star.t < best.t:
                best = SearchResult(True, lp_star.t, s_star, (d, f, z), "OK")

    return best


# --------------------------
# Interior solver WITH finite-difference gradient in z-space
# (z -> monotone s) so no sorting kinks
# --------------------------

def minimize_over_monotone_interior_with_grad(
    solver: FullProgramLPSolver,
    z0_list: List[np.ndarray],
    maxiter: int = 120,
    fd_eps: float = 2e-3,
) -> Tuple[SearchResult, List[Tuple[float, int]]]:
    """
    Returns:
      - best SearchResult among starts
      - list of (best_fun, nfev) per start
    """
    n = solver.n

    def obj_from_z(z):
        s = z_to_monotone_s(z)
        val = solver.solve_lp(s, return_f=False)
        return val.t if val.ok else 1e6

    def obj_and_grad(z):
        f0 = obj_from_z(z)
        g = np.zeros_like(z)
        # central differences
        for k in range(n):
            zp = z.copy(); zp[k] += fd_eps
            zm = z.copy(); zm[k] -= fd_eps
            g[k] = (obj_from_z(zp) - obj_from_z(zm)) / (2.0 * fd_eps)
        return f0, g

    best = SearchResult(False, float("inf"), np.zeros(n), None, "no feasible interior found")
    per_start_stats: List[Tuple[float, int]] = []

    for z0 in z0_list:
        res = minimize(
            fun=lambda z: obj_and_grad(z)[0],
            x0=np.asarray(z0, dtype=float),
            jac=lambda z: obj_and_grad(z)[1],
            method="L-BFGS-B",
            options={"maxiter": maxiter, "disp": False},
        )

        s_star = z_to_monotone_s(res.x)
        lp_star = solver.solve_lp(s_star, return_f=False)
        per_start_stats.append((float(res.fun), int(getattr(res, "nfev", -1))))

        if lp_star.ok and lp_star.t < best.t:
            best = SearchResult(True, lp_star.t, s_star, None, lp_star.msg)

    return best, per_start_stats


# --------------------------
# Run ONLY n=6, C=7
# --------------------------

def main():
    n = 7
    C = 7.0

    # Three-block knobs
    three_grid_size = 21
    three_local_starts = 2
    three_maxiter = 60

    # Interior-gradient knobs
    interior_maxiter = 120
    interior_fd_eps = 2e-3

    # How many extra starts besides the warm-start from 3-block?
    extra_random_starts = 2
    rng = np.random.default_rng(1)

    for h in range(1, n):
        solver = FullProgramLPSolver(
            n=n, h=h, C=C,
            cache_tol=1e-4,
            cache_max=200_000,
            profile=True,
        )

        three = minimize_over_three_blocks(
            solver,
            seed=1,
            grid_size=three_grid_size,
            n_local_starts=three_local_starts,
            maxiter=three_maxiter,
        )

        # Build start list for interior solver:
        # 1) warm-start from best three-block, with tiny tie-breaking jitter
        z0_list = []
        if three.ok:
            eps = 1e-3
            s0 = np.clip(three.s, eps, 1 - eps)
            s0 = np.maximum.accumulate(s0[::-1])[::-1]  # ensure nonincreasing (should already be)
            # break ties slightly while preserving order
            s0 = np.clip(s0 + np.linspace(0, 1e-7, n)[::-1], eps, 1 - eps)
            z0_list.append(monotone_s_to_z(s0))
        else:
            z0_list.append(np.zeros(n, dtype=float))

        # 2) a couple random monotone starts in z-space
        for _ in range(extra_random_starts):
            z0_list.append(rng.normal(size=n))

        interior, stats = minimize_over_monotone_interior_with_grad(
            solver,
            z0_list=z0_list,
            maxiter=interior_maxiter,
            fd_eps=interior_fd_eps,
        )

        print(
            f"\nRESULT n={n}, h={h}, C={C}\n"
            f"  three-block (d,f,z)={three.tag}: t={three.t: .6g}, s={np.array2string(three.s, precision=6, floatmode='fixed')}\n"
            f"  interior+grad:              t={interior.t: .6g}, s={np.array2string(interior.s, precision=6, floatmode='fixed')}\n"
            f"  interior starts (fun,nfev): {stats}"
        )

        solver.print_profile(label=f"(n={n}, h={h}, C={C})")


if __name__ == "__main__":
    main()



RESULT n=7, h=1, C=7.0
  three-block (d,f,z)=(0, 7, 0): t= 3.5414, s=[0.494216 0.494216 0.494216 0.494216 0.494216 0.494216 0.494216]
  interior+grad:              t= 3.5414, s=[0.494216 0.494216 0.494216 0.494216 0.494216 0.494216 0.494216]
  interior starts (fun,nfev): [(3.5414011419647164, 21), (4.988091556451353, 98), (3.9313374049446104, 148)]

=== solve_lp profile (n=7, h=1, C=7.0) ===
calls:      9894
cache_hits: 7135 (72.1%)
misses:     2759 (27.9%)
build_eq:   0.687s
build_ub:   9.440s
linprog:    17.892s
total:      28.485s

RESULT n=7, h=2, C=7.0
  three-block (d,f,z)=(1, 6, 0): t= 3.04364, s=[1.000000 0.340602 0.340602 0.340602 0.340602 0.340602 0.340602]
  interior+grad:              t= 3.04429, s=[0.999000 0.340602 0.340602 0.340602 0.340602 0.340602 0.340602]
  interior starts (fun,nfev): [(3.044290240060617, 20), (3.30250491224694, 123), (5.247146312892605, 46)]

=== solve_lp profile (n=7, h=2, C=7.0) ===
calls:      8160
cache_hits: 5823 (71.4%)
misses:     2337 (28.6

In [62]:
import itertools
import numpy as np
from dataclasses import dataclass
from typing import Tuple, List, Optional

from scipy.optimize import linprog, minimize
from scipy.sparse import coo_matrix, csr_matrix


# --------------------------
# Utilities
# --------------------------

def all_bitstrings(n: int) -> np.ndarray:
    return np.array(list(itertools.product([0, 1], repeat=n)), dtype=int)


def all_branches(n: int, h: int) -> List[Tuple[Tuple[int, ...], Tuple[int, ...]]]:
    k = n - h
    branches = []
    for A in itertools.combinations(range(n), k):
        for aA in itertools.product([0, 1], repeat=k):
            branches.append((tuple(A), tuple(aA)))
    return branches


def sigmoid(z):
    z = np.clip(z, -40, 40)
    return 1.0 / (1.0 + np.exp(-z))


def inv_sig(u):
    u = np.clip(u, 1e-12, 1 - 1e-12)
    return np.log(u / (1 - u))


def linprog_highs(c, A_ub, b_ub, A_eq, b_eq, bounds):
    return linprog(
        c=c, A_ub=A_ub, b_ub=b_ub,
        A_eq=A_eq, b_eq=b_eq,
        bounds=bounds,
        method="highs"
    )


def cache_key_from_s(s: np.ndarray, *, tol: float = 1e-4) -> tuple:
    s = np.clip(np.asarray(s, dtype=float), 0.0, 1.0)
    q = np.round(s / tol).astype(np.int64)
    diffs = np.diff(q)
    return (int(q[0]), *map(int, diffs))


# --------------------------
# Smooth monotone parameterization (NO SORTING)
# --------------------------
# z in R^n -> s in [0,1]^n with 1 >= s1 >= ... >= sn >= 0
#
# s0 = sigmoid(z0)
# s1 = s0 * sigmoid(z1)
# ...
#
# This makes finite-difference gradients meaningful (sorting breaks this).

def z_to_monotone_s(z: np.ndarray) -> np.ndarray:
    z = np.asarray(z, dtype=float)
    n = z.size
    s = np.empty(n, dtype=float)
    s[0] = sigmoid(z[0])
    for i in range(1, n):
        s[i] = s[i - 1] * sigmoid(z[i])
    return s


# --------------------------
# Inner LP solver (fast A_ub CSR reuse)
# --------------------------

@dataclass
class LPValue:
    ok: bool
    t: float
    f: Optional[np.ndarray]
    msg: str


class FullProgramLPSolver:
    """
    Fast LP assembly:
      - precompute A_ub sparsity in CSR once
      - each solve fills a COO-order data vector then permutes into CSR-order data
      - avoids COO->CSR conversion per miss

    Weights:
      term[a,j] = s_j if bit=1 else (1-s_j)
      pi[a] = prod_j term[a,j]
      eq: w_i[a] = pi[a] / term[a,i]
      branch: w_A[a] = prod_{j in comp(A)} term[a,j]
    """

    def __init__(
        self,
        n: int,
        h: int,
        C: float,
        cache_tol: float = 1e-4,
        cache_max: int = 200_000,
        profile: bool = True,
    ):
        self.n = int(n)
        self.h = int(h)
        self.C = float(C)

        self.bits = all_bitstrings(self.n)  # (m,n)
        self.m = self.bits.shape[0]
        self.all_zero = (self.bits.sum(axis=1) == 0).astype(float)  # (m,)

        self.branches = all_branches(self.n, self.h)
        self.num_branches = len(self.branches)

        # Variable layout
        self.N = self.n * self.m + 1
        self.t_idx = self.n * self.m
        self.bounds = [(0.0, None)] * (self.n * self.m) + [(0.0, None)]
        self.c = np.zeros(self.N, dtype=float)
        self.c[self.t_idx] = 1.0

        # Precompute eq index sets
        self.eq_idx1 = [np.where(self.bits[:, i] == 1)[0] for i in range(self.n)]
        self.eq_idx0 = [np.where(self.bits[:, i] == 0)[0] for i in range(self.n)]
        self.eq_cols1 = [i * self.m + self.eq_idx1[i] for i in range(self.n)]
        self.eq_cols0 = [i * self.m + self.eq_idx0[i] for i in range(self.n)]

        # Branch precompute:
        all_idx = np.arange(self.n)
        self.branch_match_idx: List[np.ndarray] = []
        self.branch_cols_per_i: List[List[np.ndarray]] = []
        self.branch_comp_cols: List[np.ndarray] = []
        self.branch_has_zero = np.zeros(self.num_branches, dtype=bool)

        for b, (A, aA) in enumerate(self.branches):
            mask = np.ones(self.m, dtype=bool)
            for pos, j in enumerate(A):
                mask &= (self.bits[:, j] == aA[pos])
            idx = np.where(mask)[0]
            self.branch_match_idx.append(idx)
            self.branch_cols_per_i.append([i * self.m + idx for i in range(self.n)])

            Aset = set(A)
            comp = np.array([j for j in all_idx if j not in Aset], dtype=int)  # size h
            self.branch_comp_cols.append(comp)

            self.branch_has_zero[b] = (0 in set(idx.tolist()))

        # ----------------------
        # Precompute A_ub pattern in COO order + slices
        # ----------------------
        self.num_ub = 1 + self.num_branches

        total_nnz = self.n * self.m
        for idx in self.branch_match_idx:
            total_nnz += self.n * idx.size + 1

        self._ub_rows = np.empty(total_nnz, dtype=int)
        self._ub_cols = np.empty(total_nnz, dtype=int)

        self._ir_slices: List[slice] = []
        self._branch_slices: List[List[slice]] = []
        self._t_pos = np.empty(self.num_branches, dtype=int)

        off = 0
        # IR row pattern (row 0)
        for i in range(self.n):
            cols = i * self.m + np.arange(self.m, dtype=int)
            self._ub_rows[off:off + self.m] = 0
            self._ub_cols[off:off + self.m] = cols
            self._ir_slices.append(slice(off, off + self.m))
            off += self.m

        # Branch row patterns
        for b in range(self.num_branches):
            row = 1 + b
            idx = self.branch_match_idx[b]
            b_slices = []
            for i in range(self.n):
                cols = self.branch_cols_per_i[b][i]
                L = cols.size
                self._ub_rows[off:off + L] = row
                self._ub_cols[off:off + L] = cols
                b_slices.append(slice(off, off + L))
                off += L

            # -t entry
            self._ub_rows[off] = row
            self._ub_cols[off] = self.t_idx
            self._t_pos[b] = off
            off += 1
            self._branch_slices.append(b_slices)

        assert off == total_nnz

        # Build CSR template once; compute permutation COO->CSR
        ids = np.arange(total_nnz, dtype=float)
        csr_template = coo_matrix((ids, (self._ub_rows, self._ub_cols)),
                                  shape=(self.num_ub, self.N)).tocsr()
        self._ub_perm = csr_template.data.astype(np.int64)
        self._ub_indices = csr_template.indices.copy()
        self._ub_indptr = csr_template.indptr.copy()

        # Cache
        self.cache_tol = float(cache_tol)
        self._cache: dict[tuple, LPValue] = {}
        self._cache_max = int(cache_max)

        # Profiling
        self.profile = bool(profile)
        self._prof = {
            "calls": 0,
            "cache_hits": 0,
            "build_eq_s": 0.0,
            "build_ub_s": 0.0,
            "linprog_s": 0.0,
            "total_s": 0.0,
        }

    def solve_lp(self, s: np.ndarray, return_f: bool = False) -> LPValue:
        import time
        t0 = time.perf_counter()
        if self.profile:
            self._prof["calls"] += 1

        s = np.clip(np.asarray(s, dtype=float), 0.0, 1.0)

        key = cache_key_from_s(s, tol=self.cache_tol)
        cached = self._cache.get(key)
        if cached is not None and ((not return_f) or (cached.f is not None)):
            if self.profile:
                self._prof["cache_hits"] += 1
                self._prof["total_s"] += time.perf_counter() - t0
            return cached

        n, m, C = self.n, self.m, self.C
        N = self.N

        # Precompute term + pi once
        s_safe = np.clip(s, 1e-12, 1 - 1e-12)
        term = np.where(self.bits == 1, s_safe[None, :], (1.0 - s_safe)[None, :])  # (m,n)
        pi = term.prod(axis=1)  # (m,)

        # ----------------------
        # Build A_eq (small)
        # ----------------------
        t_eq0 = time.perf_counter()
        eq_rows = []
        eq_cols = []
        eq_data = []

        for i in range(n):
            w_i = pi / term[:, i]
            idx1 = self.eq_idx1[i]
            idx0 = self.eq_idx0[i]
            cols1 = self.eq_cols1[i]
            cols0 = self.eq_cols0[i]

            eq_rows.append(np.full(cols1.shape[0], i, dtype=int))
            eq_cols.append(cols1.astype(int))
            eq_data.append(w_i[idx1])

            eq_rows.append(np.full(cols0.shape[0], i, dtype=int))
            eq_cols.append(cols0.astype(int))
            eq_data.append(-w_i[idx0])

        A_eq = coo_matrix(
            (np.concatenate(eq_data), (np.concatenate(eq_rows), np.concatenate(eq_cols))),
            shape=(n, N),
        ).tocsr()
        b_eq = np.ones(n, dtype=float)

        if self.profile:
            self._prof["build_eq_s"] += time.perf_counter() - t_eq0

        # ----------------------
        # Build A_ub by filling data only
        # ----------------------
        t_ub0 = time.perf_counter()

        data_coo = np.empty_like(self._ub_rows, dtype=float)
        b_ub = np.zeros(self.num_ub, dtype=float)

        # IR row: -pi blocks
        for sl in self._ir_slices:
            data_coo[sl] = -pi
        b_ub[0] = -float(s.sum())

        # Branch rows
        for b in range(self.num_branches):
            comp = self.branch_comp_cols[b]  # size h
            if comp.size == 0:
                w = np.ones(m, dtype=float)
            else:
                w = term[:, comp].prod(axis=1)  # (m,)

            idx = self.branch_match_idx[b]
            w_idx = w[idx]

            for i in range(n):
                data_coo[self._branch_slices[b][i]] = w_idx

            data_coo[self._t_pos[b]] = -1.0

            # RHS constant uses only all-zero assignment (index 0)
            if self.branch_has_zero[b]:
                b_ub[1 + b] = -C * float(w[0])
            else:
                b_ub[1 + b] = 0.0

        data_csr = data_coo[self._ub_perm]
        A_ub = csr_matrix((data_csr, self._ub_indices, self._ub_indptr), shape=(self.num_ub, self.N))

        if self.profile:
            self._prof["build_ub_s"] += time.perf_counter() - t_ub0

        # ----------------------
        # Solve LP
        # ----------------------
        t_lp0 = time.perf_counter()
        res = linprog_highs(self.c, A_ub, b_ub, A_eq, b_eq, self.bounds)
        if self.profile:
            self._prof["linprog_s"] += time.perf_counter() - t_lp0

        if not res.success:
            val = LPValue(ok=False, t=float("inf"), f=None, msg=str(res.message))
            if len(self._cache) < self._cache_max:
                self._cache[key] = val
            if self.profile:
                self._prof["total_s"] += time.perf_counter() - t0
            return val

        t_val = float(res.fun)
        if not return_f:
            val = LPValue(ok=True, t=t_val, f=None, msg="OK")
            if len(self._cache) < self._cache_max:
                self._cache[key] = val
            if self.profile:
                self._prof["total_s"] += time.perf_counter() - t0
            return val

        x = res.x
        f = x[: n * m].reshape((n, m))
        val = LPValue(ok=True, t=t_val, f=f, msg="OK")
        if len(self._cache) < self._cache_max:
            self._cache[key] = val
        if self.profile:
            self._prof["total_s"] += time.perf_counter() - t0
        return val

    def print_profile(self, label: str = ""):
        p = self._prof
        calls = p["calls"]
        hits = p["cache_hits"]
        misses = calls - hits
        head = f"=== solve_lp profile {label} ===".strip()
        print("\n" + head)
        print(f"calls:      {calls}")
        print(f"cache_hits: {hits} ({hits / max(1, calls):.1%})")
        print(f"misses:     {misses} ({misses / max(1, calls):.1%})")
        print(f"build_eq:   {p['build_eq_s']:.3f}s")
        print(f"build_ub:   {p['build_ub_s']:.3f}s")
        print(f"linprog:    {p['linprog_s']:.3f}s")
        print(f"total:      {p['total_s']:.3f}s")


# --------------------------
# Interior solver ONLY (monotone parameterization + finite-diff gradients)
# --------------------------

@dataclass
class SearchResult:
    ok: bool
    t: float
    s: np.ndarray
    msg: str


def solve_interior_only_with_grad(
    solver: FullProgramLPSolver,
    seed: int = 0,
    n_random_screen: int = 120,     # how many random monotone samples to screen
    n_local_starts: int = 3,        # how many L-BFGS-B runs
    local_maxiter: int = 120,
    fd_eps: float = 2e-3,
) -> Tuple[SearchResult, dict]:
    """
    Pure interior search:
      1) random screen in z-space (monotone by construction) to pick good start(s)
      2) L-BFGS-B with finite-difference gradient in z-space

    Returns:
      (best result, diagnostics dict)
    """
    n = solver.n
    rng = np.random.default_rng(seed)

    def obj_from_z(z):
        s = z_to_monotone_s(z)
        val = solver.solve_lp(s, return_f=False)
        return val.t if val.ok else 1e6

    def obj_and_grad(z):
        f0 = obj_from_z(z)
        g = np.zeros_like(z)
        for k in range(n):
            zp = z.copy(); zp[k] += fd_eps
            zm = z.copy(); zm[k] -= fd_eps
            g[k] = (obj_from_z(zp) - obj_from_z(zm)) / (2.0 * fd_eps)
        return f0, g

    # ----------------------
    # Random screening to pick good start points
    # ----------------------
    screened: List[Tuple[float, np.ndarray]] = []
    best_screen_val = float("inf")
    best_screen_z = None
    best_screen_s = None

    for _ in range(n_random_screen):
        z = rng.normal(size=n)
        val = obj_from_z(z)
        screened.append((val, z))
        if val < best_screen_val:
            best_screen_val = val
            best_screen_z = z.copy()
            best_screen_s = z_to_monotone_s(z).copy()

    screened.sort(key=lambda x: x[0])

    # Choose starts: top few screened + a couple random
    starts: List[np.ndarray] = []
    if best_screen_z is not None:
        starts.append(best_screen_z)
    for k in range(1, min(n_local_starts, len(screened))):
        starts.append(screened[k][1])
    while len(starts) < n_local_starts:
        starts.append(rng.normal(size=n))

    # ----------------------
    # Local optimization with gradient (finite diff)
    # ----------------------
    best = SearchResult(False, float("inf"), np.zeros(n), "no feasible found")
    local_stats = []

    for z0 in starts:
        res = minimize(
            fun=lambda z: obj_and_grad(z)[0],
            x0=np.asarray(z0, dtype=float),
            jac=lambda z: obj_and_grad(z)[1],
            method="L-BFGS-B",
            options={"maxiter": local_maxiter, "disp": False},
        )
        s_star = z_to_monotone_s(res.x)
        lp_star = solver.solve_lp(s_star, return_f=False)

        local_stats.append({
            "fun": float(res.fun),
            "nfev": int(getattr(res, "nfev", -1)),
            "nit": int(getattr(res, "nit", -1)),
            "success": bool(getattr(res, "success", False)),
            "message": str(getattr(res, "message", "")),
        })

        if lp_star.ok and lp_star.t < best.t:
            best = SearchResult(True, lp_star.t, s_star, lp_star.msg)

    diag = {
        "best_screen_val": float(best_screen_val),
        "best_screen_s": best_screen_s,
        "local_stats": local_stats,
    }
    return best, diag


# --------------------------
# Run ONLY n=6, C=7, interior-only
# --------------------------

def main():
    n = 6
    C = 7.0

    # Interior-only knobs
    n_random_screen = 160
    n_local_starts = 3
    local_maxiter = 120
    fd_eps = 2e-3

    for h in range(1, n):
        solver = FullProgramLPSolver(
            n=n, h=h, C=C,
            cache_tol=1e-4,
            cache_max=200_000,
            profile=True,
        )

        interior, diag = solve_interior_only_with_grad(
            solver,
            seed=1,
            n_random_screen=n_random_screen,
            n_local_starts=n_local_starts,
            local_maxiter=local_maxiter,
            fd_eps=fd_eps,
        )

        print(
            f"\nRESULT (INTERIOR ONLY) n={n}, h={h}, C={C}\n"
            f"  best: t={interior.t: .6g}, s={np.array2string(interior.s, precision=6, floatmode='fixed')}\n"
            f"  best_screen_val={diag['best_screen_val']: .6g}\n"
            f"  local_stats={diag['local_stats']}"
        )

        solver.print_profile(label=f"(n={n}, h={h}, C={C})")


if __name__ == "__main__":
    main()



RESULT (INTERIOR ONLY) n=6, h=1, C=7.0
  best: t= 3.56883, s=[0.779066 0.650286 0.577241 0.535858 0.509676 0.490172]
  best_screen_val= 6.39999
  local_stats=[{'fun': 3.5873000404733717, 'nfev': 89, 'nit': 22, 'success': True, 'message': 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'}, {'fun': 3.568827810911045, 'nfev': 152, 'nit': 41, 'success': True, 'message': 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'}, {'fun': 3.7147761288269168, 'nfev': 83, 'nit': 16, 'success': True, 'message': 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'}]

=== solve_lp profile (n=6, h=1, C=7.0) ===
calls:      8587
cache_hits: 6100 (71.0%)
misses:     2487 (29.0%)
build_eq:   0.513s
build_ub:   3.525s
linprog:    7.736s
total:      12.170s

RESULT (INTERIOR ONLY) n=6, h=2, C=7.0
  best: t= 3.05234, s=[0.736419 0.512630 0.477937 0.428080 0.372866 0.304850]
  best_screen_val= 5.3444
  local_stats=[{'fun': 3.0523414508329005, 'nfev': 108, 'nit': 20, 'success': False, 'message': 'ABNORMAL_TERMINATI