In [2]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.patches as mpatches
import numpy as np
import pandas as pd
import cvxpy as cp
from scipy.stats import poisson, uniform, expon, pareto, hypergeom
from scipy.optimize import minimize, fsolve, least_squares
import scipy as sc
from tqdm import tqdm
from mdptoolbox import mdp, util
import itertools
from scipy.sparse import csr_matrix, lil_matrix
from matplotlib.patches import Patch
import math
import random
import sympy as sp
from sympy.printing.latex import print_latex

plt.rcParams['axes.labelsize'] = 14
plt.rcParams['axes.titlesize'] = 16

In [3]:
import itertools
import numpy as np
from dataclasses import dataclass
from typing import Dict, Tuple, List, Optional

from scipy.optimize import minimize, Bounds, NonlinearConstraint


def all_bitstrings(n: int) -> np.ndarray:
    """Return array of shape (2^n, n) with rows in {0,1}^n."""
    return np.array(list(itertools.product([0, 1], repeat=n)), dtype=int)


def branch_list(n: int, h: int):
    """All branches (A, aA) where |A|=n-h and aA is assignment on A."""
    k = n - h
    subsets_A = list(itertools.combinations(range(n), k))
    branches = []
    for A in subsets_A:
        for aA in itertools.product([0, 1], repeat=k):
            branches.append((tuple(A), tuple(aA)))
    return branches


@dataclass
class SciPySolveResult:
    success: bool
    message: str
    t: float
    s: np.ndarray                 # shape (n,)
    f: np.ndarray                 # shape (n, 2^n)  f[i, idx(a)]
    fun: float                    # objective
    nit: int


