In [None]:
import numpy as np

# -----------------------------
#   Helper functions
# -----------------------------

def compute_Ai(x: np.ndarray, y: np.ndarray, w: np.ndarray, Ui_L: np.ndarray) -> np.ndarray:
    """Compute A_i = U_i^L + Σ_j w_ij[-y_j x_j^2 + (1+y_j)x_j] for all i."""
    term = w * ((1.0 + y) * x - y * x ** 2)  # broadcast over j
    return Ui_L + term.sum(axis=1)


def compute_Bi(x: np.ndarray, y: np.ndarray, w: np.ndarray,
               Ui_L: np.ndarray, Ui_F: np.ndarray) -> np.ndarray:
    """Compute B_i = U_i^L + U_i^F + Σ_j w_ij[(1-y_j)x_j + y_j] for all i."""
    term = w * ((1.0 - y) * x + y)
    return Ui_L + Ui_F + term.sum(axis=1)


def compute_Lhat(x: np.ndarray, y: np.ndarray, w: np.ndarray,
                 Ui_L: np.ndarray, Ui_F: np.ndarray, h: np.ndarray) -> float:
    """Evaluate the objective \hat{L}(x, y)."""
    Ai = compute_Ai(x, y, w, Ui_L)
    Bi = compute_Bi(x, y, w, Ui_L, Ui_F)
    return float(np.dot(h, Ai / Bi))

# -----------------------------
#   Gradients
# -----------------------------

def grad_x(x: np.ndarray, y: np.ndarray, w: np.ndarray,
           Ui_L: np.ndarray, Ui_F: np.ndarray, h: np.ndarray) -> np.ndarray:
    """Gradient of \hat{L} with respect to x (ascent direction)."""
    Ai = compute_Ai(x, y, w, Ui_L)
    Bi = compute_Bi(x, y, w, Ui_L, Ui_F)

    dA_dx = w * ((1.0 + y) - 2.0 * y * x)     # shape (I, J)
    dB_dx = w * (1.0 - y)

    frac = (Bi[:, None] * dA_dx - Ai[:, None] * dB_dx) / (Bi[:, None] ** 2)
    return (frac * h[:, None]).sum(axis=0)


def grad_y(x: np.ndarray, y: np.ndarray, w: np.ndarray,
           Ui_L: np.ndarray, Ui_F: np.ndarray, h: np.ndarray) -> np.ndarray:
    """Gradient of \hat{L} with respect to y (descent direction)."""
    Ai = compute_Ai(x, y, w, Ui_L)
    Bi = compute_Bi(x, y, w, Ui_L, Ui_F)

    dA_dy = w * (-x ** 2 + x)
    dB_dy = w * (1.0 - x)

    frac = (Bi[:, None] * dA_dy - Ai[:, None] * dB_dy) / (Bi[:, None] ** 2)
    return (frac * h[:, None]).sum(axis=0)

# -----------------------------
#   Projection
# -----------------------------

def project_cardinality(v: np.ndarray, k: int, mask: np.ndarray | None = None) -> np.ndarray:
    """Project a vector onto the set {0,1}^J with at most *k* ones.

    Parameters
    ----------
    v : array_like
        Continuous scores for each coordinate.
    k : int
        Maximum number of 1s allowed.
    mask : array_like of bool, optional
        Positions that are forcibly set to 0 (e.g., columns already chosen by *x* when projecting *y*).
    """
    v = np.clip(v, 0.0, 1.0)
    if mask is not None:
        v = v * (~mask)

    if k >= v.size:
        return (v > 0).astype(float)

    # pick indices of the k largest values
    idx = np.argpartition(-v, k)[:k]
    out = np.zeros_like(v)
    out[idx] = 1.0
    return out

# -----------------------------
#   OGDA Solver
# -----------------------------

def ogda_solver(x0: np.ndarray, y0: np.ndarray, w: np.ndarray,
                Ui_L: np.ndarray, Ui_F: np.ndarray, h: np.ndarray,
                p: int, r: int, *, eta: float = 0.01, max_iter: int = 500,
                tol: float = 1e-6, return_history: bool = False):
    """Optimistic Gradient Descent–Ascent (OGDA) solver for the bilevel game.

    Parameters
    ----------
    x0, y0 : array_like (J,)
        Initial continuous vectors (will be projected to 0/1 at every step).
    w, Ui_L, Ui_F, h : arrays from the original model.
    p, r : int
        Cardinality constraints for x and y respectively.
    eta : float, default 0.01
        Step size.
    max_iter : int, default 500
        Maximum number of OGDA iterations.
    tol : float, default 1e-6
        Convergence tolerance on (x, y) change.
    return_history : bool, default False
        If True, also return a list with (Lhat, ||dx||, ||dy||) per iteration.

    Returns
    -------
    x_opt, y_opt : ndarray
        Binary solutions after convergence.
    history : list | None
        Only if *return_history* is True.
    """
    x_prev = x0.copy()
    y_prev = y0.copy()

    # Initial gradients
    gx_prev = grad_x(x_prev, y_prev, w, Ui_L, Ui_F, h)
    gy_prev = grad_y(x_prev, y_prev, w, Ui_L, Ui_F, h)

    hist = []

    for _ in range(max_iter):
        # Current gradients
        gx = grad_x(x_prev, y_prev, w, Ui_L, Ui_F, h)
        gy = grad_y(x_prev, y_prev, w, Ui_L, Ui_F, h)

        # OGDA update (predictor–corrector style)
        x_tmp = x_prev + eta * (2.0 * gx - gx_prev)  # ascent step for x
        y_tmp = y_prev - eta * (2.0 * gy - gy_prev)  # descent step for y

        # Projection to maintain feasibility
        x_next = project_cardinality(x_tmp, p)
        y_next = project_cardinality(y_tmp, r, mask=(x_next > 0))

        # Convergence check
        dx = np.linalg.norm(x_next - x_prev)
        dy = np.linalg.norm(y_next - y_prev)

        if return_history:
            L_val = compute_Lhat(x_next, y_next, w, Ui_L, Ui_F, h)
            hist.append((L_val, dx, dy))

        if max(dx, dy) < tol:
            break

        # Prepare for next iteration
        gx_prev, gy_prev = gx, gy
        x_prev, y_prev = x_next, y_next

    if return_history:
        return x_next, y_next, hist
    return x_next, y_next
