In [None]:
import string
import sys
import warnings

import autograd
import autograd.numpy as np
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d, Axes3D
import pandas as pd

import autocrit

In [None]:
sys.path.append("..")

import shared.format
import shared.tools

In [None]:
plt.rcParams["font.size"] = shared.format.FONTSIZE

In [None]:
warnings.simplefilter("ignore", FutureWarning)
warnings.simplefilter("ignore", RuntimeWarning)

In [None]:
LABEL_FONTSIZE = 22

F_CMAP = "Greys_r"
FIELD_COLOR = "xkcd:brick"

In [None]:
SUCCESS_COLOR = "C0"
FAILURE_COLOR = "C1"

In [None]:
ARROW_WIDTH = 0.01
ARROW_RATIO = 10
ARROW_LENGTH = 0.12
ARROW_ALPHA = 1.

In [None]:
rcond_f = 5e-2
rcond_salmon = 1e-3

In [None]:
def f(r):
    return 1/4 * r[0]**4 - 3.*r[0]**2 + 9.*r[0] + 5.*r[1]**2 + .9*r[1]**4 + 40.

In [None]:
def to_column_vector(lst):
    return np.array(lst, dtype=np.float64)[:, None]


def compute_surface(f, xdim=4, ydim=4, num=30):
    x = np.linspace(-xdim, xdim, num=num)
    y = np.linspace(-ydim, ydim, num=num)
    
    X, Y = np.meshgrid(x, y)
    shape = X.shape
    Z = f([X.flatten(), Y.flatten()]).reshape(shape)
    
    return X, Y, Z


def plot_surface3d(X, Y, Z, cmap=F_CMAP, ax=None):
    if ax is None:
        fig = plt.figure()
        ax = plt.axes(projection="3d")
    ax.contour3D(X, Y, Z, 50, cmap=F_CMAP)
    ax.axis("off")
    
    return ax


def add_field(X, Y, field, ax, color=FIELD_COLOR, f=f, max_f=np.inf):
    for x, y, vec in zip(X.flatten(), Y.flatten(), field):
        if f([x, y]) > max_f:
            continue
        ax.arrow(
            x, y, *(ARROW_LENGTH * vec),
            width=ARROW_WIDTH,
            length_includes_head=True,
            head_width=ARROW_WIDTH * ARROW_RATIO,
            color=color, alpha=ARROW_ALPHA, zorder=4)
    return


def plot_field(X, Y, Z, field, contour_levels=None, cmap=F_CMAP, ax=None,
               field_X=None, field_Y=None, max_f=np.inf):
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
    if field_X is None:
        field_X = X
    if field_Y is None:
        field_Y = Y
        
    cpf = ax.contourf(X, Y, Z, levels=contour_levels, cmap=F_CMAP)
    cp = ax.contour(X, Y, Z, levels=contour_levels, cmap=F_CMAP, linewidths=2, zorder=2)
    add_field(flow_X, flow_Y, field, ax, max_f=max_f)
    ax.axis('off')
    

def truncate_to_unit_norm(vecs):
    return np.array(
        [vec if np.linalg.norm(vec) <= 1 else vec / np.linalg.norm(vec)
         for vec in vecs])


def make_flow_fields(X, Y, grad_f, hess_f, rcond, pinv=True):
    grads = grad_f(np.array([X.flatten(), Y.flatten()])).T
    hessians = [hess_f(to_column_vector([x, y])) for x, y in zip(X.flatten(), Y.flatten())]
    if pinv:
        inv = lambda M: np.linalg.pinv(M, rcond=rcond)
    else:
        inv = np.linalg.inv
    newton_steps = np.einsum('ij,ikj->ik', -1 * grads, inv(hessians))
    gradnorm_grads = np.einsum('ij,ikj->ik', -1 * grads, hessians)
    
    return grads, hessians, newton_steps, gradnorm_grads