def solve_full_program_scipy(
    n: int,
    h: int,
    C: float,
    n_starts: int = 10,
    seed: int = 0,
    maxiter: int = 2000,
    verbose: bool = False,
) -> SciPySolveResult:
    """
    Solve the full program with SciPy (trust-constr) using multi-start.

    Variables:
      s in [0,1]^n
      f_i(a) >= 0 for i in [n], a in {0,1}^n
      t free (but we bound t>=0 since objective is a max of nonneg-ish terms)

    Constraints (implemented):
      - Eq constraints enforced as equality for all i:
            E_{a_-i}[ f_i(1,a_-i) ] - 1  =  E_{a_-i}[ f_i(0,a_-i) ]
        where expectation uses product distribution over coordinates != i.
      - IR:
            sum_i s_i <= sum_a pi_s(a) * f(a)
      - Branch constraints:
            For each (A,aA),  E_{a_-A}[ f(aA,a_-A) + C*1[a=0] ] <= t
    """

    assert 1 <= n
    assert 0 <= h <= n
    assert C >= 0

    bits = all_bitstrings(n)          # (m, n), m=2^n
    m = bits.shape[0]
    branches = branch_list(n, h)

    # Precompute useful masks
    all_zero_mask = (bits.sum(axis=1) == 0).astype(float)  # (m,)

    # Variable packing: x = [s (n), f (n*m), t (1)]
    def unpack(x: np.ndarray):
        s = x[:n]
        f_flat = x[n:n + n * m]
        f = f_flat.reshape((n, m))
        t = x[-1]
        return s, f, t

    def pack(s: np.ndarray, f: np.ndarray, t: float):
        return np.concatenate([s.ravel(), f.ravel(), np.array([t], dtype=float)])

    # Probability helpers
    def prob_full(s: np.ndarray) -> np.ndarray:
        # pi_s(a) = ∏_j s_j^{a_j} (1-s_j)^{1-a_j}
        # compute vectorized: for each a row, product over j
        s = np.clip(s, 0.0, 1.0)
        term = np.where(bits == 1, s[None, :], (1.0 - s)[None, :])  # (m,n)
        return term.prod(axis=1)  # (m,)

    def prob_over_complement(s: np.ndarray, fixed_idx: Tuple[int, ...]) -> np.ndarray:
        # product over j not in fixed_idx
        s = np.clip(s, 0.0, 1.0)
        fixed = np.zeros(n, dtype=bool)
        fixed[list(fixed_idx)] = True
        term = np.where(bits == 1, s[None, :], (1.0 - s)[None, :])  # (m,n)
        if fixed.all():
            return np.ones(m, dtype=float)
        return term[:, ~fixed].prod(axis=1)  # (m,)

    # Objective: minimize t
    def objective(x: np.ndarray) -> float:
        return x[-1]

    # Constraints
    def eq_constraints(x: np.ndarray) -> np.ndarray:
        # size n: for each i, E[f_i | a_i=1] - 1 - E[f_i | a_i=0] = 0
        s, f, _t = unpack(x)
        out = np.zeros(n, dtype=float)
        for i in range(n):
            w = prob_over_complement(s, (i,))  # (m,)
            mask1 = (bits[:, i] == 1)
            mask0 = ~mask1
            E1 = np.dot(w[mask1], f[i, mask1])
            E0 = np.dot(w[mask0], f[i, mask0])
            out[i] = (E1 - 1.0) - E0
        return out

    def ir_constraint(x: np.ndarray) -> float:
        # sum_i s_i - sum_a pi(a)*f(a) <= 0
        s, f, _t = unpack(x)
        pi = prob_full(s)
        fsum = f.sum(axis=0)  # (m,)
        rhs = np.dot(pi, fsum)
        return s.sum() - rhs

    def branch_constraints(x: np.ndarray) -> np.ndarray:
        # for each branch b: E_{a_-A}[ f(a) + C*1[a=0] ] - t <= 0
        s, f, t = unpack(x)
        out = np.zeros(len(branches), dtype=float)

        fsum = f.sum(axis=0)  # (m,)

        for b_idx, (A, aA) in enumerate(branches):
            # build mask for rows matching assignment on A
            mask = np.ones(m, dtype=bool)
            for pos, j in enumerate(A):
                mask &= (bits[:, j] == aA[pos])

            w = prob_over_complement(s, A)  # (m,)
            expr = np.dot(w[mask], fsum[mask] + C * all_zero_mask[mask])
            out[b_idx] = expr - t
        return out

    # SciPy constraint objects:
    eq_con = NonlinearConstraint(eq_constraints, 0.0, 0.0)          # equality
    ir_con = NonlinearConstraint(ir_constraint, -np.inf, 0.0)       # <= 0
    br_con = NonlinearConstraint(branch_constraints, -np.inf, 0.0)  # <= 0

    # Bounds:
    # s in [0,1], f >= 0, t >= 0
    lb = np.zeros(n + n * m + 1, dtype=float)
    ub = np.full(n + n * m + 1, np.inf, dtype=float)
    ub[:n] = 1.0  # s upper bounds
    # f already has lb=0 and ub=inf
    # t already has lb=0 and ub=inf
    bounds = Bounds(lb, ub)

    rng = np.random.default_rng(seed)
    best_res = None

    for start in range(n_starts):
        # init s in (0,1), f positive, t something
        s0 = 0.1 + 0.8 * rng.random(n)
        f0 = np.full((n, m), 0.1, dtype=float)
        t0 = 1.0
        x0 = pack(s0, f0, t0)

        res = minimize(
            objective,
            x0,
            method="trust-constr",
            bounds=bounds,
            constraints=[eq_con, ir_con, br_con],
            options={
                "maxiter": maxiter,
                "verbose": 3 if (verbose and start == 0) else 0,
            },
        )

        if best_res is None or (res.success and res.fun < best_res.fun):
            best_res = res

    if best_res is None:
        raise RuntimeError("No result returned by SciPy (unexpected).")

    s_sol, f_sol, t_sol = unpack(best_res.x)

    return SciPySolveResult(
        success=bool(best_res.success),
        message=str(best_res.message),
        t=float(t_sol),
        s=s_sol.astype(float),
        f=f_sol.astype(float),
        fun=float(best_res.fun),
        nit=int(getattr(best_res, "nit", -1)),
    )

