In [2]:
import math
import time
from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np
# -------------------- Domain bounds --------------------

# p1 ∈ [0, 2e], p2 ∈ [0, 4e], p3 ∈ [0, 12e]
P1_MAX = 2.0 * math.e
P2_MAX = 4.0 * math.e
P3_MAX = 12.0 * math.e

EPS = 1e-15  # tolerance for zero checks and numeric stability


@dataclass
class SearchResult:
    p1: float
    p2: float
    p3: float
    value: float
    evals: int
    seconds: float


# -------------------- Utilities for A_k and B_k --------------------

def A_k(gamma, k: int):
    """
    A_k(gamma) = gamma/(1+k) + (1 - gamma) = 1 - (k/(k+1))*gamma
    Works with scalars or numpy arrays.
    """
    return 1.0 - (k / (k + 1.0)) * gamma
def B_k(p1: float, pkplus_over: np.ndarray | float, gamma: np.ndarray | float):
    """
    B_k(gamma) = min(1, 1 / ( (p_{k+1}/(k+1)) * gamma + p1 * (1 - gamma)) ).

    Note: denom = p1 + gamma*(pkplus_over - p1), which for p1>=0, pkplus_over>=0 and gamma ∈ [0,1]
    is a convex combination of p1 and pkplus_over, hence nonnegative.
    If denom == 0 (only when p1=pkplus_over=0), we set 1/denom = +inf, so min(1, +inf) = 1.
    """
    denom = p1 + gamma * (pkplus_over - p1)
    if isinstance(denom, np.ndarray):
        inv = np.empty_like(denom, dtype=float)
        mask_zero = np.isclose(denom, 0.0, atol=EPS)
        inv[mask_zero] = np.inf
        inv[~mask_zero] = 1.0 / denom[~mask_zero]
        return np.minimum(1.0, inv)
    else:
        if abs(denom) <= EPS:
            return 1.0
        return min(1.0, 1.0 / denom)


# -------------------- Lower bounds and gamma candidates --------------------

def lower_bound_Lk_scalar(p1: float, pk: float, pkplus_over: float) -> float:
    """
    L_k = min(1, max{ 1/p_k, p1 / (p_k - pkplus_over + p1) })
    Scalar version for k=1.
    """
    # Handle 1/pk
    if pk <= 0.0:
        term1 = float('inf')
    else:
        term1 = 1.0 / pk
    denom = pk - pkplus_over + p1
    if abs(denom) <= EPS:
        term2 = float('inf')
    else:
        term2 = p1 / denom

    Lk = min(1.0, max(term1, term2))
    # clamp to [0,1]
    if Lk < 0.0:
        Lk = 0.0
    elif Lk > 1.0:
        Lk = 1.0
    return Lk


def lower_bound_Lk_vec(p1: float, pk: float, pkplus_over_vec: np.ndarray) -> np.ndarray:
    """
    L_k = min(1, max{ 1/p_k, p1 / (p_k - pkplus_over + p1) })
    Vectorized over pkplus_over (used for k=2 with varying p3).
    """
    if pk <= 0.0:
        term1 = np.inf
    else:
        term1 = 1.0 / pk

    denom = pk - pkplus_over_vec + p1
    term2 = np.empty_like(denom, dtype=float)
    mask_zero = np.isclose(denom, 0.0, atol=EPS)
    term2[mask_zero] = np.inf
    term2[~mask_zero] = p1 / denom[~mask_zero]

    Lk = np.minimum(1.0, np.maximum(term1, term2))
    np.clip(Lk, 0.0, 1.0, out=Lk)
    return Lk


def gamma1_candidate_scalar(p1: float, pk: float, pkplus_over: float, k: int) -> float:
    """
    gamma^(1)_k = min(1, max{ 1/(p_k * k), p1 / (p_k - pkplus_over + p1) })
    Scalar version (used for k=1).
    """
    if pk <= 0.0:
        term1 = float('inf')
    else:
        term1 = 1.0 / (pk * k)

    denom = pk - pkplus_over + p1
    if abs(denom) <= EPS:
        term2 = float('inf')
    else:
        term2 = p1 / denom

    g1 = min(1.0, max(term1, term2))
    # ensure [0,1]
    if g1 < 0.0:
        g1 = 0.0
    elif g1 > 1.0:
        g1 = 1.0
    return g1


def gamma1_candidate_vec(p1: float, pk: float, pkplus_over_vec: np.ndarray, k: int) -> np.ndarray:
    """
    Vectorized gamma^(1)_k over pkplus_over_vec (used for k=2).
    """
    if pk <= 0.0:
        term1 = np.inf
    else:
        term1 = 1.0 / (pk * k)

    denom = pk - pkplus_over_vec + p1
    term2 = np.empty_like(denom, dtype=float)
    mask_zero = np.isclose(denom, 0.0, atol=EPS)
    term2[mask_zero] = np.inf
    term2[~mask_zero] = p1 / denom[~mask_zero]

    g1 = np.minimum(1.0, np.maximum(term1, term2))
    np.clip(g1, 0.0, 1.0, out=g1)
    return g1


