In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import ticker, cm

### Random Number Generation

In [None]:
seed = 0
key = jax.random.PRNGKey(seed)
key, randkey = jax.random.split(key)
samples = jax.random.normal(randkey, (15,))
print(samples)

In [None]:
key, randkey = jax.random.split(key)
samples = jax.random.normal(randkey, (15,))

In [None]:
samples

### Beware of Side Effects - Use Pure Functions - Transformation & Compilation

Pure Function: Always return same outputs to same inputs

Example 1: Print side effect

In [None]:
def print_side_effect(x):
    print("input type ", x.dtype)
    return x**2

In [None]:
jax.jit(print_side_effect)(4)
jax.jit(print_side_effect)(5)
jax.jit(print_side_effect)(6.)

In [None]:
y = 5
def global_side_effect(x):
    return x + y

In [None]:
jax.jit(global_side_effect)(5)

In [None]:
y = 10
jax.jit(global_side_effect)(6)

In [None]:
jax.jit(global_side_effect)(6.)

### Himmelblau test function:

$f(x, y) = (x^2 + y - 11)^2 + (x + y^2 - 7)^2$

In [None]:
def el_himmelblau(x, y):
    return (x**2 + y - 11) ** 2 + (x + y**2 - 7) ** 2

### jax.grad

In [None]:
p_x, p_y = 0., 0.
print(el_himmelblau(p_x, p_y))
print(jax.grad(el_himmelblau, 0)(p_x, p_y))
print(jax.grad(el_himmelblau, 1)(p_x, p_y))
print(jax.grad(el_himmelblau, (0,1))(p_x, p_y))
print(jax.value_and_grad(el_himmelblau, (0,1))(p_x, p_y))

In [None]:
print(jax.grad(jax.grad(el_himmelblau, 0), 0)(p_x, p_y))
print(jax.grad(jax.grad(el_himmelblau, 1), 1)(p_x, p_y))
print(jax.grad(jax.grad(jax.grad(el_himmelblau, 0), 0), 0)(p_x, p_y))

### jax.jacfwd & jax.jacrev

In [None]:
def himmelblau(inp):
    x, y = inp[0], inp[1]
    return (x**2 + y - 11) ** 2 + (x + y**2 - 7) ** 2

In [None]:
t = jnp.array([0.,0.])
print(himmelblau(t))
jac_fwd_t = jax.jacfwd(himmelblau)(t)
print(jac_fwd_t)
jac_rev_t = jax.jacrev(himmelblau)(t)
print(jac_rev_t)

### Hessian

In [None]:
def hessian(func):
    return jax.jacfwd(jax.jacrev(func))

In [None]:
print(hessian(himmelblau)(t))

### Automatic Vectorization - jax.vmap

In [None]:
def batch_himmelblau(data):
    return jax.vmap(himmelblau, in_axes=0, out_axes=0)(data)

In [None]:
x = jnp.linspace(-5, 5, 100)
y = jnp.linspace(-5, 5, 100)
[X, Y] = np.meshgrid(x, y)
data = jnp.array(list(zip(X.ravel(), Y.ravel())))

In [None]:
X, Y, data

In [None]:
X.shape, Y.shape, data.shape

In [None]:
Z = batch_himmelblau(data)
print(Z.shape)
Z = Z.reshape(X.shape)
print(Z.shape)

In [None]:
jacob = jax.vmap(jax.jacfwd(himmelblau), 0, 0)(data)

### JIT Compilation (Just In Time Compilation)

In [None]:
%timeit hessian(himmelblau)(data[0]).block_until_ready()

In [None]:
jit_hessian = jax.jit(hessian(himmelblau))
jit_jacfwd = jax.jit(jax.jacfwd(himmelblau))
jit_vmap_hessian = jax.jit(jax.vmap(hessian(himmelblau), 0, 0))
jit_himmelblau = jax.jit(himmelblau)

In [None]:
%timeit jit_hessian(data[0]).block_until_ready()

In [None]:
%timeit jax.vmap(hessian(himmelblau), 0, 0)(data).block_until_ready()

In [None]:
%timeit jit_vmap_hessian(data).block_until_ready()

### plot arrows

