# Straight Lines & JAX

First, lets make some _messy_ straight line data, and see if we can recover the original line definition.

We'll use JAX. JAX is a toolkit from DeepMind & Google used for NN research. Unlike `pytorch` and others it isn't a framework, more like a library and an ecosystem. This means we can see "inside" it. At a very basic level, you could think of it as `numpy`, but differentiable.

In [None]:
import jax # Access to the library
import jax.numpy as jnp # Easy access to numpy like functions in jax
import matplotlib.pyplot as plt

Lets create some data that is a straight line, but some random jitter in it.

In [None]:
# random number tracking in JAX
rng = jax.random.PRNGKey(0)
rng, new_key = jax.random.split(rng)

# Straight line with jitter
x = jax.random.normal(rng, (100,))
jitter = jax.random.normal(new_key, (100,))
y = 3 * x + 2 + 0.5 * jitter

We have a slope of **3** and an intercept of **2**.

To get ourselves comfortable with this, lets look at the data.

In [None]:
[f"({f_x:0.2f}, {f_y:0.2f})" for f_x, f_y in list(zip(x,y))[0:5]]

Of course - when we have this many points, our brain is not built to understand a sequence of numbers. Our eyes, however, are excellent big-data sensors!

In [None]:
plt.scatter(x, y)
# plt.plot(x, 3 * x + 2, color="red")
plt.xlabel("x")
plt.ylabel("y")
plt.show()

## "Exact" solution

We can code up the derivation:

In [None]:
n = len(x)
beta_1 = (n * jnp.sum(x * y) - jnp.sum(x) * jnp.sum(y)) / (
    n * jnp.sum(x**2) - jnp.sum(x) ** 2
)
beta_0 = (jnp.sum(y) - beta_1 * jnp.sum(x)) / n

And the values of the fit:

In [None]:
print(f"beta_0: {beta_0:.2f}")
print(f"beta_1: {beta_1:.2f}")

Again - our eyes are a lot better here!

In [None]:
plt.scatter(x, y, label="Data", color="black")
plt.plot(x, 3 * x + 2, color="green", label="Real L,ine")
plt.plot(x, beta_1 * x + beta_0, color="red", label="Fitted Line")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.show()