In [3]:
import math, time, json
from dataclasses import dataclass, asdict
from typing import List, Dict, Any, Tuple

import numpy as np
import pandas as pd
from scipy.optimize import minimize
from numpy.linalg import LinAlgError
import pyswarms as ps


In [4]:

# ------------------------------
# Model primitives (4×4 Emax info)
# ------------------------------

DOSE_MIN, DOSE_MAX = 0.0, 500.0

@dataclass
class Params:
    ED50: float = 1.0
    Emax: float = 1.0
    SD50: float = 2.0
    Smax: float = 1.0
    sigma1: float = 1.0
    sigma2: float = 1.0
    rho: float = 0.0
    k1: float = 1.0
    k2: float = 1.0

def sigma_inv(s1: float, s2: float, rho: float) -> np.ndarray:
    # inverse of [[s1^2, rho s1 s2],[rho s1 s2, s2^2]]
    det = (s1**2) * (s2**2) * (1 - rho**2)
    return np.array([[ s2**2,        -rho*s1*s2],
                     [-rho*s1*s2,     s1**2   ]], dtype=float) / det

def J_x(x: float, ED50: float, Emax: float, SD50: float, Smax: float) -> np.ndarray:
    # 2×4 Jacobian of (mu1, mu2) wrt (ED50, Emax, SD50, Smax)
    x_ed = x + ED50
    x_sd = x + SD50
    if x_ed <= 0 or x_sd <= 0:
        return np.zeros((2,4), dtype=float)

    dmu1_dED50 = -Emax * x / (x_ed**2)
    dmu1_dEmax =  x / x_ed
    dmu1_dSD50 =  0.0
    dmu1_dSmax =  0.0

    dmu2_dED50 =  0.0
    dmu2_dEmax =  0.0
    dmu2_dSD50 = -Smax * x / (x_sd**2)
    dmu2_dSmax =  x / x_sd

    return np.array([[dmu1_dED50, dmu1_dEmax, dmu1_dSD50, dmu1_dSmax],
                     [dmu2_dED50, dmu2_dEmax, dmu2_dSD50, dmu2_dSmax]], dtype=float)

def M_one_point(x: float, p: Params) -> np.ndarray:
    S_inv = sigma_inv(p.sigma1, p.sigma2, p.rho)      # 2×2
    J     = J_x(x, p.ED50, p.Emax, p.SD50, p.Smax)    # 2×4
    return J.T @ S_inv @ J                            # 4×4

def M_design(xs: List[float], ws: List[float], p: Params, ridge: float = 1e-12) -> np.ndarray:
    M = np.zeros((4,4), dtype=float)
    for x, w in zip(xs, ws):
        if w > 0:
            M += w * M_one_point(x, p)
    # Tikhonov regularization for numerical stability
    return M + ridge * np.eye(4, dtype=float)

# ------------------------------
# g(theta) from Magnusdottir Eq. (4) + numeric gradient
# ------------------------------

def g_theta(theta: np.ndarray, k1: float, k2: float) -> float:
    ED50, Emax, SD50, Smax = theta
    num_sqrt = math.sqrt(max(0.0, k1*ED50*Emax * k2*SD50*Smax))  # guard small neg due to FP
    num = num_sqrt * (ED50 - SD50) - ED50*SD50*(k1*Emax - k2*Smax)
    den = (k1*ED50*Emax - k2*SD50*Smax)
    return num / den

def grad_g_numeric(p: Params, eps: float = 1e-6) -> np.ndarray:
    # central differences w.r.t. (ED50, Emax, SD50, Smax)
    theta0 = np.array([p.ED50, p.Emax, p.SD50, p.Smax], dtype=float)
    g0 = g_theta(theta0, p.k1, p.k2)
    grad = np.zeros(4, dtype=float)
    for i in range(4):
        ei = np.zeros(4, dtype=float)
        ei[i] = 1.0
        fp = g_theta(theta0 + eps*ei, p.k1, p.k2)
        fm = g_theta(theta0 - eps*ei, p.k1, p.k2)
        grad[i] = (fp - fm) / (2*eps)
    if not np.all(np.isfinite(grad)):
        # fallback to forward difference if needed
        for i in range(4):
            ei = np.zeros(4); ei[i] = 1.0
            fp = g_theta(theta0 + eps*ei, p.k1, p.k2)
            grad[i] = (fp - g0) / eps
    return grad

# ------------------------------
# Psi(ξ) and GET check
# ------------------------------

