In [None]:
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsc
import matplotlib.pyplot as plt
from functools import partial
import scipy.stats as stats

key = jr.PRNGKey(4)

n_weight = 4


def generate_data(key, n_data, n_weight):
    x_key, noise_key, weight_key = jr.split(key, 3)
    weights = jr.normal(weight_key, (1, n_weight))
    bias = 1.0
    x = jr.uniform(x_key, (n_data, n_weight))
    noise = jr.normal(noise_key, (n_data, 1))
    y = (weights * x).sum() + bias + noise
    return x, y


n_data = 1000
x, y = generate_data(key, n_data, n_weight)
X = jnp.column_stack((x, jnp.ones((n_data, 1))))


In [None]:
@partial(jax.vmap, in_axes=(0))
def expm(x):
    return jsc.linalg.expm(x)


@jax.jit
def loss(w, X, y):
    return 0.5 * jnp.mean((X @ w - y.reshape((-1, 1))) ** 2)


def update(carry, _):
    w_gd, learning_rate = carry
    grad = jax.grad(loss)(w_gd, X, y)
    w_gd = w_gd - learning_rate * grad
    return (w_gd, learning_rate), w_gd


def fit_convergance_line(error, learning_rates):
    A = jnp.column_stack(
        (jnp.log(jnp.array(learning_rates)), jnp.ones((len(learning_rates),)))
    )
    log_error = jnp.log(jnp.array(error))
    coefficients, _, _, _ = jnp.linalg.lstsq(A, log_error)
    return coefficients


@jax.jit
def max_error(x, y):
    return jnp.max((jnp.abs(jnp.linalg.norm(x, axis=1) - jnp.linalg.norm(y, axis=1))))


In [None]:
n_variables = n_weight + 1
w_0 = jnp.ones((n_variables, 1))
M = X.T @ X
w_star = jnp.linalg.solve(M, X.T @ y).reshape((-1, 1))

first_order_error = []
second_order_error = []
learning_rates = [0.2, 0.1, 0.05, 0.025,
                  0.01, 0.005, 0.0025, 0.00125, 0.000625, 0.0003125]
final_time = 1
results = []

for learning_rate in learning_rates:
    max_iter = int(final_time / learning_rate)

    # Analytical solution to the modified equation
    t = jnp.linspace(0, final_time,
                     max_iter+1).reshape((-1, 1, 1))
    w_first_order = (expm(-1/n_data * M * t) @
                     (w_0 - w_star) + w_star).reshape((-1, n_variables))
    w_second_order = (expm(-1/n_data * (M @ (jnp.eye(n_variables) + learning_rate/(2*n_data) * M)) * t) @
                      (w_0 - w_star) + w_star).reshape((-1, n_variables))

    # Gradient descent iterates
    _, iterates = jax.lax.scan(
        update, (w_0, learning_rate), jnp.arange(0, max_iter))

    w_gd = jnp.concatenate(
        (w_0.reshape((1, n_variables, 1)), iterates)).squeeze()
    first_order_error.append(max_error(w_gd, w_first_order))
    second_order_error.append(max_error(w_gd, w_second_order))

fig, ax = plt.subplots()
ax.plot(learning_rates, first_order_error,
        '*', label='Experimental first order')
ax.plot(learning_rates, second_order_error,
        '*', label='Experimental second order')
ax.plot(learning_rates, jnp.array(learning_rates) * (first_order_error[0]/learning_rates[0]),
        '--', label='Theoretical first order')
ax.plot(learning_rates, jnp.array(learning_rates) **
        2 * (second_order_error[0]/learning_rates[0]**2), '--', label='Theoretical second order')
ax.invert_xaxis()
ax.set_xlabel('$\eta$')
ax.set_ylabel('error')
ax.set_xscale('log')
ax.set_yscale('log')
ax.legend()
fig.tight_layout()
fig.set_figwidth(7)
fig.savefig('../seminar_talk/plots/linear_regression_error.pdf')
first_order, _ = fit_convergance_line(first_order_error, learning_rates)
second_order, _ = fit_convergance_line(second_order_error, learning_rates)
first_order.item(), second_order.item()
