# Gradient Descent Visualization

In [1]:
from mpl_toolkits import mplot3d

In [2]:
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from IPython.display import HTML


matplotlib.rcParams['animation.embed_limit'] = 4**128

### Function to be Optimized

In [3]:
def f(x, y):
    return (x / 4) ** 2 + (y / 2) ** 2

def grad_f(x, y):
    return np.array([2 * (x / 4), 2 * (y / 2)])


### Optimizer to be used

In [4]:
def SGD(learning_rate):
    def optimize(pt):
        return pt - learning_rate * grad_f(*pt)
    return optimize

### Generating Data to be used for Animation

In [5]:
def gen_data(num_iters, pt, opt):
    data = []
    
    for _ in range(num_iters):
        new_data = [*pt, f(*pt)]
        data.append(new_data)
        pt = opt(pt)
    
    return np.array(data).T

In [6]:
learning_rate, num_iters = 1e-2, 1000
opt = SGD(learning_rate)

In [7]:
pt = np.array([10,10])

data = gen_data(num_iters, pt, opt)




In [8]:
x = np.linspace(-10, 10, 30)
y = np.linspace(-10, 10, 30)

X, Y = np.meshgrid(x, y)
Z = f(X, Y)

fig = plt.figure(figsize=(8, 8))
ax = plt.axes(projection='3d')
ax.contour3D(X, Y, Z, 50, cmap='binary')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
ax.set_title('SGD visualization: $\mathcal{L} = (x \div 4)^{2} + (y \div 2)^{2}: \lambda = 1e-2$')

line, = ax.plot(data[0, 0:1], data[1, 0:1], data[2, 0:1])

def update(num, data, line):
    line.set_data(data[:2, :num])
    line.set_3d_properties(data[2, :num])

ani = FuncAnimation(fig, update, num_iters, fargs=(data, line), blit=False)
plt.close()
ani.save('sgd.gif', PillowWriter(fps=60))
#HTML(ani.to_jshtml())
