# Gradient Descent Visualization

In [None]:
from mpl_toolkits import mplot3d

In [None]:
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from IPython.display import HTML


matplotlib.rcParams['animation.embed_limit'] = 4**128

### Function to be Optimized

In [None]:
def f(x, y):
    return (x / 4) ** 2 + (y / 2) ** 2

def grad_f(x, y):
    return np.array([2 * (x / 4), 2 * (y / 2)])


### Optimizers to be used

In [None]:
def SGD(learning_rate):
    def optimize(pt):
        return pt - learning_rate * grad_f(*pt)
    return optimize

In [None]:
def SGD_momentum(learning_rate, beta):
    v = np.array([0, 0])
    def optimize(pt):
        nonlocal v
        v = beta * v - learning_rate * grad_f(*pt)
        return pt + v
    return optimize

In [None]:
def SGD_nestorov(learning_rate, beta):
    v = np.array([0, 0])
    def optimize(pt):
        nonlocal v
        x_pt, y_pt = pt
        x_mom, y_mom = beta * v
        pt_lookahead = [x_pt + x_mom, y_pt + y_mom]
        v = beta * v - learning_rate * grad_f(*pt_lookahead)
        return pt + v
    return optimize

In [None]:
def RMS_prop(learning_rate, beta):
    mean_squared_x, mean_squared_y = 1., 1.
    def optimize(pt):
        nonlocal mean_squared_x, mean_squared_y
        grad_x, grad_y = grad_f(*pt)
        mean_squared_x = beta * mean_squared_x + (1 - beta) * (grad_x ** 2)
        mean_squared_y = beta * mean_squared_y + (1 - beta) * (grad_y ** 2)
        mean_squared = np.array([mean_squared_x, mean_squared_y])
        return pt - (learning_rate / np.sqrt(mean_squared)) * grad_f(*pt)
        
    return optimize

In [None]:
def ADAM(learning_rate, beta_1, beta_2, eps=1e-7):
    m, v, t = np.array([0., 0.]), np.array([0., 0.]), 1.
    def optimize(pt):
        nonlocal m, v, t
        m = (beta_1 * m) + (1 - beta_1) * grad_f(*pt)
        v = (beta_2 * v) + (1 - beta_2) * (grad_f(*pt) ** 2)
        m_hat = m / (1 - beta_1 ** t)
        v_hat = v / (1 - beta_2 ** t) 
        t += 1
        return pt - learning_rate * (m_hat / np.sqrt(v_hat + eps))
        
    return optimize

### Generating Data to be used for Animation

In [None]:
def gen_data(num_iters, pt, opt):
    data = []  
    for _ in range(num_iters):
        new_data = [*pt, f(*pt)]
        data.append(new_data)
        pt = opt(pt)
    
    return np.array(data).T

In [None]:
num_iters = 1000

In [None]:
learning_rate = 1e-2
sgd = SGD(learning_rate)

In [None]:
learning_rate, beta = 1e-2, 0.9
sgd_momentum = SGD_momentum(learning_rate, beta)

In [None]:
learning_rate, beta = 1e-2, 0.9
sgd_nestorov = SGD_nestorov(learning_rate, beta)

In [None]:
learning_rate, beta = 1e-2, 0.9
rms_prop = RMS_prop(learning_rate, beta)

In [None]:
learning_rate, beta_1, beta_2, eps = 1e-2, 0.9, 0.99, 1e-4
adam = ADAM(learning_rate, beta_1, beta_2, eps)

In [None]:
sgd_pt = np.array([1,1])
sgd_data = gen_data(num_iters, sgd_pt, sgd)

sgd_mom_pt = np.array([-1, -1])
sgd_mom_data = gen_data(num_iters, sgd_mom_pt, sgd_momentum)

sgd_nes_pt = np.array([-1, -1])
sgd_nes_data = gen_data(num_iters, sgd_nes_pt, sgd_nestorov)

rms_pt = np.array([1, -1])
rms_data = gen_data(num_iters, rms_pt, rms_prop)

adam_pt = np.array([-1, 1])
adam_data = gen_data(num_iters, adam_pt, adam)



In [None]:
x = np.linspace(-1, 1, 30)
y = np.linspace(-1, 1, 30)

X, Y = np.meshgrid(x, y)
Z = f(X, Y)

fig = plt.figure(figsize=(8, 8))
ax = plt.axes(projection='3d')
ax.contour3D(X, Y, Z, 40, cmap='Blues')
ax.view_init(75, 0)

ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
ax.set_title('Optimizer visualization: $\mathcal{L} = (x \div 4)^{2} + (y \div 2)^{2}: \lambda = 1e-3$')

sgd_line, = ax.plot(sgd_data[0, 0:1], sgd_data[1, 0:1], sgd_data[2, 0:1], label='sgd', color='r')
sgd_mom_line, = ax.plot(sgd_mom_data[0, 0:1], sgd_mom_data[1, 0:1], sgd_mom_data[2, 0:1], label='sgd momentum', color='b')
sgd_nes_line, = ax.plot(sgd_nes_data[0, 0:1], sgd_nes_data[1, 0:1], sgd_nes_data[2, 0:1], label='nesterov sgd', color='g')
rms_line, = ax.plot(rms_data[0, 0:1], rms_data[1, 0:1], rms_data[2, 0:1], label='rms prop', color='m')
adam_line, = ax.plot(adam_data[0, 0:1], adam_data[1, 0:1], adam_data[2, 0:1], label='adam', color='c')
ax.legend()

opt_data = [(sgd_line, sgd_data), 
            (sgd_mom_line, sgd_mom_data), 
            (sgd_nes_line, sgd_nes_data), 
            (rms_line, rms_data), 
            (adam_line, adam_data),]

def update(num, opt_data):
    for line, data in opt_data:
        line.set_data(data[:2, :num])
        line.set_3d_properties(data[2, :num])
    

ani = FuncAnimation(fig, update, num_iters, fargs=(opt_data,), blit=False)
plt.close()
ani.save('opt.gif', PillowWriter(fps=60))


#HTML(ani.to_jshtml())
