In [1]:
import jax
from jax import numpy as jnp, random

import numpy as np # We import the standard NumPy library 

In [9]:
key = random.PRNGKey(0)

# Create the predict function from a set of parameters
def make_predict_pytree(params):
    def predict(x):
        return jnp.dot(params['W'],x)+params['b']
    return predict

# Create the loss from the data points set
def make_mse_pytree(x_batched,y_batched):
    def mse(params):
        # Define the squared loss for a single pair (x,y)
        def squared_error(x,y):
            y_pred = make_predict_pytree(params)(x)
            return jnp.inner(y-y_pred,y-y_pred)/2.0
        # We vectorize the previous to compute the average of the loss on all samples.
        return jnp.mean(jax.vmap(squared_error)(x_batched,y_batched), axis=0)
    return jax.jit(mse) # And finally we jit the result.

In [10]:
# Generate MSE for our samples
mse_pytree = make_mse_pytree(x_samples,y_samples)

# Initialize estimated W and b with zeros.
params = {'W': jnp.zeros_like(W), 'b': jnp.zeros_like(b)}

jax.grad(mse_pytree)(params)

{'W': DeviceArray([[-1.9287353e+00,  4.2963773e-01,  7.1613431e-01,
                2.1056123e+00,  5.0405198e-01, -2.4983377e+00,
               -6.3854122e-01, -2.2620230e+00, -1.3365206e+00,
               -2.0426056e-01],
              [ 1.1999468e+00, -9.4563615e-01, -1.0878406e+00,
               -7.0340687e-01,  3.3224657e-01,  1.7538793e+00,
               -7.1916497e-01,  1.0927429e+00, -1.4491038e+00,
                5.9715652e-01],
              [-1.4826512e+00, -7.6116550e-01,  2.2319783e-01,
               -3.0392045e-01,  3.0397046e+00, -3.8419533e-01,
               -1.8290077e+00, -2.3353386e+00, -1.1087129e+00,
               -7.7454048e-01],
              [ 8.2374370e-01, -9.9650651e-01, -7.6030153e-01,
                6.3919228e-01, -6.0864404e-02, -1.0859709e+00,
                1.2923390e+00, -4.9342966e-01, -1.4719218e-03,
                1.2977620e+00],
              [-4.5656392e-01, -1.3063020e-01, -3.9179036e-01,
                2.1743817e+00, -5.3948894e-02,  

In [7]:
# Initialize estimated W and b with zeros.
What = jnp.zeros_like(W)
bhat = jnp.zeros_like(b)

alpha = 0.3 # Gradient step size
print('Loss for "true" W,b: ', mse(W,b))
for i in range(101):
    # We perform one gradient update
    What, bhat = What - alpha*jax.grad(mse,0)(What,bhat), bhat - alpha*jax.grad(mse,1)(What,bhat)
    if (i%5==0):
        print("Loss step {}: ".format(i), mse(What,bhat))

Loss for "true" W,b:  0.023639796
Loss step 0:  11.096582
Loss step 5:  1.1743387
Loss step 10:  0.3287934
Loss step 15:  0.1398177
Loss step 20:  0.07359567
Loss step 25:  0.04415302
Loss step 30:  0.029408723
Loss step 35:  0.021554684
Loss step 40:  0.017227953
Loss step 45:  0.014798909
Loss step 50:  0.013420274
Loss step 55:  0.012632738
Loss step 60:  0.012181121
Loss step 65:  0.011921484
Loss step 70:  0.011772007
Loss step 75:  0.01168586
Loss step 80:  0.011636183
Loss step 85:  0.011607513
Loss step 90:  0.011590973
Loss step 95:  0.01158142
Loss step 100:  0.011575905
