In [3]:
import math
import time
from dataclasses import dataclass
from typing import Optional, Tuple

import numpy as np


# Search domain bounds from your problem
P1_MAX = 2.0 * math.e     # p1 in [0, 2e]
P2_MAX = 2.0 * math.e     # p2 in [0, 2e]
P3_MAX = 4.0 * math.e     # p3 in [0, 4e]

EPS = 1e-15  # small tolerance for zero-denominator checks, etc.


@dataclass
class SearchResult:
    p1: float
    p2: float
    p3: float
    value: float
    evals: int
    seconds: float


# ---------- Helper functions for bounds and objective ----------

def _lower_bound_gamma_k1(p1: float, p2: float) -> float:
    """
    For k=1:
        lb = min(1, max{ 1/(1*p1), p1 / (1*p1 - p2 + p1) })
           = min(1, max{ 1/p1, p1 / (2*p1 - p2) })

    Robust handling of p1=0 and zero denominator.
    """
    # termA = 1/(k * p_k) with k=1, p_k=p1
    if p1 <= 0.0:
        termA = float('inf')  # forces lb to 1
    else:
        termA = 1.0 / p1

    denom = 2.0 * p1 - p2  # (k*p_k - p_{k+1} + p1) with k=1
    if abs(denom) <= EPS:
        termB = float('inf')
    else:
        termB = p1 / denom

    lb = min(1.0, max(termA, termB))
    # numeric clamp
    if lb < 0.0:
        lb = 0.0
    elif lb > 1.0:
        lb = 1.0
    return lb


def _lower_bound_gamma_k2_vectorized(p1: float, p2: float, p3_vec: np.ndarray) -> np.ndarray:
    """
    For k=2 (vectorized over p3):
        lb = min(1, max{ 1/(2*p2), p1 / (2*p2 - p3 + p1) })
           = min(1, max{ 1/(2*p2), p1 / (p1 + 2*p2 - p3) })
    """
    if p2 <= 0.0:
        termA = np.inf  # forces lb to 1
    else:
        termA = 1.0 / (2.0 * p2)

    denom = p1 + 2.0 * p2 - p3_vec  # (k*p_k - p_{k+1} + p1) with k=2
    termB = np.empty_like(denom, dtype=float)
    mask_zero = np.isclose(denom, 0.0, atol=EPS)
    termB[mask_zero] = np.inf
    termB[~mask_zero] = p1 / denom[~mask_zero]

    lb = np.minimum(1.0, np.maximum(termA, termB))
    # numeric clamp to [0,1]
    np.clip(lb, 0.0, 1.0, out=lb)
    return lb


def _Ak(gamma, k: int):
    """
    A_k(gamma) = gamma/(1+k) + (1 - gamma) = 1 - (k/(k+1)) * gamma
    Works for scalar or numpy array gamma.
    """
    return 1.0 - (k / (k + 1.0)) * gamma


def _Bk(p1: float, pk1, gamma):
    """
    B_k(gamma) = min(1, 1 / (p_{k+1} * gamma + p1 * (1 - gamma)) ).

    If denominator == 0, treat 1/0 as +inf -> min(1, +inf) = 1.
    Works for scalar pk1, gamma or vectorized pk1/gamma arrays.
    """
    denom = p1 + gamma * (pk1 - p1)  # p_{k+1}*gamma + p1*(1-gamma)
    if isinstance(denom, np.ndarray):
        inv = np.empty_like(denom, dtype=float)
        mask_zero = np.isclose(denom, 0.0, atol=EPS)
        inv[mask_zero] = np.inf
        inv[~mask_zero] = 1.0 / denom[~mask_zero]
        return np.minimum(1.0, inv)
    else:
        if abs(denom) <= EPS:
            return 1.0
        return min(1.0, 1.0 / denom)


def _objective_for_gamma_k1(p1: float, p2: float, gamma: float) -> float:
    return _Ak(gamma, k=1) * _Bk(p1, p2, gamma)


def _objective_for_gamma_k2_vectorized(p1: float, p3_vec: np.ndarray, gamma_vec: np.ndarray) -> np.ndarray:
    return _Ak(gamma_vec, k=2) * _Bk(p1, p3_vec, gamma_vec)


# ---------- Core: evaluate max over gamma for each k ----------

