# Just The Facts...

Lets start with a simple line fit.

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

Create some random data

In [None]:
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, (100,))
y = 3*x + 2 + 1.0*jax.random.normal(rng, (100,))

Lets plot this so we can see what we are doing

In [None]:
plt.scatter(x, y)
plt.plot(x, 3*x + 2, color='red')
plt.show()

In [None]:
delta = y - (3*x + 2)
plt.hist(delta)
plt.show()

In [None]:
def network(params, x):
    w, b = params
    return w*x + b

In [None]:
def loss(params, x, y):
    y_pred = network(params, x)
    return jnp.mean((y_pred - y)**2)

In [None]:
grad_loss = jax.grad(loss)

In [None]:
params = jnp.array([1.0, 1.0])

In [None]:
lr = 0.1

In [None]:
def one_epoch(params, x, y, i_epoch):
    g = grad_loss(params, x, y)
    params -= lr*g
    print(f'Step {i_epoch}, loss {loss(params, x, y)}')
    return params

In [None]:
for i_epoch in range(10):
    params = one_epoch(params, x, y, i_epoch)

In [None]:
print(f"Final parameters: {params}")

In [None]:
params = jnp.array([1.0, 1.0])
param_history = [params]
for i in range(50):
    params = one_epoch(params, x, y, i)
    param_history.append(params.copy())

In [None]:
# Extract the first coordinate from param_history
first_coordinate = [param[0] for param in param_history]

# Plot the first coordinate as a function of the epoch number
plt.plot(range(len(first_coordinate)), first_coordinate)
plt.xlabel('Epoch Number')
plt.ylabel('Slope Coordinate')
plt.title('Slope Coordinate vs. Epoch Number')
plt.show()


In [None]:
# Extract the second coordinate from param_history
second_coordinate = [param[1] for param in param_history]

# Plot the second coordinate as a function of the epoch number
plt.plot(range(len(second_coordinate)), second_coordinate)
plt.xlabel('Epoch Number')
plt.ylabel('Offset')
plt.title('Offset vs. Epoch Number')
plt.show()
