In [None]:
import numpy as np

from matplotlib import pyplot as plt
from matplotlib import ticker, animation, patches

In [None]:
%matplotlib notebook

In [None]:
def descend():
    global w, b, gradients, learning_rate
    X_padded = np.concatenate((X, np.ones((len(X), 1))), axis=1)
    w_padded = np.array([w, b])
    y_hat = (X_padded * w_padded).sum(axis=1)
    errors = y - y_hat
    gradients = (errors * X_padded.T).mean(axis=1)
    correction = learning_rate * gradients
    w_padded += correction
    w = w_padded[0]
    b = w_padded[1]
    
def update_regression():
    global regression_line
    new_ydata = regression_X * w + b
    regression_line.set_ydata(new_ydata)
    
def plot_gradients(i):
    global gradients, w_gradient_line, w_gradient_scatter, b_gradient_line, b_gradient_scatter
    current_x = b_gradient_line.get_xdata()
    current_w = w_gradient_line.get_ydata()
    current_b = b_gradient_line.get_ydata()
    new_x = np.append(current_x, i)
    new_w = np.append(current_w, gradients[0])
    new_b = np.append(current_b, gradients[1])
    w_gradient_line.set_xdata(new_x)
    w_gradient_line.set_ydata(new_w)
    b_gradient_line.set_xdata(new_x)
    b_gradient_line.set_ydata(new_b)
    if i % interval == 0:
        w_gradient_scatter.remove()
        b_gradient_scatter.remove()
        w_gradient_scatter = ax3.scatter(new_x[::interval], new_w[::interval], color='grey', s=4.0)
        b_gradient_scatter = ax4.scatter(new_x[::interval], new_b[::interval], color='grey', s=4.0)
    
def plot_mse(i):
    global mse_line, mse_scatter, r_squared_text, w, b
    current_x = mse_line.get_xdata()
    current_mse_y = mse_line.get_ydata()
    new_x = np.append(current_x, i)
    new_mse_y = np.append(current_mse_y, ((y - (X * w).sum(axis=1) - b) ** 2).mean())
    mse_line.set_xdata(new_x)
    mse_line.set_ydata(new_mse_y)
    if i % interval == 0:
        mse_scatter.remove()
        mse_scatter = ax2.scatter(new_x[::interval], new_mse_y[::interval], color='grey', s=4.0)
    r_squared_text._text = '$R^2 = {:.3f}$'.format(r_squared())

    
def r_squared():
    global w, b
    return 1 - ((y - (X * w).sum(axis=1) - b) ** 2).sum() / ((y - y_mean) ** 2).sum()
    
def animate(i):
    descend()
    update_regression()
    plot_gradients(i)
    plot_mse(i)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(8, 8))

(ax1, ax2), (ax3, ax4) = axes

n_frames = 201

titles = ('Regression line', 'Mean squared error', r'MSE-$x_1$ gradient', r'MSE-$x_0$ gradient')
x_labels = ('x', 'Iterations', 'Iterations', 'Iterations')
y_labels = ('y', 'MSE', r'$\frac{dMSE}{dx_1}$', r'$\frac{dMSE}{dx_0}$')

np.random.seed(0)

n_points = 100
actual_w = 4.16
actual_b = -15.13
random_strength = 25
learning_rate = 0.001

X = np.random.rand(n_points, 1) * 24 - 12
sorted_indices = np.argsort(X[:, 0])
stable_y = (X * actual_w).sum(axis=1) + actual_b
y = stable_y + (np.random.rand(n_points) - 0.5) * random_strength

y_mean = y.mean()
w = -1.5
b = -12.0

interval = 10

gradients = []

regression_X = np.linspace(X.min() - 0.1, X.max() + 0.1)
regression_y = regression_X * w + b

for ax, title, x_label, y_label in zip(axes.flat, titles, x_labels, y_labels):
    ax.set_title(title)
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label, rotation=0, labelpad=10)

ax1.scatter(X[:, 0], y, color='darkgreen', marker='x', alpha=0.35, label='Data')
ax1.plot(X[sorted_indices], stable_y[sorted_indices], color='red', lw=1.5, ls='dashed', dashes=(5.0, 2.5), label='True relationship', zorder=5)
regression_line = ax1.plot(regression_X, regression_y, color='black', label='Regression line', zorder=6)[0]
ax1.xaxis.set_major_locator(ticker.MultipleLocator(10.0))
ax1.yaxis.set_major_locator(ticker.MultipleLocator(20.0))

ax1.legend()
    
ax2.axis((-10, n_frames + 10, -20, 1470))
ax3.axis((-10, n_frames + 10, -10, 310))
ax4.axis((-10, n_frames + 10, -6.5, -1.5))

mse_line = ax2.plot([], [], color='grey')[0]
mse_scatter = ax2.scatter([], [])

ax2.add_artist(patches.Rectangle((0.66, 0.90), 0.315, 0.07, lw=0.75, edgecolor='black', facecolor='white', transform=ax2.transAxes))
r_squared_text = ax2.text(0.67, 0.91, '$R^2 = {:.3f}$'.format(r_squared()), transform=ax2.transAxes)

# ax3.axhline(actual_b, color='goldenrod')
# ax4.axhline(actual_w, color='goldenrod')

w_gradient_line = ax3.plot([], [], color='grey')[0]
w_gradient_scatter = ax3.scatter([], [])

b_gradient_line = ax4.plot([], [], color='grey')[0]
b_gradient_scatter = ax4.scatter([], [])

fig.tight_layout() 

gd_anim = animation.FuncAnimation(fig, animate, frames=n_frames, interval=100, repeat=False)