In [7]:
%%time
sol = solve_full_program_scipy(n=4, h=1, C=3.0, n_starts=5, seed=1, verbose=False)
print("success:", sol.success)
print("message:", sol.message)
print("t* =", sol.t)
print("s* =", sol.s)

print("f_1(000) =", sol.f[0, 0])

success: True
message: `gtol` termination condition is satisfied.
t* = 2.0000029661913885
s* = [0.33333293 0.33333304 0.99999981 0.33333299]
f_1(000) = 1.7706752292408065e-07
CPU times: user 1min 59s, sys: 9min 24s, total: 11min 24s
Wall time: 1min 8s


In [17]:
%%time
out = minimize_over_s(n=6, h=4, C=8.0, n_random=300, n_local_starts=8, seed=1, local_method="Powell")
print("success:", out.success)
print("t* =", out.t)
print("s* =", out.s)

success: True
t* = 2.3885382809802693
s* = [0.18550826 0.40629624 0.4039483  1.         0.37186237 0.02092311]
CPU times: user 23.9 s, sys: 55.1 ms, total: 24 s
Wall time: 24 s


In [21]:
%%time
n = 6
h = 4
C = 8.0

solver = FullProgramLPSolver(n=n, h=h, C=C)

interior = minimize_over_s_sorted(
    solver,
    n_random=300,
    n_local_starts=8,
    seed=1,
    local_method="Powell",
)

faces = minimize_over_faces(
    solver,
    seed=1,
    local_method="Powell",
    n_local_starts_per_face=3,
    include_endpoints=True,
    max_faces=None,   # set to e.g. 200 if you want a quick partial scan
)

print("=== Sorted 'interior-ish' search (always sorts s decreasing) ===")
print("ok:", interior.ok)
print("t:", interior.t)
print("s:", interior.s)

print("\n=== Best face over 1>=s1>=...>=sn>=0 (tight adjacencies + endpoints) ===")
print("ok:", faces.ok)
print("t:", faces.t)
print("s:", faces.s)
print("face (tight_adj, top_fixed, bottom_fixed):", faces.face)
if faces.face is not None:
    tight_adj, top_fixed, bottom_fixed = faces.face
    print("  tight_adj indices (0-based):", tight_adj)
    print("  meaning: s_i = s_{i+1} for i in tight_adj")
    print("  top_fixed (s1=1):", top_fixed)
    print("  bottom_fixed (sn=0):", bottom_fixed)


=== Sorted 'interior-ish' search (always sorts s decreasing) ===
ok: True
t: 2.3599996845764584
s: [1.         0.3107921  0.29620201 0.28817095 0.25090803 0.21392658]

=== Best face over 1>=s1>=...>=sn>=0 (tight adjacencies + endpoints) ===
ok: True
t: 2.3245553182504723
s: [0.70943058 0.70943058 0.70943058 0.         0.         0.        ]
face (tight_adj, top_fixed, bottom_fixed): ((0, 1, 3, 4), False, True)
  tight_adj indices (0-based): (0, 1, 3, 4)
  meaning: s_i = s_{i+1} for i in tight_adj
  top_fixed (s1=1): False
  bottom_fixed (sn=0): True
CPU times: user 5min 33s, sys: 549 ms, total: 5min 34s
Wall time: 5min 35s


In [26]:
%%time
run_grid()

n= 2, h= 1, C= 3.0
  face:     t= 1.5, s=[1.000000 0.500000]
  interior: t= 1.5, s=[1.000000 0.500000]
n= 3, h= 1, C= 3.0
  face:     t= 1.8, s=[1.000000 0.400000 0.400000]
  interior: t= 1.80001, s=[0.999853 0.400006 0.399996]