def make_mesh_r0s(xlim, ylim, num=10):
    x = np.linspace(-xlim, xlim, num=num)
    y = np.linspace(-ylim, ylim, num=num)
    X, Y = np.meshgrid(x, y)
    
    mesh_r0s = np.vstack([X.flatten(), Y.flatten()]).T
    
    return mesh_r0s


def extract_mrqlp_outputs(row):
    
    # drop initial None
    raw_mrqlp_outputs_over_traj = row["mrqlp_outputs"][1:]
    
    mrqlp_outputs_over_traj = []
    for raw_mrqlp_outputs in raw_mrqlp_outputs_over_traj:
        mrqlp_outputs = [np.squeeze(elem).item() for elem in raw_mrqlp_outputs]
        mrqlp_outputs_over_traj.append(mrqlp_outputs)
        
    mrqlp_outputs_over_traj = np.array(mrqlp_outputs_over_traj)
    
    row["flags"] = mrqlp_outputs_over_traj[:, 0]
    row["iters"] = mrqlp_outputs_over_traj[:, 1]
    row["m_iters"] = mrqlp_outputs_over_traj[:, 2]
    row["qlp_iters"] = mrqlp_outputs_over_traj[:, 3]
    row["relres"] = mrqlp_outputs_over_traj[:, 4]
    row["relares"] = mrqlp_outputs_over_traj[:, 5]
    row["anorm"] = mrqlp_outputs_over_traj[:, 6]
    row["acond"] = mrqlp_outputs_over_traj[:, 7]
    row["xnorm"] = mrqlp_outputs_over_traj[:, 8]
    row["axnorm"] = mrqlp_outputs_over_traj[:, 9]
    
    return row


def run_newton(r0, num_iters, newton_maker):
    log_kwargs = {
        "track_f": True,
        "track_g": True,
        "track_theta": True,
        "track_grad_f": True,
        "track_update": True
    }
    
    newton = newton_maker(log_kwargs)
    newton.run(r0, num_iters=num_iters)
    
    result = {}
    result.update(newton.log)
    
    found_critical_point = result["g_theta"][-1] < 1e-10
    end_point = result["theta"][-1]
    
    result.update({
      "end_point": end_point,
      "found_critical_point": found_critical_point})
    
    return result


def calculate_newton_flow(grad, hess, r, rcond=1e-1,
                          pinv=True):
    if pinv:
        inv = lambda M: np.linalg.pinv(M, rcond=rcond)
    else:
        inv = np.linalg.inv
        
    newton_step = -inv(hess(r)).dot(grad(r))
    step_norm = np.linalg.norm(newton_step)
    
    return newton_step / np.linalg.norm(newton_step)


def follow_newton_flow(grad, hess, r_init,
                       eps=1e-1, num_steps=1, rcond=1e-1,
                       pinv=True):
    rs = [r_init]
    for _ in range(num_steps):
        newton_flow = calculate_newton_flow(
            grad, hess, rs[-1], rcond, pinv)

        if np.linalg.norm(newton_flow) > 0:
            rs.append(rs[-1] + eps * newton_flow)
        else:
            break
            
    return np.array(rs)