def _max_over_gamma_k1(p1: float, p2: float, clip_gamma_le_one: bool = False) -> float:
    """
    Returns max over feasible gamma ∈ {lb1, p1/(p2 - p1)} with gamma >= lb1.
    Optionally cap gamma ≤ 1 if desired (clip_gamma_le_one=True).
    """
    lb1 = _lower_bound_gamma_k1(p1, p2)
    best = -float('inf')

    # Candidate 1: gamma = lb1
    g1 = lb1
    if clip_gamma_le_one:
        g1 = min(g1, 1.0)
    val1 = _objective_for_gamma_k1(p1, p2, g1)
    best = max(best, val1)

    # Candidate 2: gamma = p1/(p2 - p1) if defined and >= lb1
    denom = (p2 - p1)
    if abs(denom) > EPS:
        g2 = p1 / denom
        if clip_gamma_le_one:
            g2 = min(g2, 1.0)
        if g2 >= lb1:
            val2 = _objective_for_gamma_k1(p1, p2, g2)
            if val2 > best:
                best = val2

    return best


def _max_over_gamma_k2_vectorized(
    p1: float,
    p2: float,
    p3_vals: np.ndarray,
    clip_gamma_le_one: bool = False
) -> np.ndarray:
    """
    Vectorized over p3. For each p3, returns:
        max over gamma ∈ {lb2, p1/(p3 - p1)} with gamma >= lb2 (discard invalids).
    """
    lb2 = _lower_bound_gamma_k2_vectorized(p1, p2, p3_vals)

    # Candidate 1: gamma = lb2
    g1 = lb2.copy()
    if clip_gamma_le_one:
        np.minimum(g1, 1.0, out=g1)
    val1 = _objective_for_gamma_k2_vectorized(p1, p3_vals, g1)

    # Candidate 2: gamma = p1/(p3 - p1) if defined and >= lb2
    denom = (p3_vals - p1)
    g2 = np.empty_like(p3_vals, dtype=float)
    valid = ~np.isclose(denom, 0.0, atol=EPS)
    g2[:] = np.nan
    g2[valid] = p1 / denom[valid]
    if clip_gamma_le_one:
        # only cap where defined
        mask_defined = valid & ~np.isnan(g2)
        g2[mask_defined] = np.minimum(g2[mask_defined], 1.0)

    feas = valid & (g2 >= lb2)
    # Evaluate only where feasible; elsewhere use -inf so max ignores them
    val2 = np.full_like(p3_vals, -np.inf, dtype=float)
    if np.any(feas):
        val2[feas] = _objective_for_gamma_k2_vectorized(p1, p3_vals[feas], g2[feas])

    # Take elementwise max across candidates
    return np.maximum(val1, val2)


# ---------- Brute-force search over (p1, p2, p3) ----------

def brute_force_search(
    steps_p1: int = 121,
    steps_p2: int = 121,
    steps_p3: int = 241,
    p1_max: float = P1_MAX,
    p2_max: float = P2_MAX,
    p3_max: float = P3_MAX,
    clip_gamma_le_one: bool = False,
) -> SearchResult:
    """
    Brute-force over a rectilinear grid of (p1, p2, p3).
    Internally vectorizes across the p3 axis for speed.

    Returns the best (p1, p2, p3) and the objective value:
        obj(p1,p2,p3) = min( max_gamma_k1, max_gamma_k2 )
    """
    t0 = time.perf_counter()

    p1_vals = np.linspace(0.0, p1_max, steps_p1)
    p2_vals = np.linspace(0.0, p2_max, steps_p2)
    p3_vals = np.linspace(0.0, p3_max, steps_p3)

    best_val = -float('inf')
    best_tuple: Optional[Tuple[float, float, float]] = None
    evals = 0

    for p1 in p1_vals:
        for p2 in p2_vals:
            # k=1 term is scalar for this (p1,p2)
            max_k1 = _max_over_gamma_k1(p1, p2, clip_gamma_le_one=clip_gamma_le_one)

            # k=2 term is vectorized across all p3
            max_k2_vec = _max_over_gamma_k2_vectorized(
                p1, p2, p3_vals, clip_gamma_le_one=clip_gamma_le_one
            )

            # Overall objective is min over k
            obj_vec = np.minimum(max_k1, max_k2_vec)

            idx = int(np.argmax(obj_vec))
            cand_val = float(obj_vec[idx])
            cand_p3 = float(p3_vals[idx])

            evals += len(p3_vals)

            if cand_val > best_val:
                best_val = cand_val
                best_tuple = (float(p1), float(p2), cand_p3)

    t1 = time.perf_counter()

    assert best_tuple is not None
    return SearchResult(
        p1=best_tuple[0],
        p2=best_tuple[1],
        p3=best_tuple[2],
        value=best_val,
        evals=evals,
        seconds=(t1 - t0),
    )