In [None]:
def plot_arrows(X, Y, jacob):
    for xx, yy, dxx, dyy in zip(X, Y, 
                                jacob[:, 0].reshape(X.shape), 
                                jacob[:, 1].reshape(X.shape)):
        for x, dx in zip(xx, dxx):
            y = yy[0]
            dy = dyy[0]
            plt.arrow(x, y, -dx/(jnp.linalg.norm([dx,dy])*2), 
                            -dy/(jnp.linalg.norm([dx,dy])*2), 
                            width=1e-5, length_includes_head=True, 
                            head_width=1e-1)

In [None]:
plt.figure(figsize=(8, 6), dpi=80)
c = plt.contourf(X, Y, Z, locator=ticker.LogLocator(), levels=100)
ax = jnp.linspace(-4, 4, 10)
ay = jnp.linspace(-4, 4, 10)
[aX, aY] = np.meshgrid(ax, ay)
adata = jnp.array(list(zip(aX.ravel(), aY.ravel())))
ajacob = jax.vmap(jax.jacfwd(himmelblau), 0, 0)(adata)
plot_arrows(aX, aY, ajacob)
plt.colorbar(c)
plt.xlabel('x')
plt.ylabel('y')
plt.title('Default Himmelblau Contour Plot')
plt.show()

In [None]:
fig, axes = plt.subplots(1, 2,figsize=(12, 6), dpi=80)
c1 = axes[0].contourf(X, Y, jacob[:, 0].reshape(X.shape), levels=50)
c2 = axes[1].contourf(X, Y, jacob[:, 1].reshape(X.shape), levels=50)
plt.colorbar(c1, ax=axes[0])
plt.colorbar(c2, ax=axes[1])
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')
axes[0].set_title('Default Himmelblau Jacobian Plot (wrt x)')
axes[1].set_xlabel('x')
axes[1].set_ylabel('y')
axes[1].set_title('Default Himmelblau Jacobian Plot (wrt y)')

### Gradient Descent

In [None]:
init_point = jnp.array([-1.0,2.0])

In [None]:
def gradient_descent(func, point, lr, eps=1e-2, max_it=500):
    def gradient_descent_1it(p, lr):
        jacob = jax.jacfwd(himmelblau)(p)
        return p - lr * jacob, jnp.linalg.norm(jacob)
    p = point
    hist = [np.asarray(p)]
    it = 0        
    for i in range(max_it):
        p, norm = gradient_descent_1it(p, lr)
        hist.append(np.asarray(p))
        it += 1
        if norm < eps:
            break
    return p, np.array(hist), it

In [None]:
g_s, g_h, g_it = gradient_descent(himmelblau, init_point, 5e-3)
print("Number of iterations for gradient descent: ", g_it)

### Steepest Descent

In [None]:
def steepest_descent(func, point, lr, eps=1e-2, max_it=500):
    def get_alpha(p, lr0):
        f = lambda lr0: func(p-lr0 * jax.jacfwd(func)(p))
        lr = lr0
        lr_next = lr - jax.jacfwd(f)(lr)/(hessian(f)(lr))
        while jnp.abs(lr_next - lr) > 1e-4:
            lr = lr_next
            lr_next = lr - jax.jacfwd(f)(lr)/(hessian(f)(lr))
            lr_next = jnp.maximum(lr_next, 0.0)
        return lr_next
    def steepest_descent_1it(p, lr):
        lr = get_alpha(p, lr)
        jacob = jax.jacfwd(func)(p)
        return p - lr * jacob, lr, jnp.linalg.norm(jacob)
    p = point
    hist = [np.asarray(p)]
    it = 0
    for i in range(max_it):
        p, lr, norm = steepest_descent_1it(p, lr)
        hist.append(np.asarray(p))
        it += 1
        if norm < eps:
            break
    return p, np.array(hist), it

In [None]:
s_s, s_h, s_it = steepest_descent(himmelblau, init_point, 5e-3)
print("Number of iterations for steepest descent: ", s_it)

In [None]:
plt.figure(figsize=(16, 12), dpi=80)
c = plt.contourf(X, Y, Z, locator=ticker.LogLocator(), levels=100)
plot_arrows(aX, aY, ajacob)
plt.colorbar(c)
plt.plot(s_h.T[0], s_h.T[1], c='m', label='Steepest Descent')
plt.plot(g_h.T[0], g_h.T[1], c='b', label='Gradient Descent')
plt.ylim(-5, 5)
plt.xlim(-5, 5)
plt.legend(loc="upper right")
plt.xlabel('x')
plt.ylabel('y')
plt.title('Default Himmelblau Contour Plot')
plt.show()