def make_newton_result_figures(f, results_df, contour_levels=None):
    fig1, traj_ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 8))

    fig2, axs = plt.subplots(nrows=2, ncols=2, figsize=(8, 8),
                             sharex=True, sharey=True)
    gn_ax, alpha_ax, xnorm_ax, stepnorm_ax  = axs.flatten()

    X, Y, Z = compute_surface(f)
    cp = traj_ax.contourf(X, Y, Z, levels=contour_levels, cmap=F_CMAP)
    traj_ax.axis("off")

    gn_ax.set_title(r"$\|\nabla f\|^2$");
    alpha_ax.set_title(r"$\alpha$");
    xnorm_ax.set_title(r"$\|p\|$")
    stepnorm_ax.set_title(r"$\|\Delta r\|$")
    
    xnorm_ax.set_xlabel("Step")
    stepnorm_ax.set_xlabel(r"Step")
    
    for _, result in results_df.iterrows():

        if result.found_critical_point:
            color = SUCCESS_COLOR
        else:
            color = FAILURE_COLOR

        traj_ax.plot(*np.squeeze(result["theta"]).T, color=color, alpha=0.4, lw=2)
        traj_ax.scatter(*result["theta"][-1], color=color, s=24, zorder=3)

        gn_ax.semilogy(np.multiply(2, result["g_theta"]), color=color, alpha=0.4)
        
        try:
            alpha_ax.semilogy(result["relares"], color=color, alpha=0.4)
        except KeyError:
            pass

        try:
            xnorm_ax.semilogy(result["xnorm"][1:], color=color, alpha=0.4)
        except KeyError:
            pass
        
        xnorm_ax.set_ylim([1e-15, 1e5])

        stepnorm_ax.semilogy(np.linalg.norm(np.diff(result["theta"], axis=0), axis=1),
                             color=color, alpha=0.4)

    plt.tight_layout();
    

def follow_newton_flows(grad, hess, r0s, num_steps=100, newton_flow_eps=0.1, rcond=1e-3):
    newton_flows = [
        follow_newton_flow(
            grad, hess, r0,
            eps=newton_flow_eps, num_steps=num_steps, rcond=rcond)
        for r0 in tqdm(r0s)]
    return newton_flows


def make_newton_flow_plot(f, flows, flow_field, num=30, success_cutoff=1e1, contour_levels=None, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))

    X, Y, Z = compute_surface(f, num=num)

    cp = ax.contourf(X, Y, Z, levels=contour_levels, cmap=F_CMAP)
    add_field(X, Y, flow_field, ax)

    for flow in flows:
        if squared_grad_norm(flow[-1]) < success_cutoff:
            color = SUCCESS_COLOR
        else:
            color = FAILURE_COLOR
        ax.plot(*flow.T, color=FIELD_COLOR)
        ax.scatter(*flow[-1], color=color, zorder=3)

    ax.axis('off');

# Define 2D toy function

$$\begin{align}
f(x, y) = 1/4 \cdot x^4& - 3 x^2 + 9 x \ + \\
          9/10 \cdot y^4& + 5 y^2 + 40
\end{align}$$

$$
\nabla f =
\left[
    \begin{array}{c}
        x^3 - 6x + 9\\
        3.6y^3 + 10y
    \end{array}
    \right]$$

$$
\nabla f =
\left[
    \begin{array}{c}
        (x + 3) (x^2 - 3 x + 3)\\
        y(3.6y^2 + 10)
    \end{array}
    \right]$$

$$
\nabla^2 f =
\left[
    \begin{array}{cc}
        6 x^2 - 6 & 0\\
        0 & 10.8 y^2 + 10
    \end{array}
    \right]$$

In [None]:
grad_f = autograd.elementwise_grad(f)

raw_hess_f = autograd.hessian(f)

hess_f = lambda r: np.squeeze(raw_hess_f(r))

squared_grad_norm = lambda r: np.sum(np.square(grad_f(r)), axis=0)

# Figure

In [None]:
def add_indicators_to_slice(points, f, color, ax, size=12):
    ax.scatter(
        np.squeeze(np.array(points)[:, 0, :]),
        [f(point) for point in points],
        zorder=3, color=color, s=size**2)

    
def plot_mr_trajs(results_df, ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
        
    for _, result in results_df.iterrows():
        
        if is_outside_lims(result["theta"][0]):
            continue
        
        if result.found_critical_point:
            color = SUCCESS_COLOR
        else:
            color = FAILURE_COLOR

        ax.plot(
            *np.squeeze(result["theta"]).T, color=color, alpha=0.4, lw=2)
        ax.scatter(
            *result["theta"][0], color=color, s=24, zorder=3)
#         ax.scatter(
#             *result["theta"][-1], color=color, s=24, zorder=3)

def is_outside_lims(xy):
    if (np.abs(xy[0]) > 4) or (np.abs(xy[1]) > 3):
        return True
    else:
        return False
        
def plot_from_df(results_df, column="sgn", ax=None):
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))
        
    for _, result in results_df.iterrows():

        if result.found_critical_point:
            color = SUCCESS_COLOR
        else:
            color = FAILURE_COLOR

        ax.plot(result[column], color=color, alpha=0.4)