def psi_of_design(xs: List[float], ws: List[float], p: Params) -> float:
    M = M_design(xs, ws, p)
    try:
        Minv = np.linalg.inv(M)
    except LinAlgError:
        return 1e50
    grad = grad_g_numeric(p)
    return float(grad.T @ Minv @ grad)

def GET_check(xs: List[float], ws: List[float], p: Params,
              grid: np.ndarray) -> Dict[str, Any]:
    M = M_design(xs, ws, p)
    Minv = np.linalg.inv(M)
    grad = grad_g_numeric(p)
    rhs = float(grad.T @ Minv @ grad)

    def phi_x(x):
        Mx = M_one_point(x, p)
        lhs = float(grad.T @ Minv @ Mx @ Minv @ grad)
        return lhs - rhs

    phi_vals = np.array([phi_x(x) for x in grid], dtype=float)
    at_design = np.array([phi_x(x) for x in xs], dtype=float)
    return dict(max_violation=float(np.max(phi_vals)),
                grid_x=grid, phi=phi_vals, at_design=at_design)

# ------------------------------
# Helpers: weights, penalties, encoding
# ------------------------------

def clean_weights(w: np.ndarray) -> np.ndarray:
    w = w.copy()
    w[(w < 0) & (w > -1e-10)] = 0.0
    w = np.clip(w, 0.0, 1.0)
    s = np.sum(w)
    if not np.isfinite(s) or s <= 0:
        return np.ones_like(w) / len(w)
    if abs(s - 1) > 1e-10:
        w = w / s
    return w

def distinct_penalty(xs: np.ndarray, min_gap: float = 1e-3, strength: float = 1e6) -> float:
    xs = np.sort(xs)
    if len(xs) < 2:
        return 0.0
    gmin = float(np.min(np.diff(xs)))
    return 0.0 if gmin >= min_gap else strength * (min_gap - gmin)**2

# v = [x1..xn, w1..w_{n-1}], with wn := 1 - sum(w1..w_{n-1})
def obj_n_point(v: np.ndarray, p: Params, xmin: float, xmax: float, n: int) -> float:
    xs = np.array(v[:n], dtype=float)
    w_head = np.array(v[n:n+(n-1)], dtype=float)
    w_last = 1.0 - np.sum(w_head)
    ws = np.concatenate([w_head, [w_last]])

    penalty = 0.0
    if w_last < -1e-6:
        penalty += 1e6 * (abs(w_last))**2

    # sort doses, reorder weights accordingly
    ord_idx = np.argsort(xs)
    xs = xs[ord_idx]
    ws = ws[ord_idx]

    ws = clean_weights(ws)
    if np.any(xs < xmin) or np.any(xs > xmax):
        return 1e9 + penalty

    penalty += distinct_penalty(xs)
    return psi_of_design(xs.tolist(), ws.tolist(), p) + penalty

def polish_n(sol_vec: np.ndarray, p: Params, xmin: float, xmax: float, n: int) -> Tuple[np.ndarray, float]:
    x0 = sol_vec.copy()

    def fn(v):
        return obj_n_point(v, p, xmin, xmax, n)

    bounds = [(xmin, xmax)] * n + [(0.0, 1.0)] * (n-1)
    res = minimize(fn, x0=x0, bounds=bounds, method="L-BFGS-B", options={"maxiter": 2000})
    if res.success:
        return res.x.copy(), float(res.fun)
    return sol_vec, float(fn(sol_vec))

def rand_w_head(n: int, rng: np.random.Generator) -> np.ndarray:
    w = rng.gamma(shape=1.0, scale=1.0, size=n)
    w = w / np.sum(w)
    return w[:n-1]

def decode_solution(par_vec: np.ndarray, xmin: float, xmax: float, n: int) -> Dict[str, Any]:
    xs = np.array(par_vec[:n], dtype=float)
    w_head = np.array(par_vec[n:n+(n-1)], dtype=float)
    w_last = 1.0 - np.sum(w_head)
    ws = np.concatenate([w_head, [w_last]])
    ord_idx = np.argsort(xs)
    xs = xs[ord_idx]
    ws = clean_weights(ws[ord_idx])
    return {"xs": xs, "ws": ws}

# ------------------------------
# Build nominal parameters (Theorem 2 invariances)
# ------------------------------