### Sensitivity to Initial Alpha

In [None]:
def get_alpha(p, lr0):
    f = lambda lr0: himmelblau(p-lr0 * jax.jacfwd(himmelblau)(p))
    lr_next = lr0 - jax.jacfwd(f)(lr0)/(hessian(f)(lr0))
#     return jnp.minimum(jnp.maximum(lr_next, 1e-5), 1e-2)
    return jnp.maximum(lr_next, 0.0)
jit_get_alpha = jax.jit(get_alpha)
def steepest_descent_1it(p, lr):
    lr = jit_get_alpha(p, lr)
    jacob = jax.jacfwd(himmelblau)(p)
    return p - lr * jacob, lr, jnp.linalg.norm(jacob)
f_d = lambda lr: himmelblau(steepest_descent_1it(init_point, lr)[0])
djacob = jax.vmap(jax.jacfwd(f_d), 0, 0)(jnp.linspace(0.0, 1e-1, 100))

In [None]:
plt.figure(figsize=(8, 6), dpi=80)
plt.plot(jnp.linspace(0.0, 1e-2, 100), djacob)
plt.xlabel('x')
plt.ylabel('y')
plt.title('Gradient of 1 iteration steepest descent update wrt initial alpha')

### Himmelblau Function constrained to x < 0 & y < 0

In [None]:
def constrained_himmelblau(inp, gamma=1.):
    scale = 1e2
    x, y = inp[0], inp[1]
    return (x**2 + y - 11) ** 2 + (x + y**2 - 7) ** 2 \
                    + gamma * jnp.maximum(0, scale * x**3) \
                    + gamma * jnp.maximum(0, scale * y**3)

def batch_constrainedhimmelblau(data):
    return jax.vmap(constrained_himmelblau, 0, 0)(data)

In [None]:
constrained_Z = batch_constrainedhimmelblau(data)
print(constrained_Z.shape)
constrained_Z = constrained_Z.reshape(X.shape)
print(constrained_Z.shape)

In [None]:
constrained_jacob = jax.vmap(jax.jacfwd(constrained_himmelblau), 0, 0)(data)

In [None]:
plt.figure(figsize=(8, 6), dpi=80)
c = plt.contourf(X, Y, constrained_Z, locator=ticker.LogLocator(), levels=100)
cajacob = jax.vmap(jax.jacfwd(constrained_himmelblau), 0, 0)(adata)
plot_arrows(aX, aY, cajacob)
plt.colorbar(c)
plt.xlabel('x')
plt.ylabel('y')
plt.title('Constrained Himmelblau Contour Plot')
plt.show()

In [None]:
fig, axes = plt.subplots(1, 2,figsize=(12, 6), dpi=80)
c1 = axes[0].contourf(X, Y, constrained_jacob[:, 0].reshape(X.shape), levels=50)
c2 = axes[1].contourf(X, Y, constrained_jacob[:, 1].reshape(X.shape), levels=50)
plt.colorbar(c1, ax=axes[0])
plt.colorbar(c2, ax=axes[1])
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')
axes[0].set_title('Constrained Himmelblau Jacobian Plot (wrt x)')
axes[1].set_xlabel('x')
axes[1].set_ylabel('y')
axes[1].set_title('Constrained Himmelblau Jacobian Plot (wrt y)')

In [None]:
cg_s, cg_h, cg_it = gradient_descent(constrained_himmelblau, init_point, 5e-3)
print("Number of iterations for gradient descent: ", cg_it)

In [None]:
cs_s, cs_h, cs_it = steepest_descent(constrained_himmelblau, init_point, 5e-3)
print("Number of iterations for steepest descent: ", cs_it)

In [None]:
plt.figure(figsize=(16, 12), dpi=80)
c = plt.contourf(X, Y, constrained_Z, locator=ticker.LogLocator(), levels=100)
plot_arrows(aX, aY, cajacob)
plt.colorbar(c)
plt.plot(cs_h.T[0], cs_h.T[1], c='m', label='Steepest Descent')
plt.plot(cg_h.T[0], cg_h.T[1], c='b', label='Gradient Descent')
plt.ylim(-5, 5)
plt.xlim(-5, 5)
plt.legend(loc="upper right")
plt.xlabel('x')
plt.ylabel('y')
plt.title('Constrained Himmelblau Contour Plot')
plt.show()