n= 3, h= 2, C= 3.0
  face:     t= 1.5, s=[1.000000 0.500000 0.000000]
  interior: t= 1.50053, s=[1.000000 0.499356 0.001170]
n= 4, h= 1, C= 3.0
  face:     t= 2, s=[1.000000 0.333333 0.333333 0.333333]
  interior: t= 2.00001, s=[1.000000 0.333352 0.333331 0.333329]
n= 4, h= 2, C= 3.0
  face:     t= 1.72508, s=[1.000000 0.241694 0.241694 0.241694]
  interior: t= 1.77783, s=[1.000000 0.333455 0.333455 0.110922]
n= 4, h= 3, C= 3.0
  face:     t= 1.5, s=[1.000000 0.500000 0.000000 0.000000]
  interior: t= 1.5, s=[1.000000e+00 5.000000e-01 4.248354e-18 4.248354e-18]
n= 5, h= 1, C= 3.0
  face:     t= 2.14286, s=[1.000000 0.285714 0.285714 0.285714 0.285714]
  interior: t= 2.16192, s=[1.000000 0.322750 0.279511 0.279360 0.279360]
n= 5, h= 2, C= 3.0
  face:     t= 1.854

KeyboardInterrupt: 

In [28]:
%%time
run_grid(Ns = [2, 3, 4, 5], Cs = [7.0, 11.0, 15.0])

n= 2, h= 1, C= 7.0
  face:     t= 1.75, s=[1.000000 0.750000]
  interior: t= 1.75, s=[1.000000 0.750000]
n= 3, h= 1, C= 7.0
  face:     t= 2.26628, s=[0.676246 0.676246 0.676246]
  interior: t= 2.2668, s=[0.677773 0.676743 0.676171]
n= 3, h= 2, C= 7.0
  face:     t= 1.75, s=[1.000000 0.750000 0.000000]
  interior: t= 1.75004, s=[1.000000e+00 7.499926e-01 4.388965e-05]
n= 4, h= 1, C= 7.0
  face:     t= 2.63488, s=[0.623589 0.623589 0.623589 0.623589]
  interior: t= 2.67581, s=[0.672342 0.635874 0.617760 0.617741]
n= 4, h= 2, C= 7.0
  face:     t= 2.26628, s=[0.676246 0.676246 0.676246 0.000000]
  interior: t= 2.32923, s=[0.655447 0.616391 0.615949 0.133584]
n= 4, h= 3, C= 7.0
  face:     t= 1.75, s=[1.000000 0.750000 0.000000 0.000000]
  interior: t= 1.88536, s=[1.000000 0.647167 0.231520 0.006670]
n= 5, h= 1, C= 7.0
  face:     t= 2.97311, s=[0.575271 0.575271 0.575271 0.575271 0.575271]
  interior: t= 3.03858, s=[0.608803 0.595602 0.593073 0.565917 0.565917]
n= 5, h= 2, C= 7.0
  face:

In [27]:
def run_grid(Ns = [2, 3, 4, 5, 6, 7, 8], Cs = [3.0, 7.0, 11.0, 15.0]):

    # Tuning knobs (increase for better quality, decrease for speed)
    interior_n_random = 250
    interior_n_local_starts = 6

    face_local_starts_per_face = 1   # keep small; full face scan is already a lot
    face_include_endpoints = True
    face_max_faces = None            # full enumeration; set e.g. 400 to cap runtime
    face_maxiter = 160               # local optimizer effort per face-start
    interior_maxiter = 250

    # Monkey-patch maxiter controls (your functions call minimize(..., options={...}))
    # If you want cleaner control, just edit the maxiter numbers inside those functions.
    import types
    from scipy.optimize import minimize as _minimize

    # Wrap scipy.minimize so we can override maxiter consistently (optional).
    def minimize_with_cap(fun, x0, method, options):
        opts = dict(options or {})
        if method == "Powell":
            # apply caps
            opts["maxiter"] = min(opts.get("maxiter", interior_maxiter), interior_maxiter)
        else:
            opts["maxiter"] = min(opts.get("maxiter", interior_maxiter), interior_maxiter)
        return _minimize(fun, x0, method=method, options=opts)

    # If you DON'T want this wrapper trick, delete it and just edit the maxiter values
    # inside minimize_over_s_sorted / minimize_over_faces.

    # Run
    for C in Cs:
        for n in Ns:
            for h in range(1, n):  # all h up to n-1
                solver = FullProgramLPSolver(n=n, h=h, C=C)

                # Best "interior-ish" (sorted)
                interior = minimize_over_s_sorted(
                    solver,
                    n_random=interior_n_random,
                    n_local_starts=interior_n_local_starts,
                    seed=1,
                    local_method="Powell",
                )

                # Best face on chain polytope faces
                faces = minimize_over_faces(
                    solver,
                    seed=1,
                    local_method="Powell",
                    n_local_starts_per_face=face_local_starts_per_face,
                    include_endpoints=face_include_endpoints,
                    max_faces=face_max_faces,
                )

                # Print summary line
                print(
                    f"n={n:>2}, h={h:>2}, C={C:>4}\n"
                    f"  face:     t={faces.t: .6g}, "
                    f"s={np.array2string(faces.s, precision=6, floatmode='fixed')}\n"
                    f"  interior: t={interior.t: .6g}, "
                    f"s={np.array2string(interior.s, precision=6, floatmode='fixed')}"
                )


In [20]:
import itertools
import numpy as np
from dataclasses import dataclass
from typing import Tuple, List, Optional

from scipy.optimize import linprog, minimize


# --------------------------
# Utilities
# --------------------------

def all_bitstrings(n: int) -> np.ndarray:
    """Array shape (2^n, n) containing all bitstrings in lexicographic order."""
    return np.array(list(itertools.product([0, 1], repeat=n)), dtype=int)


def all_branches(n: int, h: int) -> List[Tuple[Tuple[int, ...], Tuple[int, ...]]]:
    """All branches (A, aA) where |A|=n-h and aA is an assignment on A."""
    k = n - h
    branches = []
    for A in itertools.combinations(range(n), k):
        for aA in itertools.product([0, 1], repeat=k):
            branches.append((tuple(A), tuple(aA)))
    return branches


def sigmoid(z):
    z = np.clip(z, -40, 40)
    return 1.0 / (1.0 + np.exp(-z))


def sort_desc(s: np.ndarray) -> np.ndarray:
    return np.sort(np.asarray(s, dtype=float))[::-1]


def linprog_any(c, A_ub, b_ub, A_eq, b_eq, bounds):
    """
    Try modern HiGHS methods; fall back to older ones if SciPy doesn't have HiGHS.
    """
    methods = ("highs", "highs-ds", "highs-ipm", "interior-point", "revised simplex", "simplex")
    last_err = None
    for method in methods:
        try:
            return linprog(
                c=c, A_ub=A_ub, b_ub=b_ub,
                A_eq=A_eq, b_eq=b_eq,
                bounds=bounds,
                method=method
            )
        except ValueError as e:
            last_err = e
            continue
    if last_err is not None:
        raise last_err
    raise RuntimeError("linprog failed unexpectedly.")


# --------------------------
# Inner LP solver
# --------------------------

@dataclass
class LPValue:
    ok: bool
    t: float
    f: Optional[np.ndarray]  # shape (n, 2^n)
    msg: str


