In [None]:
import numpy as np
import matplotlib.pyplot as plt

# -----------------------------
#   Helper functions
# -----------------------------

def compute_Ai(x: np.ndarray, y: np.ndarray, w: np.ndarray, Ui_L: np.ndarray) -> np.ndarray:
    """Compute A_i = U_i^L + Σ_j w_ij[-y_j x_j^2 + (1+y_j)x_j] for all i."""
    term = w * ((1.0 + y) * x - y * x ** 2)  # broadcast over j
    return Ui_L + term.sum(axis=1)


def compute_Bi(x: np.ndarray, y: np.ndarray, w: np.ndarray,
               Ui_L: np.ndarray, Ui_F: np.ndarray) -> np.ndarray:
    """Compute B_i = U_i^L + U_i^F + Σ_j w_ij[(1-y_j)x_j + y_j] for all i."""
    term = w * ((1.0 - y) * x + y)
    return Ui_L + Ui_F + term.sum(axis=1)


def compute_Lhat(x: np.ndarray, y: np.ndarray, w: np.ndarray,
                 Ui_L: np.ndarray, Ui_F: np.ndarray, h: np.ndarray) -> float:
    """Evaluate the objective \hat{L}(x, y)."""
    Ai = compute_Ai(x, y, w, Ui_L)
    Bi = compute_Bi(x, y, w, Ui_L, Ui_F)
    return float(np.dot(h, Ai / Bi))

# -----------------------------
#   Gradients
# -----------------------------

def grad_x(x: np.ndarray, y: np.ndarray, w: np.ndarray,
           Ui_L: np.ndarray, Ui_F: np.ndarray, h: np.ndarray) -> np.ndarray:
    """Gradient of \hat{L} with respect to x (ascent direction)."""
    Ai = compute_Ai(x, y, w, Ui_L)
    Bi = compute_Bi(x, y, w, Ui_L, Ui_F)

    dA_dx = w * ((1.0 + y) - 2.0 * y * x)     # shape (I, J)
    dB_dx = w * (1.0 - y)

    frac = (Bi[:, None] * dA_dx - Ai[:, None] * dB_dx) / (Bi[:, None] ** 2)
    return (frac * h[:, None]).sum(axis=0)


def grad_y(x: np.ndarray, y: np.ndarray, w: np.ndarray,
           Ui_L: np.ndarray, Ui_F: np.ndarray, h: np.ndarray) -> np.ndarray:
    """Gradient of \hat{L} with respect to y (descent direction)."""
    Ai = compute_Ai(x, y, w, Ui_L)
    Bi = compute_Bi(x, y, w, Ui_L, Ui_F)

    dA_dy = w * (-x ** 2 + x)
    dB_dy = w * (1.0 - x)

    frac = (Bi[:, None] * dA_dy - Ai[:, None] * dB_dy) / (Bi[:, None] ** 2)
    return (frac * h[:, None]).sum(axis=0)

# -----------------------------
#   Projection
# -----------------------------

def project_cardinality(v: np.ndarray, k: int, mask: np.ndarray | None = None) -> np.ndarray:
    """Project a vector onto the set {0,1}^J with at most *k* ones."""
    v = np.clip(v, 0.0, 1.0)
    if mask is not None:
        v = v * (~mask)

    if k >= v.size:
        return (v > 0).astype(float)

    idx = np.argpartition(-v, k)[:k]
    out = np.zeros_like(v)
    out[idx] = 1.0
    return out

# -----------------------------
#   OGDA Solver (with history)
# -----------------------------

def ogda_solver(x0: np.ndarray, y0: np.ndarray, w: np.ndarray,
                Ui_L: np.ndarray, Ui_F: np.ndarray, h: np.ndarray,
                p: int, r: int, *, eta: float = 0.1, max_iter: int = 500,
                tol: float = 1e-6, return_history: bool = False):
    """Optimistic Gradient Descent–Ascent (OGDA) solver that **stores full history**."""
    x_prev = x0.copy()
    y_prev = y0.copy()

    gx_prev = grad_x(x_prev, y_prev, w, Ui_L, Ui_F, h)
    gy_prev = grad_y(x_prev, y_prev, w, Ui_L, Ui_F, h)

    obj_vals, dx_vals, dy_vals = [], [], []

    for _ in range(max_iter):
        gx = grad_x(x_prev, y_prev, w, Ui_L, Ui_F, h)
        gy = grad_y(x_prev, y_prev, w, Ui_L, Ui_F, h)

        x_tmp = x_prev + eta * (2.0 * gx - gx_prev)
        y_tmp = y_prev - eta * (2.0 * gy - gy_prev)

        x_next = project_cardinality(x_tmp, p)
        y_next = project_cardinality(y_tmp, r, mask=(x_next > 0))

        dx = np.linalg.norm(x_next - x_prev)
        dy = np.linalg.norm(y_next - y_prev)
        obj = compute_Lhat(x_next, y_next, w, Ui_L, Ui_F, h)

        obj_vals.append(obj)
        dx_vals.append(dx)
        dy_vals.append(dy)

        print(f"\n[Iter {_}]")
        print(f"x_tmp: {x_tmp}")
        print(f"x_proj (x_next): {x_next}")
        print(f"dx: {dx}")
        print(f"gx: {gx}")
        print(f"y_tmp: {y_tmp}")
        print(f"y_proj (y_next): {y_next}")
        print(f"dy: {dy}")
        print(f"gy: {gy}")

        if max(dx, dy) < tol:
            break
        
        gx_prev, gy_prev = gx, gy
        x_prev, y_prev = x_next, y_next


    history = {
        "objective": np.array(obj_vals),
        "dx": np.array(dx_vals),
        "dy": np.array(dy_vals)
    }

    if return_history:
        return x_next, y_next, history
        # return x_next, y_next
    return x_next, y_next

