In [3]:
%run ../gda/func4.ipynb

In [4]:
def lgda_solver(
    x0: np.ndarray, y0: np.ndarray, w: np.ndarray,
    Ui_L: np.ndarray, Ui_F: np.ndarray, h: np.ndarray,
    p: int, r: int, *,
    eta_x: float = 0.01, eta_y: float = 0.01,
    mu: float = 0.01,
    max_iter: int = 500, tau_interval: int = 10,
    return_history: bool = False
):
    """LGDA with Mutation (Algorithm 4) solver."""
    x = x0.copy()
    y = y0.copy()

    cp = x.copy()
    cq = y.copy()
    tau = 0

    obj_vals, dx_vals, dy_vals = [], [], []

    for _ in range(max_iter):
        gx = grad_x(x, y, w, Ui_L, Ui_F, h)
        gy = grad_y(x, y, w, Ui_L, Ui_F, h)

        # Mutation-added update steps
        x_tmp = x - eta_x * gx - mu * (x - cp)
        y_tmp = y + eta_y * gy + mu * (cq - y)

        x_next = project_cardinality(x_tmp, p)
        y_next = project_cardinality(y_tmp, r, mask=(x_next > 0))

        dx = np.linalg.norm(x_next - x)
        dy = np.linalg.norm(y_next - y)
        obj = compute_Lhat(x_next, y_next, w, Ui_L, Ui_F, h)

        obj_vals.append(obj)
        dx_vals.append(dx)
        dy_vals.append(dy)

        x, y = x_next, y_next
        tau += 1

        # Mutation reference update
        if tau == tau_interval:
            cp = x.copy()
            cq = y.copy()
            tau = 0

    history = {
        "objective": np.array(obj_vals),
        "dx": np.array(dx_vals),
        "dy": np.array(dy_vals)
    }

    if return_history:
        return x, y, history
    return x, y