def make_params(r_SE: float, r_k21: float, r_SDED: float, r_var: float, rho: float) -> Params:
    assert r_SE > 0 and r_k21 > 0 and r_SDED > 0 and r_var > 0
    assert -1 < rho < 1
    return Params(
        ED50=1.0,
        Emax=1.0,
        SD50=r_SDED,
        Smax=r_SE,
        sigma1=1.0,
        sigma2=float(np.sqrt(r_var)),
        rho=float(rho),
        k1=1.0,
        k2=r_k21
    )

# ------------------------------
# PSO drivers (generic + n-point wrappers)
# ------------------------------

def _pso_minimize(func, lb, ub, dim, swarm_size=60, iters=300, seed=42, options=None):
    if options is None:
        options = {"c1": 1.7, "c2": 1.7, "w": 0.6}
    optimizer = ps.single.GlobalBestPSO(
        n_particles=swarm_size,
        dimensions=dim,
        options=options,
        bounds=(np.array(lb, dtype=float), np.array(ub, dtype=float)),
        ftol=1e-12,
        init_pos=None
    )
    cost, pos = optimizer.optimize(func, iters=iters, n_processes=None, verbose=False)
    return float(cost), np.array(pos, dtype=float)

def run_copt_once(
    r_SE: float, r_k21: float, r_SDED: float, r_var: float, rho: float,
    xmin: float, xmax: float,
    npoints: int = 2,
    swarm: int = 100,
    iters: int = 500,
    polish: bool = True,
    grid_length: int = 10000,   # can set to 10000 like R; 2000 is faster
    seed: int = None
) -> Dict[str, Any]:

    if seed is not None:
        np.random.seed(seed)

    p = make_params(r_SE, r_k21, r_SDED, r_var, rho)
    n = int(npoints)
    assert xmax > xmin and n >= 2

    def obj_wrapper(X):
        # X shape: (n_particles, dim); dim = n + (n-1)
        vals = []
        for v in X:
            vals.append(obj_n_point(v, p, xmin, xmax, n))
        return np.array(vals, dtype=float)

    lb = [xmin] * n + [0.0] * (n-1)
    ub = [xmax] * n + [1.0] * (n-1)

    # PSO
    t0 = time.perf_counter()
    pso_cost, pso_pos = _pso_minimize(
        obj_wrapper, lb, ub, dim=(n + (n-1)),
        swarm_size=swarm, iters=iters, seed=seed,
        options={"c1": 1.7, "c2": 1.7, "w": 0.6}
    )
    t1 = time.perf_counter()

    sol_vec = pso_pos.copy()
    sol_cost = pso_cost

    # local polish
    if polish:
        sol_vec, sol_cost_pol = polish_n(sol_vec, p, xmin, xmax, n)
        if sol_cost_pol <= sol_cost:
            sol_cost = sol_cost_pol

    dec = decode_solution(sol_vec, xmin, xmax, n)
    xs = dec["xs"]; ws = dec["ws"]
    psi_val = psi_of_design(xs.tolist(), ws.tolist(), p)

    # GET check on dense grid (+ include design points exactly)
    grid = np.unique(np.concatenate([np.linspace(xmin, xmax, grid_length), xs]))
    getd = GET_check(xs.tolist(), ws.tolist(), p, grid)

    return dict(
        input=dict(r_SE=r_SE, r_k21=r_k21, r_SDED=r_SDED, r_var=r_var, rho=rho,
                   xmin=xmin, xmax=xmax, npoints=n, swarm=swarm, iters=iters,
                   polish=polish, grid_length=grid_length, seed=seed),
        params=asdict(p),
        xs=xs, ws=ws, psi=float(psi_val),
        GET_max_violation=float(getd["max_violation"]),
        GET_at_design=getd["at_design"].tolist(),
        objective_value=float(sol_cost),
        runtime_s=float(t1 - t0)
    )

# ------------------------------
# Batch runners over a scenario grid (record runtime)
# ------------------------------