def gamma2_candidate_scalar(p1: float, pkplus_over: float) -> Optional[float]:
    """
    gamma^(2)_k = p1 / (pkplus_over - p1), if denominator != 0, else None.
    """
    denom = pkplus_over - p1
    if abs(denom) <= EPS:
        return None
    return p1 / denom


def gamma2_candidate_vec(p1: float, pkplus_over_vec: np.ndarray) -> np.ndarray:
    """
    Vectorized gamma^(2)_k over pkplus_over_vec.
    Returns array with NaNs where undefined (denominator near zero).
    """
    denom = pkplus_over_vec - p1
    g2 = np.full_like(pkplus_over_vec, np.nan, dtype=float)
    mask = ~np.isclose(denom, 0.0, atol=EPS)
    g2[mask] = p1 / denom[mask]
    return g2


# -------------------- mu(p1,p2,p3,k): max over feasible candidates --------------------

def mu_k1(p1: float, p2: float) -> float:
    """
    k=1:
      pk = p1
      pkplus_over = p2/(1+1) = p2/2
      L1 = min(1, max{ 1/p1, p1/(p1 - p2/2 + p1) })
      candidates: gamma1 = min(1, max{ 1/(p1*1), p1/(p1 - p2/2 + p1) })
                  gamma2 = p1 / (p2/2 - p1)  (if defined)
      feasible if gamma ∈ [L1, 1].
    """
    pk = p1
    pkplus_over = p2 / 2.0
    L1 = lower_bound_Lk_scalar(p1, pk, pkplus_over)

    # Candidate 1
    g1 = gamma1_candidate_scalar(p1, pk, pkplus_over, k=1)
    vals = []

    if g1 >= L1 - 1e-15 and g1 <= 1.0 + 1e-15:
        val1 = A_k(g1, k=1) * B_k(p1, pkplus_over, g1)
        vals.append(val1)

    # Candidate 2
    g2 = gamma2_candidate_scalar(p1, pkplus_over)
    if g2 is not None and g2 >= L1 - 1e-15 and g2 <= 1.0 + 1e-15:
        val2 = A_k(g2, k=1) * B_k(p1, pkplus_over, g2)
        vals.append(val2)

    if not vals:
        return -float('inf')
    return max(vals)


def mu_k2_vec(p1: float, p2: float, p3_vals: np.ndarray) -> np.ndarray:
    """
    k=2 (vectorized across p3_vals):
      pk = p2
      pkplus_over = p3/(1+2) = p3/3
      L2 = min(1, max{ 1/p2, p1/(p2 - p3/3 + p1) })
      candidates: gamma1 = min(1, max{ 1/(p2*2), p1/(p2 - p3/3 + p1) })
                  gamma2 = p1 / (p3/3 - p1)  (if defined)
      feasible if gamma ∈ [L2, 1].
    """
    pk = p2
    pkplus_over = p3_vals / 3.0

    L2 = lower_bound_Lk_vec(p1, pk, pkplus_over)

    # Candidate 1 (vectorized)
    g1 = gamma1_candidate_vec(p1, pk, pkplus_over, k=2)
    val1 = A_k(g1, k=2) * B_k(p1, pkplus_over, g1)
    feas1 = (g1 >= L2 - 1e-15) & (g1 <= 1.0 + 1e-15)
    val1[~feas1] = -np.inf
    # Candidate 2 (vectorized)
    g2 = gamma2_candidate_vec(p1, pkplus_over)
    feas2 = ~np.isnan(g2)
    # Enforce feasibility band [L2, 1]
    feas2 = feas2 & (g2 >= L2 - 1e-15) & (g2 <= 1.0 + 1e-15)
    val2 = np.full_like(p3_vals, -np.inf, dtype=float)
    if np.any(feas2):
        val2[feas2] = A_k(g2[feas2], k=2) * B_k(p1, pkplus_over[feas2], g2[feas2])

    # Take max across feasible candidates
    return np.maximum(val1, val2)


# -------------------- Brute-force search over (p1, p2, p3) --------------------