class FullProgramLPSolver:
    """
    For fixed s, solve the inner LP in (f,t).

    Variables: f_{i,a} >= 0 for i in [n], a in {0,1}^n, and t >= 0.
    Objective: minimize t.

    Constraints implemented (linear for fixed s):
      - Eq (interior equality form) for each i:
          E_{a_-i}[ f_i(1,a_-i) ] - E_{a_-i}[ f_i(0,a_-i) ] = 1
        where expectation uses product distribution over coordinates != i.
      - IR:
          sum_i s_i <= sum_a pi_s(a) * sum_i f_i(a)
      - Branch epigraph constraints for each (A,aA):
          E_{a_-A}[ sum_i f_i(aA,a_-A) + C*1[a=0] ] <= t
    """

    def __init__(self, n: int, h: int, C: float):
        assert 1 <= n
        assert 0 <= h <= n
        assert C >= 0

        self.n = n
        self.h = h
        self.C = float(C)

        self.bits = all_bitstrings(n)  # (m,n)
        self.m = self.bits.shape[0]

        self.branches = all_branches(n, h)
        self.num_branches = len(self.branches)

        self.all_zero = (self.bits.sum(axis=1) == 0).astype(float)  # (m,)

        # For each branch, store indices of a that match the fixed assignment on A
        self.branch_match_idx = []
        for (A, aA) in self.branches:
            mask = np.ones(self.m, dtype=bool)
            for pos, j in enumerate(A):
                mask &= (self.bits[:, j] == aA[pos])
            self.branch_match_idx.append(np.where(mask)[0])

    def _pi_full(self, s: np.ndarray) -> np.ndarray:
        # pi_s(a) = ∏_j s_j^{a_j} (1-s_j)^{1-a_j}
        term = np.where(self.bits == 1, s[None, :], (1.0 - s)[None, :])
        return term.prod(axis=1)  # (m,)

    def _pi_over_complement(self, s: np.ndarray, fixed: Tuple[int, ...]) -> np.ndarray:
        # product over j not in fixed
        fixed_mask = np.zeros(self.n, dtype=bool)
        fixed_mask[list(fixed)] = True
        if fixed_mask.all():
            return np.ones(self.m, dtype=float)
        term = np.where(self.bits == 1, s[None, :], (1.0 - s)[None, :])
        return term[:, ~fixed_mask].prod(axis=1)

    def solve_lp(self, s: np.ndarray, return_f: bool = False) -> LPValue:
        n, m, C = self.n, self.m, self.C
        s = np.clip(np.asarray(s, dtype=float), 0.0, 1.0)

        # x = [f_{0,0..m-1}, f_{1,0..m-1}, ..., f_{n-1,0..m-1}, t]
        N = n * m + 1
        t_idx = n * m

        # Objective min t
        c = np.zeros(N, dtype=float)
        c[t_idx] = 1.0

        # Bounds: f >= 0, t >= 0
        bounds = [(0.0, None)] * (n * m) + [(0.0, None)]

        # ----------------------
        # Equality constraints (Eq): A_eq x = b_eq
        A_eq = np.zeros((n, N), dtype=float)
        b_eq = np.ones(n, dtype=float)  # E1 - E0 = 1

        for i in range(n):
            w = self._pi_over_complement(s, (i,))  # (m,)
            mask1 = (self.bits[:, i] == 1)
            mask0 = ~mask1

            base = i * m
            A_eq[i, base + np.where(mask1)[0]] = w[mask1]
            A_eq[i, base + np.where(mask0)[0]] = -w[mask0]

        # ----------------------
        # Inequality constraints: A_ub x <= b_ub
        A_ub_rows = []
        b_ub_rows = []

        # IR: sum s_i <= sum_a pi(a) * sum_i f_i(a)
        # => - sum_a pi(a)*sum_i f_i(a) <= -sum s_i
        pi = self._pi_full(s)
        row_ir = np.zeros(N, dtype=float)
        for i in range(n):
            row_ir[i * m:(i + 1) * m] = -pi
        A_ub_rows.append(row_ir)
        b_ub_rows.append(-float(s.sum()))

        # Branch constraints: for each branch, expectation <= t
        for b, (A, _aA) in enumerate(self.branches):
            idx = self.branch_match_idx[b]
            w = self._pi_over_complement(s, A)  # (m,)

            row = np.zeros(N, dtype=float)
            for i in range(n):
                base = i * m
                row[base + idx] += w[idx]
            row[t_idx] = -1.0

            # Move constant C term to RHS:
            # sum w*(fsum + C*I0) <= t
            # sum w*fsum - t <= - sum w*(C*I0)
            rhs = -np.dot(w[idx], C * self.all_zero[idx])

            A_ub_rows.append(row)
            b_ub_rows.append(float(rhs))

        A_ub = np.vstack(A_ub_rows)
        b_ub = np.array(b_ub_rows, dtype=float)

        res = linprog_any(c, A_ub, b_ub, A_eq, b_eq, bounds)

        if not res.success:
            return LPValue(ok=False, t=float("inf"), f=None, msg=str(res.message))

        t_val = float(res.fun)
        if not return_f:
            return LPValue(ok=True, t=t_val, f=None, msg="OK")

        x = res.x
        f = x[:n * m].reshape((n, m))
        return LPValue(ok=True, t=t_val, f=f, msg="OK")


