In [None]:
import jax
import jax.numpy as jnp
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [2]:
X, y = make_regression(n_samples=1000, n_features=20, noise=0.1)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [3]:
X_train

array([[-1.02077309, -0.48696616, -0.48421983, ..., -1.38741447,
        -1.3479442 ,  0.75929868],
       [ 0.54267285, -0.56201687, -0.84109614, ...,  0.42826231,
         1.18534808,  0.4762074 ],
       [ 0.56625935,  1.08510127, -0.01581252, ...,  0.87658043,
         0.49213668,  0.42274501],
       ...,
       [ 0.9860877 , -0.25821279,  2.02669309, ...,  1.085418  ,
        -0.70901641,  1.27596923],
       [ 0.74752392,  0.77974478,  1.69728358, ...,  1.03238153,
        -1.33953144, -0.78206895],
       [-0.00640283,  0.77753109, -0.4224415 , ...,  0.54716341,
        -1.571583  ,  0.42241783]])

In [33]:
params = {
    'w': jnp.zeros(X.shape[1:]),
    'b': 0.
}

def forward(params, X):
    return jnp.dot(X, params['w']) + params['b']

@jax.jit
def loss_fn(params, X, y):
    error = forward(params, X) - y
    return jnp.mean(jnp.square(error))
grad_fn = jax.grad(loss_fn)

for _ in range(100):
    loss = loss_fn(params, X_test, y_test)  # Monitor loss on test data
    print(loss)
    grads = grad_fn(params, X_train, y_train)  # Compute gradients on train data
    params['w'] -= 0.05 * grads['w']  # Update weights
    params['b'] -= 0.05 * grads['b']  # Update bias




29019.84
23508.777
19065.955
15480.141
12582.683
10238.759
8340.475
6801.378
5552.12
4537.0073
3711.2622
3038.841
2490.697
2043.3971
1678.0146
1379.2472
1134.7075
934.3578
770.05585
635.18884
524.3813
433.2588
358.2573
296.47156
245.52925
203.49213
168.77545
140.08128
116.34631
96.69847
80.421906
66.92818
55.7336
46.43998
38.719227
32.30098
26.96207
22.51808
18.816814
15.732281
13.160183
11.014235
9.222804
7.726504
6.476105
5.430629
4.556072
3.8241634
3.2113106
2.697968
2.2677352
1.9070276
1.6044858
1.3506162
1.1375304
0.95858437
0.8082756
0.6819602
0.5757848
0.4865
0.41141036
0.3482236
0.29504904
0.2502865
0.2125924
0.18084738
0.15410051
0.13156472
0.11257393
0.096566215
0.08306889
0.071688905
0.062092837
0.05400071
0.04717503
0.0414176
0.036560718
0.032465078
0.029010946
0.02609632
0.023637494
0.021564374
0.019816428
0.018341115
0.017099304
0.016050981
0.015168404
0.014424893
0.013798906
0.013271661
0.012829266
0.012456546
0.012143085
0.01188058
0.011660062
0.011475269
0.011320574
0.