In [None]:
def plot_along_slice(slice, f, ax=None, **kwargs):
    if ax is None:
        fig, ax = plt.subplots()

    ax.plot(xs, [f(r) for r in slice], **kwargs);

    return ax

In [None]:
critical_point = to_column_vector([-3., 0.])

gradient_flat_point = to_column_vector([np.sqrt(2), 0])

In [None]:
xs = np.linspace(-5, 5, num=5000)
x_axis = [to_column_vector([x, 0]) for x in xs]

In [None]:
X, Y, Z = compute_surface(f, num=100)

contour_levels = np.logspace(0, 2, 20)

In [None]:
flow_X, flow_Y, flow_Z = compute_surface(f, num=21, xdim=4, ydim=4)

In [None]:
grads, hessians, newton_steps, gradnorm_grads = make_flow_fields(
    flow_X, flow_Y, grad_f, hess_f, rcond=rcond_f, pinv=False)

truncated_newton_steps = truncate_to_unit_norm(newton_steps)

In [None]:
nmr_maker = lambda log_kwargs: autocrit.finders.newtons.NewtonMR(
        f, maxit=2, alpha=1., check_pure=True, rho=0.1, beta=0.9, rho_pure=0.1,
        log_kwargs=log_kwargs, rtol=1e-10, log_mrqlp=True,
        acondlim=1e15, maxxnorm=1e2, trancond=1e7)

In [None]:
mesh_r0s = make_mesh_r0s(4, 4, num=10)
r0s = mesh_r0s[:, :, None]

In [None]:
newton_results = [run_newton(r0, num_iters=15, newton_maker=nmr_maker)
                  for r0 in r0s]

results_df = pd.DataFrame(newton_results)

In [None]:
results_df["sgn"] = results_df["g_theta"].apply(
    lambda xs: np.multiply(2, xs))

In [None]:
results_df = results_df.apply(extract_mrqlp_outputs, axis=1)

In [None]:
xticks = [-3, -np.sqrt(2), 0, np.sqrt(2), 3]
xticklabels = ["-3", r"$-\sqrt{2}$", "0", r"$\sqrt{2}$", "3"]

In [None]:
## MAKE FIGURE AND AXES

fig = plt.figure(figsize=(12, 12))

gs = fig.add_gridspec(4, 4)

flow_ax = fig.add_subplot(gs[1:3, :2])
mr_ax = fig.add_subplot(gs[1:3, 2:], sharex=flow_ax, sharey=flow_ax)

sgn_slice_ax = fig.add_subplot(gs[3, :2])
fun_slice_ax = fig.add_subplot(gs[0, :2], sharex=sgn_slice_ax)
fun_ax = fig.add_subplot(gs[0, 2:], sharey=fun_slice_ax)
sgn_ax = fig.add_subplot(gs[3, 2:], sharey=sgn_slice_ax)

## PLOT FLOW FIELD

plot_field(X, Y, Z, truncated_newton_steps, contour_levels,
           ax=flow_ax, field_X=flow_X, field_Y=flow_Y, max_f=np.inf)
flow_ax.scatter(*gradient_flat_point, color=FAILURE_COLOR, s=144, zorder=3)
flow_ax.scatter(*-gradient_flat_point, color=FAILURE_COLOR, s=144, zorder=3)
flow_ax.scatter(*critical_point, color=SUCCESS_COLOR, s=144, zorder=3)
flow_ax.axis("on")
[spine.set_visible(False) for spine in flow_ax.spines.values()]
flow_ax.set_xticks(xticks)
flow_ax.set_xticklabels(xticklabels)

