In [None]:
import numpy as np
import matplotlib.pyplot as plt

from numpy import array, cos, sin, tan, arctan, exp, log, pi

# This is the key function for solving systems of equations numerically
from scipy.optimize import root 

# Systems of Equations

In the course of solving optimization problems, one often needs to solve systems of equations, and sometimes large ones. One way to do this numerically is the `root` function from `scipy.optimize` (imported above). 

Let's peek at its [documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.root.html), but don't get bogged down in the details.

In [None]:
help(root)

The important thing is that `root` tries to solve the equation $$f(x) = 0$$ numerically, starting with an initial "guess" of $x_0$. 

To do this, `root` function needs two inputs, a function (or _callable_) `fun` and a point `x0` in its domain. It then calls a solver (you can investigate and customize these, but it is quite the rabbithole) and and reports back results. 

## First example. 

Let's start simply and solve $$\cos x = x.$$ We know this has a solution for some $0 < x < \pi/2$ but don't have a good way of finding it.  

In [None]:
# plot cos and and x to see their intersection. 
x = np.linspace(0,pi/2,31)
plt.plot(x,x,x,cos(x));

Now, there is a simple trick to turning the solution of an equation into the zero of a function. It is good old subtraction. 

In [None]:
def f(x):
    return cos(x) - x

In [None]:
x = np.linspace(0,pi/2,31)
plt.plot(x,f(x),x,0*x);

In [None]:
# Now just invoke root and give it a guess. 
root(f,.5)

That's a lot of information, but the most important bit is the solution `x` and the `success` flag. We can capture all this in an object and then just use the pieces we need. 

In [None]:
sol = root(f,.5)
print(sol.message)
x, = sol.x # that comma is there because root stores the solution as an array. 

In [None]:
# Note the solution is a numeric approximation and thus not exact. 
print(x,cos(x))

## Exercise

Find all the solutions to $$e^x - 2x = 1$$. 

In [None]:
xs = np.linspace(-1,2,50)
plt.plot(xs,exp(xs) - 2*xs,label="$e^x-2x$")
plt.plot(xs,np.ones_like(xs),label="$1$");
plt.legend();

In [None]:
def f(x):
    return # insert formula here

### Warning

Bad initial guesses can confuse the solver. Why does this code fail to find $\sqrt{2}$? Can you fix it?

In [None]:
def g(x): return x**2 - 2
root(g,0)

# Systems

The solution to any system of equations can be expressed as a root-finding problem by using vectors. `root` can take a vector-valued function as its callable. 

**Important** When dealing with several variables, the function given to `root` must take an array as its argument, not several variables. 

## First example

Find two numbers that sum to 51 where one is twice the other. 

That is, we solve the system 

$$x + y = 51 $$
$$x-2y =0 $$


In [None]:
def F(v):
    x,y = v
    return [x+y - 51,x-2*y]

sol = root(F,[0,0])
x,y = sol.x
print(x,y)

## Example from class

<img src="halfdisk4.png" width="50%" style="float: right;">
We were trying to maximize/minimize the function $$u(x,y) = x^2 - 6x + 4y^2 - 8y$$ on the upper half-disk of radius $4$. 

Let's do the whole problem. 

In [None]:
def u(x,y):
    return x**2 - 6*x + 4*y**2 - 8*y

In [None]:
# Check the corners, store the candidates in a dictionary
candidates = dict([((4,0),u(4,0)),((-4,0),u(-4,0))])
candidates

Find critical points $$\nabla u = \begin{bmatrix} 2x-6 \\ 8y-8 \end{bmatrix} = \mathbf 0$$

In [None]:
# Play dumb

def F(v):
    x,y = v # separate the individual inputs variables.
    return [2*x - 6, 8*y - 8]

root(F, (1,1))

In [None]:
candidates[(3,1)] = u(3,1)
candidates

For the bottom, we use the simple constraint $g(x,y) = y = 0$ and solve the system 

$$\nabla u = \begin{bmatrix} 2x-6 \\ 8y-8 \end{bmatrix} = \lambda \begin{bmatrix} 0 \\ 1 \end{bmatrix}$$

This is equally trivial to solve, but let's use `root`. Note we have 3 variables now and so we pass in a function that takes in and returns $3$-vectors like so:

In [None]:
def G(v):
    x,y,lam = v #unpack 3 variables
    return [y,2*x-6,8*y - 8 - lam ] # note we move all terms over; the first entry is the constraint

In [None]:
root(G,(1,1,1))

In case you are unfamiliar with the notation, that middle term for $y$ represents $-4.03896783 \times 10^{-28}$ (or similar, depending on your setup), off from $0$ because of rounding errors.

In [None]:
# add to what we've got. 
candidates[(3,0)] = u(3,0)
candidates

Finally, we turn to the top where $g(x,y) = x^2 + y^2 = 16$ and solve the system 

$$\nabla u = \begin{bmatrix} 2x-6 \\ 8y-8 \end{bmatrix} = \lambda \begin{bmatrix} 2x \\ 2y \end{bmatrix}$$

Not so simple a system, so we use `root`.

In [None]:
def H(v):
    x,y,lam = v 
    return [x**2 + y**2 - 16,
            2*x-6 - lam*2*x,
            8*y - 8 - lam*2*y ] 

In [None]:
sol = root(H,(-1,3,1))
print(sol.message)

In [None]:
# Hooray
x,y,lam = sol.x
print(x,y,u(x,y))
candidates[(x,y)] = u(x,y)

In [None]:
candidates

**Bad news** we are not done. There are more solutions. The picture above suggests 3. We go hunting. 

In [None]:
# Now the tricky bit is there are other solutions to the system above. 

root(H,(-3,1,1)).x

In [None]:
x,y,lam = root(H,(3,1,1)).x
print(x,y,u(x,y))
candidates[(x,y)] = u(x,y)

In [None]:
x,y,lam = root(H,(-3,1,1)).x
print(x,y,u(x,y))
candidates[(x,y)] = u(x,y)

In [None]:
candidates

But that is all, so we conlude, finally that the global max of $u$ is $40$ at the corner $(-4,0)$ and the minimum is $-13$ and the critical point.

## Bonus discussion

#### Ignore if you have other things to do

Observe:

In [None]:
x,y,lam = root(H,(0,4,1)).x
print(x,y,u(x,y))

The initial point $(0,4)$ does not lead to the closest root. Which begs the question, which initial condition leads to which root?

In [None]:
def which_root(x0,y0):
    """This function takes the initial point (x0,y0)  and lambda=1 
    and reports back the angle from -pi/2 to 3pi/2 associated to 
    the root it finds. This will make the subsequent picture easier to draw."""
    x,y,l = root(H,(x0,y0,1)).x
    if x > 0:
        return arctan(y/x)
    elif x < 0:
        return arctan(y/x) + pi
    else:
        return sign(y)*pi/2

In [None]:
X = Y = np.linspace(-4,4,250)
X,Y = np.meshgrid(X,Y)
Z = np.vectorize(which_root)(X,Y)

In [None]:
plt.figure(figsize=(10,8))
plt.pcolormesh(X,Y,Z,cmap='gnuplot')
plt.plot(4*cos(np.linspace(0,2*pi,100)),4*sin(np.linspace(0,2*pi,100)),'-k',lw=4)
plt.colorbar();

What does this mean? It mean these solving schemes are wickedly unstable, even chaotic. Try changing the default initial choice of $\lambda_0$ in the code above (It is at 1 now) and watch the picture change. 

In [None]:
plt.colormaps() #use these to try out other color schemes.