def brute_force_search(
    steps_p1: int = 121,
    steps_p2: int = 161,
    steps_p3: int = 481,
    p1_max: float = P1_MAX,
    p2_max: float = P2_MAX,
    p3_max: float = P3_MAX,
) -> SearchResult:
    """
    Brute-force on a rectilinear grid.
    Returns the best (p1, p2, p3) and the objective value:
        obj(p1,p2,p3) = min( mu_k1, mu_k2 )
    """
    t0 = time.perf_counter()

    p1_vals = np.linspace(0.0, p1_max, steps_p1)
    p2_vals = np.linspace(0.0, p2_max, steps_p2)
    p3_vals = np.linspace(0.0, p3_max, steps_p3)

    best_val = -float('inf')
    best_tuple: Optional[Tuple[float, float, float]] = None
    evals = 0

    for p1 in p1_vals:
        for p2 in p2_vals:
            # k=1 term (scalar for this (p1,p2))
            mu1 = mu_k1(p1, p2)

            # k=2 term (vectorized across all p3)
            mu2_vec = mu_k2_vec(p1, p2, p3_vals)

            # Overall objective at each p3 is min over k
            obj_vec = np.minimum(mu1, mu2_vec)

            idx = int(np.argmax(obj_vec))
            cand_val = float(obj_vec[idx])
            cand_p3 = float(p3_vals[idx])

            evals += len(p3_vals)

            if cand_val > best_val:
                best_val = cand_val
                best_tuple = (float(p1), float(p2), cand_p3)

    t1 = time.perf_counter()

    assert best_tuple is not None
    return SearchResult(
        p1=best_tuple[0],
        p2=best_tuple[1],
        p3=best_tuple[2],
        value=best_val,
        evals=evals,
        seconds=(t1 - t0),
    )


def refine_around_best(
    best: SearchResult,
    span_steps: Tuple[int, int, int] = (6, 8, 12),
    refine_steps: Tuple[int, int, int] = (181, 221, 541),
    p1_max: float = P1_MAX,
    p2_max: float = P2_MAX,
    p3_max: float = P3_MAX,
    coarse_steps: Tuple[int, int, int] = (121, 161, 481),
) -> SearchResult:
    """
    Local refinement around the best coarse point: define a window covering roughly ±span_steps
    coarse-grid increments on each axis and resample it at a finer 'refine_steps' resolution.
    """
    (c1, c2, c3) = coarse_steps
    (s1, s2, s3) = span_steps
    (r1, r2, r3) = refine_steps

    def window(best_v, s, max_v, c):
        step = max_v / max(c - 1, 1)
        return max(0.0, best_v - s * step), min(max_v, best_v + s * step)

    p1_lo, p1_hi = window(best.p1, s1, p1_max, c1)
    p2_lo, p2_hi = window(best.p2, s2, p2_max, c2)
    p3_lo, p3_hi = window(best.p3, s3, p3_max, c3)

    r1 = max(2, r1)
    r2 = max(2, r2)
    r3 = max(2, r3)

    t0 = time.perf_counter()

    p1_vals = np.linspace(p1_lo, p1_hi, r1)
    p2_vals = np.linspace(p2_lo, p2_hi, r2)
    p3_vals = np.linspace(p3_lo, p3_hi, r3)

    best_val = -float('inf')
    best_tuple: Optional[Tuple[float, float, float]] = None
    evals = 0

    for p1 in p1_vals:
        for p2 in p2_vals:
            mu1 = mu_k1(p1, p2)
            mu2_vec = mu_k2_vec(p1, p2, p3_vals)

            obj_vec = np.minimum(mu1, mu2_vec)

            idx = int(np.argmax(obj_vec))
            cand_val = float(obj_vec[idx])
            cand_p3 = float(p3_vals[idx])

            evals += len(p3_vals)

            if cand_val > best_val:
                best_val = cand_val
                best_tuple = (float(p1), float(p2), cand_p3)

    t1 = time.perf_counter()

    assert best_tuple is not None
    return SearchResult(
        p1=best_tuple[0],
        p2=best_tuple[1],
        p3=best_tuple[2],
        value=best_val,
        evals=evals,
        seconds=(t1 - t0),
    )


if __name__ == "__main__":
    # Example: a reasonably fine grid to start with (adjust to your time budget).
    coarse = brute_force_search(
        steps_p1=1210,   # p1 ∈ [0, 2e]
        steps_p2=1610,   # p2 ∈ [0, 4e]
        steps_p3=4810,   # p3 ∈ [0, 12e]
    )
    print("[COARSE] best =", coarse)

    # Optional: refine around the best coarse point
    refined = refine_around_best(
        best=coarse,
        span_steps=(6, 8, 12),
        refine_steps=(1810, 2210, 5410),
        coarse_steps=(1210, 1610, 4810),
    )
    print("[REFINED] best =", refined)


[COARSE] best = SearchResult(p1=1.3355330075307468, p2=1.7772731407948512, p3=2.218036576184923, value=0.6248318783338377, evals=9370361000, seconds=306.56154729193076)
[REFINED] best = SearchResult(p1=1.3333405651175934, p2=1.7777870802916789, p3=2.2222049339278165, value=0.6249967057055402, evals=21640541000, seconds=585.4742142501054)
