* **Projected Gradient Descent (PGD)**: take a gradient step on $f(x)$, then **project** back to the affine set $\{x: Ax=b\}$.
* **Equality-Constrained Newton**: solve the **KKT Newton system** each iteration using autograd for $\nabla f$ and $\nabla^2 f$.

### Problem

We’ll use a simple quadratic with a linear equality:

$$
\min_x\; f(x)=\tfrac12\|x - v\|^2 \quad \text{s.t.} \quad Ax=b,
$$

with $v=(1,2)$, $A=\begin{bmatrix}1 & 1\end{bmatrix}$, $b=1$.
The exact solution is the Euclidean projection of $v$ onto the line $x_1+x_2=1$, i.e. $x^\star=(0,1)$.


In [2]:
import torch
torch.set_default_dtype(torch.float64)

# ---------- Toy instance ----------
v = torch.tensor([1.0, 2.0])
A = torch.tensor([[1.0, 1.0]])  # shape (m,n) with m=1, n=2
b = torch.tensor([1.0])

def f(x: torch.Tensor) -> torch.Tensor:
    # 0.5 * ||x - v||^2
    return 0.5 * torch.sum((x - v) ** 2)

# Closed-form projector onto {x : Ax = b}
# P(x) = x - A^T (A A^T)^{-1} (A x - b)
def project_affine(x: torch.Tensor, A: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    # Assumes A has full row rank (here m=1).
    At = A.T
    M = A @ At                        # (m,m)
    y = torch.linalg.solve(M, (A @ x - b))  # (m,)
    return x - At @ y                 # (n,)

# ---------- 1) Projected Gradient Descent ----------
def projected_gradient_descent(x0, A, b, lr=0.5, max_iter=100, tol=1e-10, verbose=True):
    x = x0.detach().clone().requires_grad_(True)
    hist = []
    for k in range(max_iter):
        # grad f(x)
        fx = f(x)
        (g,) = torch.autograd.grad(fx, x, create_graph=False)
        # gradient step
        x_new = (x - lr * g).detach()
        # projection step
        x_new = project_affine(x_new, A, b)
        # check convergence (stationarity on the feasible set via gradient norm)
        stationarity = g.norm().item()
        feas = torch.norm(A @ x_new - b).item()
        hist.append((k, fx.item(), stationarity, feas))
        if verbose and (k % 10 == 0 or k == max_iter - 1):
            print(f"[PGD] it={k:3d}  f={fx.item():.6e}  ||grad||={stationarity:.3e}  ||Ax-b||={feas:.3e}")
        if stationarity < tol and feas < tol:
            x = x_new.detach()
            break
        x = x_new.detach().requires_grad_(True)
    return x.detach(), hist

# ---------- 2) Equality-Constrained Newton (KKT Newton) ----------
# KKT system for the first-order conditions:
# r_d(x, λ) = ∇f(x) + A^T λ = 0           (dual residual)
# r_p(x)     = A x - b        = 0         (primal residual)
# Newton step solves:
# [ H   A^T ] [dx] = -[ r_d ]
# [ A    0  ] [dλ]   [ r_p ]
#
# where H = ∇^2 f(x) (using autograd.hessian for pedagogy).

def newton_kkt(x0, A, b, lam0=None, max_iter=20, backtrack=True, rho_merit=1.0, tol=1e-12, verbose=True):
    x = x0.detach().clone().requires_grad_(True)
    m, n = A.shape
    lam = torch.zeros(m) if lam0 is None else lam0.detach().clone()

    def grad_f(x):
        fx = f(x)
        (g,) = torch.autograd.grad(fx, x, create_graph=True)
        return g, fx

    def hess_f(x):
        # Pedagogical: use autograd to build H
        return torch.autograd.functional.hessian(lambda z: f(z), x)

    def residuals(x, lam):
        g, fx = grad_f(x)
        rd = g + A.T @ lam               # dual residual
        rp = A @ x - b                   # primal residual
        return rd, rp, fx

    def merit(x, lam):
        # A simple merit function to drive line search: phi = f(x) + (rho/2)||Ax-b||^2
        rd, rp, fx = residuals(x, lam)
        return fx + 0.5 * rho_merit * torch.sum(rp * rp)

    hist = []
    for k in range(max_iter):
        rd, rp, fx = residuals(x, lam)
        H = hess_f(x)

        # Build and solve the KKT linear system
        KKT = torch.zeros((n + m, n + m), dtype=x.dtype)
        KKT[:n, :n] = H
        KKT[:n, n:] = A.T
        KKT[n:, :n] = A

        rhs = torch.cat([-rd, -rp], dim=0)
        sol = torch.linalg.solve(KKT, rhs)
        dx = sol[:n]
        dlam = sol[n:]

        # Backtracking line search on merit function (optional; helps outside quadratics)
        t = 1.0
        phi0 = merit(x, lam)
        if backtrack:
            # Simple Armijo-like decrease on the merit function
            c, beta = 1e-4, 0.5
            # directional derivative of merit ~ grad_phi^T [dx; rp-part], approximated numerically is fine for teaching
            while True:
                x_try = (x + t * dx).detach().requires_grad_(True)
                lam_try = lam + t * dlam
                phi_try = merit(x_try, lam_try)
                if phi_try <= phi0 - c * t * (rd.norm()**2 + rp.norm()**2):
                    break
                t *= beta
                if t < 1e-12:
                    break

        # Update
        x = (x + t * dx).detach().requires_grad_(True)
        lam = (lam + t * dlam).detach()

        # Logging
        res_norm = torch.sqrt(rd.norm()**2 + rp.norm()**2).item()
        hist.append((k, fx.item(), rd.norm().item(), rp.norm().item(), res_norm, t))
        if verbose:
            print(f"[Newton] it={k:2d}  f={fx.item():.6e}  ||rd||={rd.norm().item():.3e}  ||rp||={rp.norm().item():.3e}  step={t:.2e}")
        if res_norm < tol:
            break

    return x.detach(), lam.detach(), hist

# ---------- Run both methods ----------
x0 = torch.tensor([2.5, -1.0])  # arbitrary infeasible start

print("Exact solution (by projection of v onto Ax=b):")
x_star = project_affine(v, A, b)
print("x* =", x_star.numpy(), "   f(x*) =", f(x_star).item())

print("\n--- Projected Gradient Descent ---")
x_pgd, pgd_hist = projected_gradient_descent(x0, A, b, lr=0.5, max_iter=200, tol=1e-12, verbose=True)
print("PGD solution:", x_pgd.numpy(), "  f =", f(x_pgd).item(), "  ||Ax-b|| =", torch.norm(A @ x_pgd - b).item())

print("\n--- Equality-Constrained Newton (KKT) ---")
x_nt, lam_nt, nt_hist = newton_kkt(x0, A, b, max_iter=10, verbose=True)
print("Newton solution:", x_nt.numpy(), "  f =", f(x_nt).item(), "  ||Ax-b|| =", torch.norm(A @ x_nt - b).item(), "  λ =", lam_nt.numpy())


Exact solution (by projection of v onto Ax=b):
x* = [0. 1.]    f(x*) = 1.0

--- Projected Gradient Descent ---
[PGD] it=  0  f=5.625000e+00  ||grad||=3.354e+00  ||Ax-b||=0.000e+00
[PGD] it= 10  f=1.000005e+00  ||grad||=1.414e+00  ||Ax-b||=0.000e+00
[PGD] it= 20  f=1.000000e+00  ||grad||=1.414e+00  ||Ax-b||=0.000e+00
[PGD] it= 30  f=1.000000e+00  ||grad||=1.414e+00  ||Ax-b||=0.000e+00
[PGD] it= 40  f=1.000000e+00  ||grad||=1.414e+00  ||Ax-b||=0.000e+00
[PGD] it= 50  f=1.000000e+00  ||grad||=1.414e+00  ||Ax-b||=0.000e+00
[PGD] it= 60  f=1.000000e+00  ||grad||=1.414e+00  ||Ax-b||=0.000e+00
[PGD] it= 70  f=1.000000e+00  ||grad||=1.414e+00  ||Ax-b||=0.000e+00
[PGD] it= 80  f=1.000000e+00  ||grad||=1.414e+00  ||Ax-b||=0.000e+00
[PGD] it= 90  f=1.000000e+00  ||grad||=1.414e+00  ||Ax-b||=0.000e+00
[PGD] it=100  f=1.000000e+00  ||grad||=1.414e+00  ||Ax-b||=0.000e+00
[PGD] it=110  f=1.000000e+00  ||grad||=1.414e+00  ||Ax-b||=0.000e+00
[PGD] it=120  f=1.000000e+00  ||grad||=1.414e+00  ||Ax-b||=0.