# -----------------------------
#   Plot helpers
# -----------------------------

def plot_history(history: dict[str, np.ndarray], *, logy: bool = True):
    """Plot convergence curves for OGDA solver history."""
    iters = np.arange(1, len(history["objective"]) + 1)

    fig, ax1 = plt.subplots(figsize=(6, 4))
    ax1.plot(iters, history["objective"], label="objective", linewidth=1.5)
    ax1.set_xlabel("iteration")
    ax1.set_ylabel("objective")

    ax2 = ax1.twinx()
    ax2.plot(iters, history["dx"], linestyle="--", label="‖dx‖")
    ax2.plot(iters, history["dy"], linestyle=":", label="‖dy‖")
    ax2.set_ylabel("step size")
    if logy:
        ax2.set_yscale("log")

    lines, labels = ax1.get_legend_handles_labels()
    l2, lab2 = ax2.get_legend_handles_labels()
    ax1.legend(lines + l2, labels + lab2, loc="best")
    plt.tight_layout()
    plt.show()


def plot_minmax_history(L_vals: list | np.ndarray,
                        dx_vals: list | np.ndarray,
                        dy_vals: list | np.ndarray,
                        *, logy: bool = True):
    """Plot convergence history returned by *minmax_solver*.

    Parameters
    ----------
    L_vals : sequence of float
        Objective values (hist_Lcont).
    dx_vals : sequence of float
        Norm of x updates.
    dy_vals : sequence of float
        Norm of y updates.
    logy : bool, default True
        Use log scale for dx/dy curves.
    """
    L_vals = np.asarray(L_vals)
    dx_vals = np.asarray(dx_vals)
    dy_vals = np.asarray(dy_vals)
    iters = np.arange(1, len(L_vals) + 1)

    fig, ax1 = plt.subplots(figsize=(6, 4))
    ax1.plot(iters, L_vals, label="objective", linewidth=1.5, color="tab:blue")
    ax1.set_xlabel("iteration")
    ax1.set_ylabel("objective")

    ax2 = ax1.twinx()
    ax2.plot(iters, dx_vals, linestyle="--", label="‖dx‖", color="tab:orange")
    ax2.plot(iters, dy_vals, linestyle=":", label="‖dy‖", color="tab:green")
    ax2.set_ylabel("step size")
    if logy:
        ax2.set_yscale("log")

    lines, labels = ax1.get_legend_handles_labels()
    l2, lab2 = ax2.get_legend_handles_labels()
    ax1.legend(lines + l2, labels + lab2, loc="best")
    plt.tight_layout()
    plt.show()


In [None]:
def plot_facility_selection(candidate_sites, demand_points, x_bin, y_bin):
    """
    施設配置の可視化関数。
    
    Parameters
    ----------
    candidate_sites : list of tuple(float, float)
        候補施設の座標 [(x1, y1), (x2, y2), ...]
    demand_points : list of tuple(float, float)
        需要点の座標 [(x1, y1), (x2, y2), ...]
    x_bin : array-like of 0/1
        リーダーによって選ばれた施設（青丸で表示）
    y_bin : array-like of 0/1
        フォロワーによって選ばれた施設（赤丸で表示）
    """
    # 座標分解
    candidate_x = [pt[0] for pt in candidate_sites]
    candidate_y = [pt[1] for pt in candidate_sites]
    demand_x = [pt[0] for pt in demand_points]
    demand_y = [pt[1] for pt in demand_points]

    # プロット開始
    plt.figure(figsize=(8, 8))

    # 需要点（黒）
    plt.scatter(demand_x, demand_y, color='black', marker='x', label='Demand Points')

    # 候補地（グレー）
    plt.scatter(candidate_x, candidate_y, color='gray', label='Candidate Sites')

    # x_bin == 1 → 青い○（枠のみ）
    for i, val in enumerate(x_bin):
        if val == 1:
            plt.scatter(candidate_sites[i][0], candidate_sites[i][1],
                        s=200, facecolors='none', edgecolors='blue', linewidths=2,
                        label='x_bin = 1' if 'x_bin = 1' not in plt.gca().get_legend_handles_labels()[1] else "")

    # y_bin == 1 → 赤い●
    for i, val in enumerate(y_bin):
        if val == 1:
            plt.scatter(candidate_sites[i][0], candidate_sites[i][1],
                        s=100, color='red',
                        label='y_bin = 1' if 'y_bin = 1' not in plt.gca().get_legend_handles_labels()[1] else "")

    # 軸・凡例など
    plt.xlabel('x')
    plt.ylabel('y')
    plt.title('Demand Points and Candidate Sites with Selections')
    plt.grid(True)
    plt.legend()
    plt.axis('equal')
    plt.show()