# Gradient Ascent in 2D

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sympy as sym

from IPython import display
display.set_matplotlib_formats('svg')

In [90]:
# A peaksfunction
def peaks(x, y):
    # Expand to a 2D mesh
    x, y = np.meshgrid(x, y)

    # Define functions using np
    z = 3 * (1 - x) ** 2 * np.exp(-(x ** 2) - (y + 1) ** 2) \
        - 10 * (x / 5 - x ** 3 - y ** 5) * np.exp(-x ** 2 - y ** 2) \
        - 1/3 * np.exp(-(x + 1) ** 2 - y ** 2)
    
    return z

In [91]:
# Create the landscape
x = np.linspace(-3, 3, 201)
y = np.linspace(-3, 3, 201)

Z = peaks(x, y)

In [None]:
# Visualize
plt.imshow(Z, extent=[x[0], x[-1], y[0], y[-1]], vmin=-5, vmax=5, origin='lower')
plt.show()

In [93]:
# Create a derivative function using sympy
sx, sy = sym.symbols('sx,sy')

sZ = 3 * (1 - sx) ** 2 * sym.exp(-(sx ** 2) - (sy + 1) ** 2) \
        - 10 * (sx / 5 - sx ** 3 - sy ** 5) * sym.exp(-sx ** 2 - sy ** 2) \
        - 1/3 * sym.exp(-(sx + 1) ** 2 - sy ** 2)

# Create derivative functions using sympy
df_x = sym.lambdify((sx, sy), sym.diff(sZ, sx), 'sympy')
df_y = sym.lambdify((sx, sy), sym.diff(sZ, sy), 'sympy')

In [None]:
df_x(1, 1).evalf()

In [None]:
# Random starting point between -2 and 2
local_minima = np.random.rand(2) * 4 - 2
start_point = local_minima[:]

# Learning parameters
learning_rate = 0.01
training_epochs = 1000

# Training
trajectory = np.zeros((training_epochs, 2))
for i in range(training_epochs):
    gradient = np.array([df_x(local_minima[0], local_minima[1]).evalf(),
                         df_y(local_minima[0], local_minima[1]).evalf()])
    local_minima = local_minima + learning_rate * gradient
    trajectory[i, :] = local_minima

print(local_minima)
print(start_point)

In [None]:
# Visualize
plt.imshow(Z, extent=[x[0], x[-1], y[0], y[-1]], vmin=-5, vmax=5, origin='lower')
plt.plot(start_point[0], start_point[1], 'bs')
plt.plot(local_minima[0], local_minima[1], 'ro')
plt.plot(trajectory[:, 0], trajectory[:, 1], 'r')
plt.legend(['rnd start', 'local max'])
plt.colorbar()
plt.show()