# --------------------------
# Face enumeration & optimization
# --------------------------

def blocks_from_equalities(n: int, tight_adj: Tuple[int, ...]) -> List[List[int]]:
    """
    tight_adj are 0-based adjacency indices: i in {0,...,n-2} means s_i = s_{i+1}.
    Returns contiguous blocks of equal coordinates.
    """
    tight = set(tight_adj)
    blocks = []
    cur = [0]
    for i in range(n - 1):
        if i in tight:
            cur.append(i + 1)
        else:
            blocks.append(cur)
            cur = [i + 1]
    blocks.append(cur)
    return blocks


def enumerate_faces(n: int, include_endpoints: bool = True):
    """
    Faces of chain polytope 1>=s1>=...>=sn>=0 determined by:
      - which adjacent inequalities are tight (s_i=s_{i+1})
      - optional endpoint tightness: s1=1 and/or sn=0
    """
    faces = []
    for mask in range(1 << (n - 1)):
        tight_adj = tuple(i for i in range(n - 1) if (mask >> i) & 1)
        if include_endpoints:
            for top_fixed in (False, True):
                for bottom_fixed in (False, True):
                    faces.append((tight_adj, top_fixed, bottom_fixed))
        else:
            faces.append((tight_adj, False, False))
    return faces


def face_parameterization(z: np.ndarray, k: int, top_fixed: bool, bottom_fixed: bool) -> np.ndarray:
    """
    Build monotone block-values v0>=v1>=...>=v_{k-1} in [0,1],
    with optional v0=1 and/or v_{k-1}=0.

    Parameterization:
      if top_fixed: v0=1 else v0=sigmoid(z0)
      for j>=1: vj = v_{j-1} * sigmoid(zj)   (monotone decreasing)
      if bottom_fixed: set last block exactly 0
    """
    z = np.asarray(z, dtype=float)
    v = np.zeros(k, dtype=float)
    if k == 0:
        return v

    idx = 0
    if top_fixed:
        v[0] = 1.0
    else:
        v[0] = sigmoid(z[idx]); idx += 1

    last_free = k - (1 if bottom_fixed else 0)
    for j in range(1, last_free):
        v[j] = v[j - 1] * sigmoid(z[idx])
        idx += 1

    if bottom_fixed:
        v[-1] = 0.0

    return v


def blocks_to_full_s(n: int, blocks: List[List[int]], v: np.ndarray) -> np.ndarray:
    s = np.zeros(n, dtype=float)
    for b_idx, idxs in enumerate(blocks):
        s[idxs] = v[b_idx]
    return s


@dataclass
class SearchResult:
    ok: bool
    t: float
    s: np.ndarray
    face: Optional[Tuple[Tuple[int, ...], bool, bool]]
    msg: str