def run_copt_grid(
    scen_df: pd.DataFrame,
    xmin: float, xmax: float,
    swarm: int = 40,
    iters: int = 200,
    polish: bool = True,
    grid_length: int = 2000,
    seed: int = None,
    save_csv_path: str = None
) -> pd.DataFrame:
    req_cols = ["r_SE", "r_k21", "r_SDED", "r_var", "rho", "npoints"]
    assert all(c in scen_df.columns for c in req_cols)

    rows = []
    for i, row in scen_df.reset_index(drop=True).iterrows():
        s = None if seed is None else (int(seed) + i + 1)
        out = run_copt_once(
            r_SE=float(row["r_SE"]),
            r_k21=float(row["r_k21"]),
            r_SDED=float(row["r_SDED"]),
            r_var=float(row["r_var"]),
            rho=float(row["rho"]),
            xmin=float(xmin), xmax=float(xmax),
            npoints=int(row["npoints"]),
            swarm=int(swarm), iters=int(iters),
            polish=bool(polish), grid_length=int(grid_length),
            seed=s
        )
        xs = out["xs"]; ws = out["ws"]
        rows.append({
            "r_SE": row["r_SE"],
            "r_k21": row["r_k21"],
            "r_SDED": row["r_SDED"],
            "r_var": row["r_var"],
            "rho": row["rho"],
            "npoints": row["npoints"],
            "psi": out["psi"],
            "GET_max_violation": out["GET_max_violation"],
            "runtime_s": out["runtime_s"],
            # serialize design (works for 2- or 3-point designs)
            "xs": json.dumps([float(x) for x in xs]),
            "ws": json.dumps([float(w) for w in ws]),
            # just for reference:
            "objective_value": out["objective_value"]
        })

    df = pd.DataFrame(rows)
    if save_csv_path:
        df.to_csv(save_csv_path, index=False)
        print(f"Saved grid results -> {save_csv_path}")
    return df



In [6]:
res = run_copt_once(
        r_SE=1.0, r_k21=1.0, r_SDED=2.0, r_var=1.0, rho=0.0,
        xmin=DOSE_MIN, xmax=DOSE_MAX,
        npoints=2, swarm=40, iters=200, polish=True, seed=2025
    )
print("Single scenario:")
print("xs:", res["xs"], "ws:", res["ws"], "psi:", res["psi"],
          "GET max viol:", res["GET_max_violation"], "time(s):", res["runtime_s"])

Single scenario:
xs: [  1.1078374 500.       ] ws: [0.39440029 0.60559971] psi: 112.3651116245437 GET max viol: 4.663160353857165e-06 time(s): 0.06173950002994388


In [None]:
# Create all combinations of parameters
import itertools

# ------------------------------
# Example usage (matches your R examples)
# ------------------------------

if __name__ == "__main__":

    
    # Define parameter ranges
    r_SE_vals = [0.2, 1, 5, 100]
    r_k21_vals = [0.2, 1, 5, 100]
    r_SDED_vals = [0.2, 1, 5, 100]
    r_var_vals = [0.2, 1, 5, 100]
    rho_vals = [-0.9, -0.45, 0.0, 0.45, 0.9]
    npoints_vals = [2, 3, 4]  # number of design points
    
    # Create Cartesian product
    combinations = list(itertools.product(r_SE_vals, r_k21_vals, r_SDED_vals, 
                                        r_var_vals, rho_vals, npoints_vals))
    
    scen = pd.DataFrame(combinations, columns=['r_SE', 'r_k21', 'r_SDED', 'r_var', 'rho', 'npoints'])
    
    df_grid = run_copt_grid(
        scen_df=scen,
        xmin=DOSE_MIN, xmax=DOSE_MAX,
        swarm=40, iters=800, polish=True,
        grid_length=10000,   # increase to 10000 to mimic your R check
        seed=2025,
        save_csv_path="try_copt_grid_results.csv"
    )
    print("\nGrid summary (runtime):")
    print(df_grid[["r_SDED","psi","runtime_s","xs","ws"]])


  return num / den
  return num / den


Saved grid results -> try_copt_grid_results.csv

Grid summary (runtime):
      r_SDED          psi  runtime_s  \
0        0.2     7.025033   0.031764   
1        0.2     7.025047   0.011914   
2        0.2  7936.795864   0.014220   
3        0.2     6.143006   0.016638   
4        0.2     6.148275   0.019176   
...      ...          ...        ...   
3835   100.0    25.996169   0.015549   
3836   100.0    40.403046   0.028823   
3837   100.0    22.883386   0.015850   
3838   100.0    10.396843   0.012372   
3839   100.0    10.305319   0.029213   

                                                     xs  \
0                          [0.17939570715802333, 500.0]   
1      [0.17938610481683523, 442.67279824182435, 500.0]   
2     [18.155014784081047, 87.40862510221592, 107.19...   
3             [0.16793615923184993, 313.30920027508375]   
4     [0.16787200302795502, 34.4827205052785, 234.15...   
...                                                 ...   
3835  [0.6836750335887322, 182.47