In [2]:
import numpy as np
from dataclasses import dataclass
from typing import Tuple, Optional
import time


# Domain bounds
P1_MAX = 9.0 / 4.0
P2_MAX = 9.0 / 4.0
P3_MAX = 18.0

@dataclass
class SearchResult:
    p1: float
    p2: float
    p3: float
    value: float
    evals: int
    seconds: float


def _compute_x1(p1: float, p2: float) -> float:
    """
    x1 for k=1:
        x1 = min(1, max{ 1/(1*p1), p1 / (p1 + 1*p1 - p2) })
           = min(1, max{ 1/p1, p1 / (2*p1 - p2) })

    Robust handling:
    - If p1 == 0, 1/p1 -> +inf => x1 = 1.
    - If denominator (2*p1 - p2) == 0, treat p1/denom as +inf => x1 = 1.
    - If denominator < 0, the ratio can be negative; we allow it (max with 1/p1 handles it).
    """
    # termA = 1 / (k*p_k) = 1 / p1
    if p1 == 0:
        termA = np.inf
    else:
        termA = 1.0 / p1

    denom = 2.0 * p1 - p2
    if denom == 0.0:
        termB = np.inf
    else:
        termB = p1 / denom

    x1 = min(1.0, max(termA, termB))
    # Guard for any tiny numerical noise:
    if x1 < 0.0:
        x1 = 0.0
    elif x1 > 1.0:
        x1 = 1.0
    return x1


def _compute_x2_vectorized(p1: float, p2: float, p3_vec: np.ndarray) -> np.ndarray:
    """
    x2 for k=2, vectorized over p3:
        x2 = min(1, max{ 1/(2*p2), p1 / (p1 + 2*p2 - p3) })

    Robust handling:
    - If p2 == 0, 1/(2*p2) -> +inf => x2 = 1 (dominates).
    - For denom == 0, set p1/denom -> +inf.
    - For denom < 0, allow negative ratios; max with 1/(2*p2) handles it.
    """
    # termA = 1/(k*p_k) = 1/(2*p2)
    if p2 == 0:
        termA = np.inf
    else:
        termA = 1.0 / (2.0 * p2)

    denom = p1 + 2.0 * p2 - p3_vec
    termB = np.empty_like(denom, dtype=float)
    # denom == 0 -> +inf
    mask_zero = (denom == 0.0)
    mask_nz = ~mask_zero
    termB[mask_zero] = np.inf
    # denom != 0 -> p1/denom (can be negative if denom<0)
    termB[mask_nz] = p1 / denom[mask_nz]

    x2 = np.minimum(1.0, np.maximum(termA, termB))
    # Clamp to [0,1] for numerical stability
    np.clip(x2, 0.0, 1.0, out=x2)
    return x2


def _A_from_xk(xk: np.ndarray | float, k: int) -> np.ndarray | float:
    """
    A_k = x_k/(1+k) + (1 - x_k)
        = 1 - x_k * (k/(k+1))
    """
    return 1.0 - (k / (k + 1.0)) * xk


def _B_from_xk(p1: float, pk1: np.ndarray | float, xk: np.ndarray | float) -> np.ndarray | float:
    """
    B_k = min(1, 1 / (p_{k+1} * x_k + p1 * (1 - x_k)) ).
    If denominator is 0, treat 1/0 = +inf, and min(1, +inf) = 1.
    """
    denom = pk1 * xk + p1 * (1.0 - xk)
    if isinstance(denom, np.ndarray):
        inv = np.empty_like(denom, dtype=float)
        mask_zero = (denom == 0.0)
        mask_nz = ~mask_zero
        inv[mask_zero] = np.inf
        inv[mask_nz] = 1.0 / denom[mask_nz]
        return np.minimum(1.0, inv)
    else:
        if denom == 0.0:
            return 1.0
        else:
            return min(1.0, 1.0 / denom)


def objective_n3(p1: float, p2: float, p3: float) -> float:
    """
    Computes the objective value for given (p1, p2, p3) in the n=3 case.
    """
    # k = 1
    x1 = _compute_x1(p1, p2)
    A1 = _A_from_xk(x1, k=1)
    B1 = _B_from_xk(p1, p2, x1)

    # k = 2
    # For scalar evaluation, reuse vectorized function with length-1 array
    x2 = _compute_x2_vectorized(p1, p2, np.array([p3], dtype=float))[0]
    A2 = _A_from_xk(x2, k=2)
    B2 = _B_from_xk(p1, p3, x2)

    return min(A1 * B1, A2 * B2)