## PLOT F ALONG AXIS

plot_along_slice(x_axis, f, lw=4, color="k", ax=fun_slice_ax)
fun_slice_ax.set_xlim([-4, 4])
add_indicators_to_slice(
    [-gradient_flat_point, gradient_flat_point],
    f,
    FAILURE_COLOR, fun_slice_ax)
add_indicators_to_slice([critical_point], f,
    SUCCESS_COLOR, fun_slice_ax)
fun_slice_ax.set_ylabel(r"$f$", fontsize=LABEL_FONTSIZE)
fun_slice_ax.set_xticks(xticks)

## PLOT SGN ALONG AXIS

plot_along_slice(x_axis, squared_grad_norm, lw=4, color="k", ax=sgn_slice_ax)
sgn_slice_ax.set_xlabel(r"Position Along $x$-axis",
                    fontsize=LABEL_FONTSIZE);

add_indicators_to_slice(
    [-gradient_flat_point, gradient_flat_point],
    squared_grad_norm, FAILURE_COLOR, sgn_slice_ax)
add_indicators_to_slice([critical_point], lambda x: 1e-3,
    SUCCESS_COLOR, sgn_slice_ax)

sgn_slice_ax.set_yscale("log")
sgn_slice_ax.set_ylim([1e-3, 1e5])
sgn_slice_ax.set_yticks([1e-3, 1e-1, 1e1, 1e3])
sgn_slice_ax.set_ylabel(r"$\|\|\nabla f\|\|^2$", fontsize=LABEL_FONTSIZE)
sgn_slice_ax.set_xticks(xticks)
sgn_slice_ax.set_xticklabels(xticklabels)

fun_slice_ax.set_xticks(xticks)
fun_slice_ax.set_xticklabels(xticklabels)
fun_slice_ax.set_xlabel(r"Position Along $x$-axis",
                    fontsize=LABEL_FONTSIZE);

# for tic in fun_slice_ax.xaxis.get_major_ticks():
#     tic.tick1line.set_visible(False)
#     tic.label1.set_visible(False)
    
## PLOT MR TRAJECTORIES

# X, Y, Z = compute_surface(f, num=30)
cpf = mr_ax.contourf(X, Y, Z, levels=contour_levels, cmap=F_CMAP)
cp = mr_ax.contour(X, Y, Z, levels=contour_levels, cmap=F_CMAP, linewidths=2, zorder=2)
plot_mr_trajs(results_df, ax=mr_ax)
mr_ax.set_xlim(-4, 4); mr_ax.set_ylim(-3., 3.)
mr_ax.axis("on")
[spine.set_visible(False) for spine in mr_ax.spines.values()]
mr_ax.set_xticks(xticks)
mr_ax.set_xticklabels(xticklabels)

## PLOT F of MR

plot_from_df(results_df, column="f_theta", ax=fun_ax)
fun_ax.set_ylim([0, 100]);
fun_ax.set_xlabel("Iterations", fontsize=LABEL_FONTSIZE)

## PLOT SGN of MR

plot_from_df(results_df, column="sgn", ax=sgn_ax)
sgn_ax.set_xlabel("Iterations", fontsize=LABEL_FONTSIZE)
sgn_ax.set_xticks(range(0, int(sgn_ax.get_xlim()[1])+1, 5))
fun_ax.set_xticks(range(0, int(sgn_ax.get_xlim()[1])+1, 5))

label_axs = [fun_slice_ax, fun_ax]
label_poss = [(-0.1, 1.2)] * 2  + [(-0.05, 1.2)] * 2

for label, ax, pos in zip(string.ascii_uppercase, label_axs, label_poss):
    shared.tools.add_panel_label(label, ax, pos)
    
plt.tight_layout(pad=0.4);
None

In [None]:
fig.savefig("toy-problem.pdf", bbox_inches="tight")