def refine_around_best(
    best: SearchResult,
    span_steps: Tuple[int, int, int] = (6, 6, 12),
    refine_steps: Tuple[int, int, int] = (161, 161, 321),
    p1_max: float = P1_MAX,
    p2_max: float = P2_MAX,
    p3_max: float = P3_MAX,
    coarse_steps: Tuple[int, int, int] = (121, 121, 241),
    clip_gamma_le_one: bool = False,
) -> SearchResult:
    """
    Local refinement: build a window around the coarse-best point spanning roughly ±span_steps
    coarse-grid increments on each axis, then resample that window at 'refine_steps' resolution.
    """
    (c1, c2, c3) = coarse_steps
    (s1, s2, s3) = span_steps
    (r1, r2, r3) = refine_steps

    def window(lo, hi, best_v, s, max_v, c):
        step = max_v / max(c - 1, 1)
        return max(lo, best_v - s * step), min(hi, best_v + s * step)

    p1_lo, p1_hi = window(0.0, p1_max, best.p1, s1, p1_max, c1)
    p2_lo, p2_hi = window(0.0, p2_max, best.p2, s2, p2_max, c2)
    p3_lo, p3_hi = window(0.0, p3_max, best.p3, s3, p3_max, c3)

    r1 = max(2, r1)
    r2 = max(2, r2)
    r3 = max(2, r3)

    t0 = time.perf_counter()

    p1_vals = np.linspace(p1_lo, p1_hi, r1)
    p2_vals = np.linspace(p2_lo, p2_hi, r2)
    p3_vals = np.linspace(p3_lo, p3_hi, r3)

    best_val = -float('inf')
    best_tuple: Optional[Tuple[float, float, float]] = None
    evals = 0

    for p1 in p1_vals:
        for p2 in p2_vals:
            max_k1 = _max_over_gamma_k1(p1, p2, clip_gamma_le_one=clip_gamma_le_one)
            max_k2_vec = _max_over_gamma_k2_vectorized(
                p1, p2, p3_vals, clip_gamma_le_one=clip_gamma_le_one
            )
            obj_vec = np.minimum(max_k1, max_k2_vec)

            idx = int(np.argmax(obj_vec))
            cand_val = float(obj_vec[idx])
            cand_p3 = float(p3_vals[idx])

            evals += len(p3_vals)

            if cand_val > best_val:
                best_val = cand_val
                best_tuple = (float(p1), float(p2), cand_p3)

    t1 = time.perf_counter()

    assert best_tuple is not None
    return SearchResult(
        p1=best_tuple[0],
        p2=best_tuple[1],
        p3=best_tuple[2],
        value=best_val,
        evals=evals,
        seconds=(t1 - t0),
    )


if __name__ == "__main__":
    # Example usage:
    # Adjust the step counts to fit your time budget and desired resolution.
    coarse = brute_force_search(
        steps_p1=121,   # resolution along p1 in [0, 2e]
        steps_p2=121,   # resolution along p2 in [0, 2e]
        steps_p3=241,   # resolution along p3 in [0, 4e]
        clip_gamma_le_one=False,  # Set True if you intend gamma <= 1
    )
    print("[COARSE] best =", coarse)

    # Optional refinement around the coarse best
    refined = refine_around_best(
        best=coarse,
        span_steps=(6, 6, 12),
        refine_steps=(181, 181, 361),
        coarse_steps=(121, 121, 241),
        clip_gamma_le_one=False,  # keep consistent with above
    )
    print("[REFINED] best =", refined)


[COARSE] best = SearchResult(p1=1.3138362170885385, p2=0.9060939428196817, p3=0.0, value=0.616777659990171, evals=3528481, seconds=0.9185979999965639)
[REFINED] best = SearchResult(p1=1.3319580959449322, p2=0.887972063963288, p3=0.0, value=0.6246128151311813, evals=11826721, seconds=1.5772351670020726)
