In [None]:
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------
#   Helper functions
# -----------------------------

def compute_Ai(x: np.ndarray, y: np.ndarray, w: np.ndarray, Ui_L: np.ndarray) -> np.ndarray:
    """Compute A_i = U_i^L + Σ_j w_ij[-y_j x_j^2 + (1+y_j)x_j] for all i."""
    term = w * ((1.0 + y) * x - y * x ** 2)  # broadcast over j
    return Ui_L + term.sum(axis=1)


def compute_Bi(x: np.ndarray, y: np.ndarray, w: np.ndarray,
               Ui_L: np.ndarray, Ui_F: np.ndarray) -> np.ndarray:
    """Compute B_i = U_i^L + U_i^F + Σ_j w_ij[(1-y_j)x_j + y_j] for all i."""
    term = w * ((1.0 - y) * x + y)
    return Ui_L + Ui_F + term.sum(axis=1)


def compute_Lhat(x: np.ndarray, y: np.ndarray, w: np.ndarray,
                 Ui_L: np.ndarray, Ui_F: np.ndarray, h: np.ndarray) -> float:
    """Evaluate the objective \hat{L}(x, y)."""
    Ai = compute_Ai(x, y, w, Ui_L)
    Bi = compute_Bi(x, y, w, Ui_L, Ui_F)
    return float(np.dot(h, Ai / Bi))

# -----------------------------
#   Gradients
# -----------------------------

def grad_x(x: np.ndarray, y: np.ndarray, w: np.ndarray,
           Ui_L: np.ndarray, Ui_F: np.ndarray, h: np.ndarray) -> np.ndarray:
    """Gradient of \hat{L} with respect to x (ascent direction)."""
    Ai = compute_Ai(x, y, w, Ui_L)
    Bi = compute_Bi(x, y, w, Ui_L, Ui_F)

    dA_dx = w * ((1.0 + y) - 2.0 * y * x)     # shape (I, J)
    dB_dx = w * (1.0 - y)

    frac = (Bi[:, None] * dA_dx - Ai[:, None] * dB_dx) / (Bi[:, None] ** 2)
    return (frac * h[:, None]).sum(axis=0)


def grad_y(x: np.ndarray, y: np.ndarray, w: np.ndarray,
           Ui_L: np.ndarray, Ui_F: np.ndarray, h: np.ndarray) -> np.ndarray:
    """Gradient of \hat{L} with respect to y (descent direction)."""
    Ai = compute_Ai(x, y, w, Ui_L)
    Bi = compute_Bi(x, y, w, Ui_L, Ui_F)

    dA_dy = w * (-x ** 2 + x)
    dB_dy = w * (1.0 - x)

    frac = (Bi[:, None] * dA_dy - Ai[:, None] * dB_dy) / (Bi[:, None] ** 2)
    return (frac * h[:, None]).sum(axis=0)

# -----------------------------
#   Projection
# -----------------------------

def project_cardinality(v: np.ndarray, k: int, mask: np.ndarray | None = None) -> np.ndarray:
    """Project a vector onto the set {0,1}^J with at most *k* ones.

    Parameters
    ----------
    v : array_like
        Continuous scores for each coordinate.
    k : int
        Maximum number of 1s allowed.
    mask : array_like of bool, optional
        Positions that are forcibly set to 0 (e.g., columns already chosen by *x* when projecting *y*).
    """
    v = np.clip(v, 0.0, 1.0)
    if mask is not None:
        v = v * (~mask)

    if k >= v.size:
        return (v > 0).astype(float)

    # pick indices of the k largest values
    idx = np.argpartition(-v, k)[:k]
    out = np.zeros_like(v)
    out[idx] = 1.0
    return out

# -----------------------------
#   OGDA Solver (with history)
# -----------------------------

def ogda_solver(x0: np.ndarray, y0: np.ndarray, w: np.ndarray,
                Ui_L: np.ndarray, Ui_F: np.ndarray, h: np.ndarray,
                p: int, r: int, *, eta: float = 0.01, max_iter: int = 500,
                tol: float = 1e-6, return_history: bool = False):
    """Optimistic Gradient Descent–Ascent (OGDA) solver that **stores full history**.

    The function now always records objective value, ||dx|| and ||dy|| for each
    iteration so that the convergence process can be plotted easily.
    Set *return_history=True* to retrieve the history dictionary.
    """
    x_prev = x0.copy()
    y_prev = y0.copy()

    # Initial gradients
    gx_prev = grad_x(x_prev, y_prev, w, Ui_L, Ui_F, h)
    gy_prev = grad_y(x_prev, y_prev, w, Ui_L, Ui_F, h)

    obj_vals, dx_vals, dy_vals = [], [], []

    for _ in range(max_iter):
        # Current gradients
        gx = grad_x(x_prev, y_prev, w, Ui_L, Ui_F, h)
        gy = grad_y(x_prev, y_prev, w, Ui_L, Ui_F, h)

        # OGDA update (predictor–corrector style)
        x_tmp = x_prev + eta * (2.0 * gx - gx_prev)  # ascent step for x
        y_tmp = y_prev - eta * (2.0 * gy - gy_prev)  # descent step for y

        # Projection to maintain feasibility
        x_next = project_cardinality(x_tmp, p)
        y_next = project_cardinality(y_tmp, r, mask=(x_next > 0))

        # Metrics for this iteration
        dx = np.linalg.norm(x_next - x_prev)
        dy = np.linalg.norm(y_next - y_prev)
        obj = compute_Lhat(x_next, y_next, w, Ui_L, Ui_F, h)

        obj_vals.append(obj)
        dx_vals.append(dx)
        dy_vals.append(dy)

        # Convergence check
        if max(dx, dy) < tol:
            break

        # Prepare for next iteration
        gx_prev, gy_prev = gx, gy
        x_prev, y_prev = x_next, y_next

    history = {
        "objective": np.array(obj_vals),
        "dx": np.array(dx_vals),
        "dy": np.array(dy_vals)
    }

    if return_history:
        return x_next, y_next, history
    return x_next, y_next

# -----------------------------
#   Plot helper
# -----------------------------

def plot_history(history: dict[str, np.ndarray], *, logy: bool = True):
    """Quick plotting utility for convergence curves.

    Parameters
    ----------
    history : dict
        Dictionary returned by *ogda_solver(return_history=True)*.
    logy : bool, default True
        Plot the y‑axis in log scale for dx/dy curves.
    """
    iters = np.arange(1, len(history["objective"]) + 1)

    fig, ax1 = plt.subplots(figsize=(6, 4))

    # Objective value on primary axis
    # ax1.plot(iters, history["objective"], label="objective", linewidth=1.5)
    ax1.set_xlabel("iteration")
    ax1.set_ylabel("objective")

    # dx & dy on secondary axis
    ax2 = ax1.twinx()
    ax2.plot(iters, history["dx"], linestyle="--", label="‖dx‖")
    ax2.plot(iters, history["dy"], linestyle=":", label="‖dy‖")
    ax2.set_ylabel("step size")
    if logy:
        ax2.set_yscale("log")

    # Legends
    lines, labels = ax1.get_legend_handles_labels()
    l2, lab2 = ax2.get_legend_handles_labels()
    ax1.legend(lines + l2, labels + lab2, loc="best")

    plt.tight_layout()
    plt.show()
