In [2]:
import math
import numpy as np
from dataclasses import dataclass
from typing import Tuple, Optional

# ----------------------------
# Utility functions
# ----------------------------

def div_safe(a: float, b: float) -> float:
    """Safe division a/b with sensible handling of zero denominators."""
    if b == 0.0:
        if a > 0: 
            return math.inf
        elif a < 0:
            return -math.inf
        else:
            return math.nan  # 0/0 is undefined
    return a / b

def recip_safe(x: float) -> float:
    """Safe reciprocal 1/x with handling of zero."""
    if x == 0.0:
        return math.inf
    return 1.0 / x

def is_finite(x: float) -> bool:
    return np.isfinite(x)

# ----------------------------
# Core math for the problem
# ----------------------------

def compute_A_B(k: int, p1: float, p2: float, p3: float) -> Tuple[float, float]:
    """
    Compute A_k and B_k for the given k in {1,2}.
      A_k = min{1, max{ 1/(p_k * k),  p1 / (p_k*k - p_{k+1} + p1) } }
      B_k = p1 / (p_{k+1} - p1)
    """
    assert k in (1, 2), "k must be 1 or 2"
    pk = p1 if k == 1 else p2
    pk1 = p2 if k == 1 else p3

    term1 = recip_safe(pk * k)  # 1/(p_k * k)
    denom2 = pk * k - pk1 + p1
    term2 = div_safe(p1, denom2)  # p1 / (p_k*k - p_{k+1} + p1)

    inner_max = max(term1, term2)
    A = min(1.0, inner_max)

    denomB = (pk1 - p1)
    B = div_safe(p1, denomB)  # p1 / (p_{k+1} - p1)

    return A, B

def objective_for_k(k: int, gamma: float, p1: float, p2: float, p3: float) -> float:
    """
    f(k, gamma) = (gamma/(1+k) + (1-gamma)) * min{ 1, 1 / (p_{k+1}*gamma + p1*(1-gamma)) }
    """
    if not is_finite(gamma):
        return -math.inf

    pk1 = p2 if k == 1 else p3
    factor1 = (gamma / (1.0 + k)) + (1.0 - gamma)

    denom3 = pk1 * gamma + p1 * (1.0 - gamma)
    inv = recip_safe(denom3)  # 1 / denom3 (inf if denom3==0)
    term2 = min(1.0, inv)

    # If inv is -inf (denom3 -> 0-), min(1, -inf) = -inf makes objective -inf.
    # This is consistent with the literal mathematical definition.
    return factor1 * term2

def best_value_for_p(p1: float, p2: float, p3: float, tol: float = 1e-12):
    """
    For a fixed (p1, p2, p3), compute:
      min_{k=1,2} max_{gamma in feasible set} f(k, gamma).
    The feasible set for each k is candidates {A_k, B_k (if B_k >= A_k)}.
    Returns (value, details) where details has the argmax gamma per k, and f_k values.
    """
    k_values = [1, 2]
    fks = []
    chosen_gammas = []

    for k in k_values:
        A, B = compute_A_B(k, p1, p2, p3)
        candidates = []

        if is_finite(A):
            candidates.append(A)
        # B is feasible only if it's finite and B >= A (within tolerance)
        if is_finite(B) and is_finite(A) and (B >= A - tol):
            candidates.append(B)

        # If there are no feasible gamma candidates, this (p1,p2,p3) is infeasible for this k
        if len(candidates) == 0:
            fks.append(-math.inf)
            chosen_gammas.append(None)
            continue

        # Evaluate objective for these candidates and pick the best for this k
        vals = [objective_for_k(k, g, p1, p2, p3) for g in candidates]
        idx = int(np.argmax(vals))
        fks.append(vals[idx])
        chosen_gammas.append(candidates[idx])

    # Overall objective is the min over k
    overall = min(fks[0], fks[1])

    details = {
        "k1": {"A": compute_A_B(1, p1, p2, p3)[0], "B": compute_A_B(1, p1, p2, p3)[1],
               "gamma_star": chosen_gammas[0], "f_k": fks[0]},
        "k2": {"A": compute_A_B(2, p1, p2, p3)[0], "B": compute_A_B(2, p1, p2, p3)[1],
               "gamma_star": chosen_gammas[1], "f_k": fks[1]},
    }
    return overall, details

# ----------------------------
# Grid search / brute force
# ----------------------------

@dataclass
class SearchResult:
    best_value: float
    best_p: Tuple[float, float, float]
    best_details: dict

def make_grid(start: float, stop: float, num: Optional[int] = None, step: Optional[float] = None) -> np.ndarray:
    """
    Create a 1D grid from [start, stop] inclusive using either:
      - num points (linspace), or
      - step size (arange with inclusion of stop).
    """
    assert (num is not None) ^ (step is not None), "Specify exactly one of num or step."
    if num is not None:
        return np.linspace(start, stop, num=num)
    else:
        # ensure inclusion of stop (within floating tolerance)
        n = int(math.floor((stop - start) / step + 0.5))
        grid = start + step * np.arange(n + 1)
        # Ensure last point is stop if close
        if abs(grid[-1] - stop) > 1e-12:
            grid = np.append(grid, stop)
        return grid

def brute_force_search(
    n1: int = 61,         # number of points for p1 in [0, 2e]
    n2: int = 61,         # number of points for p2 in [0, 2e]
    n3: int = 121,        # number of points for p3 in [0, 4e] (twice the span of p1/p2)
    tol: float = 1e-12,
) -> SearchResult:
    """
    Brute-force grid search over (p1, p2, p3) in [0, 2e] x [0, 2e] x [0, 4e].
    Returns the best value found, the argmax (p1, p2, p3), and details about k=1,2.
    """
    p1_grid = make_grid(0.0, 2.0 * math.e, num=n1)
    p2_grid = make_grid(0.0, 2.0 * math.e, num=n2)
    p3_grid = make_grid(0.0, 4.0 * math.e, num=n3)

    best_val = -math.inf
    best_p = (None, None, None)
    best_details = None

    # Triple nested loops; manageable if n1*n2*n3 is not too large
    for p1 in p1_grid:
        for p2 in p2_grid:
            for p3 in p3_grid:
                val, details = best_value_for_p(p1, p2, p3, tol=tol)
                if val > best_val:
                    best_val = val
                    best_p = (p1, p2, p3)
                    best_details = details

    return SearchResult(best_value=best_val, best_p=best_p, best_details=best_details)

# ----------------------------
# Example usage
# ----------------------------

if __name__ == "__main__":
    # You can adjust the resolution below; finer grids will take longer.
    # For a quick test: n1=n2=31, n3=61
    result = brute_force_search(n1=201, n2=201, n3=401, tol=1e-12)

    print("Best objective value found:", result.best_value)
    print("Best (p1, p2, p3):", result.best_p)
    print("Details at best point:")
    for k in ["k1", "k2"]:
        d = result.best_details[k]
        print(f"  {k}: A={d['A']}, B={d['B']}, gamma*={d['gamma_star']}, f_k={d['f_k']}")


Best objective value found: 0.6205379994770268
Best (p1, p2, p3): (np.float64(1.331958095944932), np.float64(0.8970330033914848), np.float64(0.0))
Details at best point:
  k1: A=0.7538461538461538, B=-3.0625000000000004, gamma*=0.7538461538461538, f_k=0.6205379994770268
  k2: A=0.5573930926840036, B=-1.0, gamma*=0.5573930926840036, f_k=0.628404604877331