def brute_force_search(
    steps_p1: int = 121,
    steps_p2: int = 121,
    steps_p3: int = 481,
    p1_max: float = P1_MAX,
    p2_max: float = P2_MAX,
    p3_max: float = P3_MAX,
) -> SearchResult:
    """
    Brute-force search over a rectilinear grid with given number of steps per axis.
    Uses a vectorized scan over p3 for each (p1, p2) pair to accelerate the search.

    Returns the best (p1, p2, p3) and the max objective value found.
    """
    t0 = time.perf_counter()

    p1_vals = np.linspace(0.0, p1_max, steps_p1)
    p2_vals = np.linspace(0.0, p2_max, steps_p2)
    p3_vals = np.linspace(0.0, p3_max, steps_p3)

    best_val = -np.inf
    best_tuple: Optional[Tuple[float, float, float]] = None
    evals = 0
    for p1 in p1_vals:
        for p2 in p2_vals:
            # Precompute x1, A1, B1 (scalars)
            x1 = _compute_x1(p1, p2)
            A1 = _A_from_xk(x1, k=1)
            B1 = _B_from_xk(p1, p2, x1)
            obj1 = A1 * B1  # scalar

            # Vectorized work for k=2 across all p3
            x2_vec = _compute_x2_vectorized(p1, p2, p3_vals)
            A2_vec = _A_from_xk(x2_vec, k=2)
            B2_vec = _B_from_xk(p1, p3_vals, x2_vec)

            # Overall objective = min(obj1, A2*B2) component-wise
            obj_vec = np.minimum(obj1, A2_vec * B2_vec)

            # Find best p3 for this (p1, p2)
            idx = int(np.argmax(obj_vec))
            cand_val = float(obj_vec[idx])
            cand_p3 = float(p3_vals[idx])

            evals += len(p3_vals)

            if cand_val > best_val:
                best_val = cand_val
                best_tuple = (float(p1), float(p2), cand_p3)

    t1 = time.perf_counter()

    assert best_tuple is not None
    return SearchResult(
        p1=best_tuple[0],
        p2=best_tuple[1],
        p3=best_tuple[2],
        value=best_val,
        evals=evals,
        seconds=(t1 - t0),
    )


def refine_around_best(
    best: SearchResult,
    span_steps: Tuple[int, int, int] = (6, 6, 12),
    refine_steps: Tuple[int, int, int] = (121, 121, 241),
    p1_max: float = P1_MAX,
    p2_max: float = P2_MAX,
    p3_max: float = P3_MAX,
    coarse_steps: Tuple[int, int, int] = (121, 121, 241),
) -> SearchResult:
    """
    Refine search around the best coarse point by narrowing the window to ±(span_steps / coarse_steps) * (range)
    and sampling with 'refine_steps' resolution within that window (clipped to bounds).

    This keeps the refinement window roughly proportional to one coarse-grid step count.
    """
    (c1, c2, c3) = coarse_steps
    (s1, s2, s3) = span_steps
    (r1, r2, r3) = refine_steps

    # Compute window sizes approximately equal to span_steps coarse increments
    p1_lo = max(0.0, best.p1 - s1 * (p1_max / (c1 - 1)))
    p1_hi = min(p1_max, best.p1 + s1 * (p1_max / (c1 - 1)))

    p2_lo = max(0.0, best.p2 - s2 * (p2_max / (c2 - 1)))
    p2_hi = min(p2_max, best.p2 + s2 * (p2_max / (c2 - 1)))

    p3_lo = max(0.0, best.p3 - s3 * (p3_max / (c3 - 1)))
    p3_hi = min(p3_max, best.p3 + s3 * (p3_max / (c3 - 1)))

    # Ensure at least 2 steps to avoid degenerate linspace
    r1 = max(2, r1)
    r2 = max(2, r2)
    r3 = max(2, r3)

    t0 = time.perf_counter()

    p1_vals = np.linspace(p1_lo, p1_hi, r1)
    p2_vals = np.linspace(p2_lo, p2_hi, r2)
    p3_vals = np.linspace(p3_lo, p3_hi, r3)

    best_val = -np.inf
    best_tuple: Optional[Tuple[float, float, float]] = None
    evals = 0

    for p1 in p1_vals:
        for p2 in p2_vals:
            x1 = _compute_x1(p1, p2)
            A1 = _A_from_xk(x1, k=1)
            B1 = _B_from_xk(p1, p2, x1)
            obj1 = A1 * B1

            x2_vec = _compute_x2_vectorized(p1, p2, p3_vals)
            A2_vec = _A_from_xk(x2_vec, k=2)
            B2_vec = _B_from_xk(p1, p3_vals, x2_vec)

            obj_vec = np.minimum(obj1, A2_vec * B2_vec)

            idx = int(np.argmax(obj_vec))
            cand_val = float(obj_vec[idx])
            cand_p3 = float(p3_vals[idx])

            evals += len(p3_vals)

            if cand_val > best_val:
                best_val = cand_val
                best_tuple = (float(p1), float(p2), cand_p3)

    t1 = time.perf_counter()

    assert best_tuple is not None
    return SearchResult(
        p1=best_tuple[0],
        p2=best_tuple[1],
        p3=best_tuple[2],
        value=best_val,
        evals=evals,
        seconds=(t1 - t0),
    )


if __name__ == "__main__":
    # Example usage:
    # Start with a coarse grid (adjust steps to fit your time budget).
    # Total evaluations = steps_p1 * steps_p2 * steps_p3 (but inner loop is vectorized).
    coarse = brute_force_search(
        steps_p1=121,  # resolution along p1 in [0, 9/4]
        steps_p2=121,  # resolution along p2 in [0, 9/4]
        steps_p3=241,  # resolution along p3 in [0, 9]
    )
    print("[COARSE] best =", coarse)

    # Optional: refine around the best coarse point with a local window
    refined = refine_around_best(
        best=coarse,
        span_steps=(6, 6, 12),            # window approx ±6 coarse steps on p1,p2 and ±12 on p3
        refine_steps=(181, 181, 721),     # finer resolution in the local window
        coarse_steps=(121, 121, 481),
    )
    print("[REFINED] best =", refined)


[COARSE] best = SearchResult(p1=1.33125, p2=0.88125, p3=0.0, value=0.6217494089834515, evals=3528481, seconds=0.2252756250090897)
[REFINED] best = SearchResult(p1=1.33375, p2=0.8887499999999999, p3=0.0, value=0.6249413970932958, evals=23620681, seconds=0.49999020795803517)
