In [None]:
! pip install jax jaxlib

# Implicit curves and surfaces

In this activity, we'll investigate methods for finding the nearest point on an [implicit curve](https://en.wikipedia.org/wiki/Implicit_curve) or implicit surface.
Implicit curves are the level sets of functions, and can describe complicated shapes.  For example, the level curve below is the interface between red and blue.

In [None]:
%matplotlib inline
import jax.numpy as np
import jax
import matplotlib.pyplot as plt
import matplotlib.cm as cm
plt.style.use('ggplot')

x1 = np.linspace(-np.pi, np.pi, 30)
X, Y = np.meshgrid(x1, x1)
F = np.cos(X+Y) - np.cos(X*Y) + .5
plt.contourf(X, Y, F, cmap=cm.coolwarm, levels=[-1, -.5, -0.001, 0.001, .5, 1])
plt.colorbar();

Suppose we are given a point $\mathbf x$ and wish to find the closest point $\mathbf y$ on the level surface; i.e., $f(\mathbf y) = 0$.
This is a constrained optimization problem
$$ \mathbf y_* = \operatorname{argmin}_{\mathbf y} \lVert \mathbf y - \mathbf x \rVert^2, \quad \text{subject to}\, f(\mathbf y) = 0 . $$

We'll use two strategies to solve this problem.  First, we will reformulate the minimization to ask that $\mathbf g = \nabla f(\mathbf y)$ be parallel to $\mathbf y - \mathbf x$.

