In [2]:
import jax
import jax.numpy as jnp
import jax.random as jr
import plotly.express as px
import pandas as pd


key = jr.PRNGKey(4)
def generate_data(key, n_data):
    x_key, noise_key = jr.split(key)
    weight = 4.0
    bias = 1.0
    x = jr.uniform(x_key, (n_data, ))
    x = jnp.sort(x)
    noise = jr.normal(noise_key, x.shape)
    y = weight * x + bias + noise
    return x, y

n_data = 1000
x, y = generate_data(key, n_data)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
key, subkey = jr.split(key)
weight, bias = jr.normal(subkey, (2,))
weight_0 = weight.copy()
bias_0 = bias.copy()
@jax.jit
def loss(weight, bias, x, y): return jnp.square(weight * x + bias - y).sum(0)

In [24]:
learning_rates = reversed([0.000005, 0.00001, 0.00005, 0.0001])
max_iter = 1000


def new_recording(data, max_iter, learning_rate):
    return pd.DataFrame(
        {
            "iterations": jnp.arange(0, max_iter + 1),
            "time": jnp.linspace(0, max_iter * learning_rate, max_iter + 1),
            "data": data,
            "type": "weights",
            "Learning rate": f"{learning_rate:.1E}",
        }
    )


df = pd.DataFrame(
    {
        "iterations": [],
        "time": [],
        "data": [],
        "type": [],
        "Learning rate": [],
    }
)
for learning_rate in learning_rates:
    weight = weight_0
    bias = bias_0
    weights = [weight]
    biases = [bias]
    for i in range(max_iter):
        dweight, dbias = jax.grad(loss, argnums=(0, 1))(weight, bias, x, y)
        weight += -learning_rate * dweight
        bias += -learning_rate * dbias
        weights.append(weight)
        biases.append(bias)
    weights = jnp.array(weights)
    biases = jnp.array(biases)
    df = pd.concat((df, new_recording(weights, max_iter, learning_rate)))
    df = pd.concat((df, new_recording(biases, max_iter, learning_rate)))
    
    print(f"Weight: {weight:.3f}, bias: {bias:.3f}")
    print(f"Loss: {loss(weight, bias, x, y)}")


Weight: 3.758, bias: 1.112
Loss: 985.7213745117188
Weight: 3.754, bias: 1.114
Loss: 985.7225341796875
Weight: 3.136, bias: 1.442
Loss: 1017.139892578125
Weight: 2.585, bias: 1.734
Loss: 1097.360107421875


In [31]:
fig = px.line(df, x="iterations", y="data", facet_col="type", color="Learning rate")
fig.show()
fig.write_html('../../jonathan-hellwig.github.io/assets/plotly/linear_model.html')

In [29]:
fig = px.line(
    df,
    x="time",
    y="data",
    facet_col="type",
    color="Learning rate",
)
fig.update_traces(opacity=0.8)
fig.show()
fig.write_html('../../jonathan-hellwig.github.io/assets/plotly/linear_model_scaled.html')