def minimize_over_s_sorted(
    solver: FullProgramLPSolver,
    n_random: int = 200,
    n_local_starts: int = 8,
    seed: int = 0,
    local_method: str = "Powell",
) -> SearchResult:
    """
    Outer search over s in [0,1]^n, ALWAYS projecting to sorted decreasing order.
    """
    n = solver.n
    rng = np.random.default_rng(seed)

    best_t = float("inf")
    best_s = None

    # Random screening
    for _ in range(n_random):
        s = sort_desc(rng.random(n))
        val = solver.solve_lp(s, return_f=False)
        if val.ok and val.t < best_t:
            best_t = val.t
            best_s = s.copy()

    if best_s is None:
        best_s = np.full(n, 0.5)

    def obj(z):
        s = sort_desc(sigmoid(z))
        val = solver.solve_lp(s, return_f=False)
        return val.t if val.ok else 1e6

    # Build start points in z-space
    def inv_sig(u):
        u = np.clip(u, 1e-6, 1 - 1e-6)
        return np.log(u / (1 - u))

    starts = [inv_sig(best_s)]
    for _ in range(max(0, n_local_starts - 1)):
        s0 = sort_desc(rng.random(n))
        starts.append(inv_sig(s0))

    best_z = None
    best_local = best_t

    for z0 in starts:
        res = minimize(obj, z0, method=local_method, options={"maxiter": 400, "disp": False})
        if float(res.fun) < best_local:
            best_local = float(res.fun)
            best_z = res.x.copy()

    s_star = sort_desc(sigmoid(best_z)) if best_z is not None else best_s
    lp_star = solver.solve_lp(s_star, return_f=False)
    return SearchResult(lp_star.ok, lp_star.t, s_star, None, lp_star.msg)


def minimize_over_faces(
    solver: FullProgramLPSolver,
    seed: int = 0,
    local_method: str = "Powell",
    n_local_starts_per_face: int = 3,
    include_endpoints: bool = True,
    max_faces: Optional[int] = None,
) -> SearchResult:
    """
    Enumerate and optimize over faces of the chain polytope 1>=s1>=...>=sn>=0.

    Face is determined by (tight_adj, top_fixed, bottom_fixed).
    """
    n = solver.n
    rng = np.random.default_rng(seed)

    faces = enumerate_faces(n, include_endpoints=include_endpoints)
    if max_faces is not None:
        faces = faces[:max_faces]

    best = SearchResult(False, float("inf"), np.zeros(n), None, "no feasible face found")

    for (tight_adj, top_fixed, bottom_fixed) in faces:
        blocks = blocks_from_equalities(n, tight_adj)
        k = len(blocks)

        # If k==1 and both endpoints fixed => impossible unless 1==0; skip
        if k == 1 and top_fixed and bottom_fixed:
            continue

        # Dimension of z:
        dim = 0
        if not top_fixed:
            dim += 1
        dim += max(0, (k - 1) - (1 if bottom_fixed else 0))

        if dim == 0:
            # fully fixed (rare); evaluate directly
            v = np.zeros(k)
            if top_fixed:
                v[0] = 1.0
            if bottom_fixed:
                v[-1] = 0.0
            s = blocks_to_full_s(n, blocks, v)
            val = solver.solve_lp(s, return_f=False)
            if val.ok and val.t < best.t:
                best = SearchResult(True, val.t, s, (tight_adj, top_fixed, bottom_fixed), "OK")
            continue

        def obj(z):
            v = face_parameterization(z, k, top_fixed, bottom_fixed)
            s = blocks_to_full_s(n, blocks, v)
            # already satisfies ordering by construction
            val = solver.solve_lp(s, return_f=False)
            return val.t if val.ok else 1e6

        for _ in range(n_local_starts_per_face):
            z0 = rng.normal(size=dim)
            res = minimize(obj, z0, method=local_method, options={"maxiter": 300, "disp": False})
            if float(res.fun) < best.t:
                v = face_parameterization(res.x, k, top_fixed, bottom_fixed)
                s = blocks_to_full_s(n, blocks, v)
                best = SearchResult(True, float(res.fun), s, (tight_adj, top_fixed, bottom_fixed), "OK")

    return best