Suppose we consider $\mathbf g$ as a column vector ($2\times 1$ for an implicit curve; we'll extent to $3\times 1$ for an implicit surface later) and compute its full QR factorization,

$$ \begin{bmatrix} \mathbf q_0 & \mathbf q_1 \end{bmatrix} \begin{bmatrix} r_{00} & 0 \\ 0 & 0 \end{bmatrix} = \begin{bmatrix} \mathbf g \end{bmatrix} . $$
Then $\mathbf q_0$ is parallel to $\mathbf g$, and $\mathbf q_1$ is orthogonal.  As such, our minimization problem can be reformulated as solving the equations

\begin{align}
  f(\mathbf y) &= 0 \\
  \mathbf q_1^T (\mathbf y - \mathbf x) &= 0.
\end{align}

Now let's recall that Householder QR expresses the matrix $Q$ as $I - 2 \mathbf v \mathbf v^T$ where $\mathbf v$ is a unit vector.
We get $\mathbf v$ by sending $\mathbf g$ to $\lVert \mathbf g \rVert \mathbf e_0$.
Here's the code for one column.

Note that we are using `jax.numpy`, which doesn't behave exactly like `numpy`.

In [None]:
def householder1(g):
    n = len(g)
    v = g + np.sign(g[0]) * np.linalg.norm(g) * np.eye(n,1).flatten()
    return v / np.linalg.norm(v)

householder1(np.array([1,2,3]))

In [None]:
class icurve:
    def f(self, y):
        return np.cos(y[0]+y[1]) - np.cos(y[0]*y[1]) + .5
    def grad(self, y):
        return jax.grad(self.f)(y)
    def hessian(self, y):
        return jax.hessian(self.f)(y)

ic = icurve()
ic.grad(np.array([1., 2]))

In [None]:
ic.hessian(np.array([1., 3]))

In [None]:
class nearest:
    def __init__(self, target, x):
        self.target = target # Curve or surface
        self.x = x           # starting point
        
    def residual(self, y):
        f = self.target.f(y)
        g = self.target.grad(y)
        v = householder1(g)
        d = y - self.x
        Qd = d - 2 * v * (v @ d)
        return np.hstack([f, Qd[1:]])
    
#x0 = np.array([1.2, 1.5])
x0 = np.array([1.5, 0.])
n = nearest(ic, x0)
n.residual(np.array([1., 0.]))

* The first component of the residual is satisfaction of the surface criteria $f(\mathbf y) = 0$.
* The second is the tangent component.

So the above point $\mathbf y = [1, 0]$ is close to the implicit surface, but the gradient of the implicit function points in a rather different direction from $\mathbf y - \mathbf x$.

In [None]:
def newton(model, x, jit=False):
    y = x.copy()
    # F and J are callable functions to compute the residual and Jacobian
    F = model.residual
    J = jax.jacobian(model.residual)
    if jit:
        F = jax.jit(F)
        J = jax.jit(J)
    for i in range(10):
        resid = F(y)
        norm = np.linalg.norm(resid)
        print(f'{i} y: {y}')
        print(f'{i} residual: {resid} ({norm:.2e})')
        if norm < 1e-6:
            break
        y -= np.linalg.solve(J(y), resid)
    return y

newton(n, np.array([1., 0]))

* Convergence is fast, though not monotone (you can experiment with different initial guesses).
* Execution speed is slow (there is a perceptable lag between iterations). We'll make that faster by turning on Just In Time (JIT) compilation.

In [None]:
newton(n, np.array([1., 0.]), jit=True)

* All the same iterates, but executes fast after the first iteration.

## Alternate formulation using Lagrange multipliers

A powerful technique for handling constraints is to enforce the constraints using Lagrange multipliers.
To this end, we write a function

$$ L(\mathbf y, \lambda) = \frac 1 2 \lVert \mathbf y - \mathbf x \rVert^2 + \lambda f(\mathbf y) $$

and seek a point for which $\nabla_{\mathbf y,\lambda} L = 0$.
This point will *not* be a minimum of $L$ (it's a saddle point), but allows us to simultaneously satisfy multiple equations.  In particular, the gradient with respect to $\lambda$ is the implicit surface equation $f(\mathbf y) = 0$.

In [None]:
class nearest_lagrange:
    def __init__(self, target, x):
        self.target = target # Curve or surface
        self.x = x           # starting point
        
    def residual(self, ylam):
        y = ylam[:-1]
        lam = ylam[-1]
        f = self.target.f(y)
        g = self.target.grad(y)
        return np.hstack([y - self.x + lam * g, f])
    
nl = nearest_lagrange(ic, x0)
newton(nl, np.array([1., 0., 0])) # Now we have to specify an initial guess for lambda

* There are now 3 residual components.
  * The first two represent satisfaction of $\mathbf y - \mathbf x + \lambda \mathbf g = 0$ (i.e., that $\mathbf y - \mathbf x$ is parallel to $\mathbf g$ *with* relative length $\lambda$).
  * The last component is satisfaction of $f(\mathbf y) = 0$ (i.e., that $\mathbf y$ is on the implicit surface).
* Convergence is fast in the terminal phase.
* We converge to the same solution (in this case; not guaranteed in case there are multiple solutions).
* Convergence is much slower in the initial phase, with $\lambda$ becoming huge (thereby making the first part of the residual huge) before converging.

The erratic initial convergence can be alleviated with line searches and other *globalization* methods, but it represents a common observation that Lagrange multipliers are convenient for modeling (we didn't have to know about Householder QR and differentiate it using `jax`), but are more challenging for optimizers and algebraic solvers.

## Implicit Surfaces

Write a model for an [implicit surface](https://en.wikipedia.org/wiki/Implicit_surface) and experiment with convergence from different initial guesses.  You could use the [Schwarz P](https://en.wikipedia.org/wiki/Schwarz_minimal_surface#Schwarz_P_(%22Primitive%22)) function, for example.

We used the projection technique described in this notebook to 

In [None]:
class isurface:
    def f(self, y):
        """Return a function of your choice"""
        # YOUR CODE HERE
        raise NotImplementedError()
    def grad(self, y):
        return jax.grad(self.f)(y)
    
isurf = isurface()
x0 = np.array([2., 1, 1])
nl = nearest(isurf, x0)
newton(nl, x0, jit=True) # Initial guess equals starting point

In [None]:
assert newton(nl, x0, jit=True)
print('Tests pass')