# Gradient Descent in 1D

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from IPython import display
display.set_matplotlib_formats('svg')

In [2]:
# Define the function
def fx(x):
    return 3*x**2 - 3*x + 4

# Define derivative of the function
def deriv_fx(x):
    return 6*x - 3

In [None]:
# Plot the function and its derivative

# Define range for x
x = np.linspace(-1, 2, 2001)

# Plotting
plt.plot(x, fx(x), x, deriv_fx(x))
plt.xlim(x[[0, -1]])
plt.grid()
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend(['y', 'dy'])
plt.show()

In [None]:
# Learning by gradient descent

# Random starting point
local_minima = np.random.choice(x, 1)

# Learning parameters
learning_rate = 0.01
training_epochs = 100

# Training
for i in range(training_epochs):
    gradient = deriv_fx(local_minima)
    local_minima = local_minima - learning_rate * gradient

local_minima

In [None]:
# Plot the results
plt.plot(x, fx(x), x, deriv_fx(x))
plt.plot(local_minima, deriv_fx(local_minima), 'ro')
plt.plot(local_minima, fx(local_minima), 'ro')

plt.xlim(x[[0, -1]])
plt.grid()
plt.xlabel('x')
plt.ylabel('f(x)')
plt.legend(['f(x)', 'df', 'f(x) min'])
plt.title('Empirical local minimum: %s'%local_minima)
plt.show()

In [None]:
# Saving the learning steps

# Random starting point
local_minima = np.random.choice(x, 1)[0]  # Extract the scalar value

# Learning parameters
learning_rate = 0.01
training_epochs = 100

# Training steps while storing results
model_params = np.zeros((training_epochs, 2))
for i in range(training_epochs):
    gradient = deriv_fx(local_minima)
    local_minima = local_minima - learning_rate * gradient
    model_params[i, :] = [local_minima, gradient]  # Use a list to match the shape

local_minima

In [None]:
# Plot the gradient over iterations

fix, ax = plt.subplots(1, 2, figsize=(12, 4))
for i in range(2):
    ax[i].plot(model_params[:, i], 'o-')
    ax[i].set_xlabel('Iteration')
    ax[i].set_title(f'Final estimated minima: {local_minima:.5f}')

ax[0].set_ylabel('Local mimima')
ax[1].set_ylabel('Derivative')

plt.show()