In [235]:
import numpy as onp
import jax.numpy as jnp
from jax import grad, jit, vmap, value_and_grad
import jax

# Generate key which is used to generate random numbers
init_rng_key = jax.random.PRNGKey(1)




In [236]:
n_samples, n_dims = 100, 20

x = jnp.array(onp.random.randn(n_samples, n_dims))
w_true = jnp.array(onp.random.randn(n_dims))
y_true = jnp.dot(x, w_true)
print(x.shape)
print(w_true.shape)
print(y_true.shape)
w_init = jnp.array(onp.random.randn(n_dims))
print(w.shape)
true_params = {"w": w_true}
params = {"w": w_init}

print(params)

(100, 20)
(20,)
(100,)
(20,)
{'w': DeviceArray([-0.5754662 , -1.6014344 ,  0.87189704,  0.61361986,
              0.23964736, -0.24000783, -0.8138886 , -0.78609693,
             -0.8947132 , -0.9696329 , -0.60194707, -0.9117358 ,
              0.1608854 ,  1.2640678 , -0.8562055 ,  1.8879801 ,
             -1.0424665 , -1.8190886 ,  1.7076255 , -1.9316444 ],            dtype=float32)}


In [215]:
x

DeviceArray([[-1.4514358 , -0.8476368 , -1.491045  , ..., -0.5829066 ,
              -2.2127461 ,  1.0088211 ],
             [ 0.58676153, -0.8094    , -0.48540178, ..., -0.4163852 ,
              -0.44798017,  0.9589344 ],
             [-0.990291  ,  0.6244451 ,  0.7188759 , ...,  0.9558566 ,
               0.26782858, -0.37546408],
             ...,
             [ 0.42144024, -0.5205132 ,  2.552735  , ...,  0.95577615,
               0.48141402, -0.12665674],
             [-0.82951343, -1.1792121 , -0.30609933, ...,  1.6789333 ,
               1.1527575 , -0.27306867],
             [ 1.1216229 ,  0.70948994, -1.017506  , ...,  0.39141247,
               0.6852563 , -0.16375953]], dtype=float32)

In [216]:
w_true

DeviceArray([-0.27517602, -0.87972355,  0.64689726, -1.4476798 ,
             -0.08543438,  0.59927756, -1.2518536 , -0.42516187,
             -0.87735945,  1.1498739 , -0.9430346 ,  0.59459394,
             -1.12569   ,  2.0530133 ,  0.71003985,  0.41845223,
              0.19534698, -0.1552832 ,  1.123011  , -0.02014724],            dtype=float32)

In [217]:
y_true

DeviceArray([ -1.6046159 ,  -1.1347995 ,  -0.05299273,   3.5625432 ,
               0.16229987,   4.9081597 ,   4.66898   ,   3.6767163 ,
              -3.4984467 ,  -5.0213184 , -10.533674  ,   1.4393327 ,
               1.931407  ,   5.1674695 ,   1.4575765 ,  -5.28413   ,
              -3.3337026 ,   3.2233655 ,   1.3536702 ,   1.4192331 ,
              -1.2004356 ,   2.4223394 ,   3.0290937 ,   0.20612033,
              -1.38643   ,   3.3155084 ,  -2.349163  ,   6.3354745 ,
               1.3867663 ,  -6.871331  ,  -2.4555058 ,   1.4839106 ,
              -6.140254  ,   4.957026  ,  -1.5521696 ,   1.8523246 ,
               2.6613998 ,  -1.4244708 ,  -0.6872809 ,  -6.7345295 ,
               2.9102345 ,   3.2274504 ,   0.9045044 ,  -5.465003  ,
              -4.377287  ,   2.4935548 ,  -0.92376435,  -3.0080094 ,
               3.1115081 , -10.174513  ,   5.9762163 ,  -3.847435  ,
               0.42686108,  -0.703447  ,   0.5415129 ,   5.6658487 ,
               5.1659036 ,   3.678

In [230]:
def forward(x, w):
    return jnp.dot(x, w) 

def mse(w, x, y_true):
    diff = y_true - jnp.dot(x, w)  
    return jnp.mean(diff**2)


def full_pass(w, x, y_true, lr):
    dloss = jax.grad(mse)
    dw = dloss(w, x, y_true)

    return w - lr * dw
   
def opt(w, x, y_true, lr=0.1, iters=100):
    for i in range(iters):
        old_w = w
        w = full_pass(w, x, y_true, lr=lr)
        y_pred = forward(x, w)
        print("diff in outputs", ((y_pred-y_true) ** 2).mean())
        print("diff in weights", ((w - old_w) ** 2).mean())
    return params


In [231]:
new_w = opt(w, x, y_true)


diff in outputs 9.3014965
diff in weights 0.04971885
diff in outputs 5.329583
diff in weights 0.02337109
diff in outputs 3.328307
diff in weights 0.01161578
diff in outputs 2.2454681
diff in weights 0.006184096
diff in outputs 1.6115594
diff in weights 0.0035582788
diff in outputs 1.2104039
diff in weights 0.002215139
diff in outputs 0.9384528
diff in weights 0.0014808502
diff in outputs 0.7436202
diff in weights 0.0010494071
diff in outputs 0.59814537
diff in weights 0.000777281
diff in outputs 0.48624828
diff in weights 0.00059443846
diff in outputs 0.39834636
diff in weights 0.00046505002
diff in outputs 0.32824776
diff in weights 0.00036975677
diff in outputs 0.27172935
diff in weights 0.00029746274
diff in outputs 0.2257824
diff in weights 0.00024141264
diff in outputs 0.18818833
diff in weights 0.00019725792
diff in outputs 0.1572693
diff in weights 0.00016205404
diff in outputs 0.13173111
diff in weights 0.00013372747
diff in outputs 0.11056087
diff in weights 0.00011076716
diff

In [239]:
def mse(w, x, y_true):
    return jnp.mean((y_true - jnp.dot(x, w)) ** 2)

dloss = jax.grad(mse)
lr = 0.1
for i in range(100):
    dw = dloss(w, x, y_true)
    w -= lr * dw
    

In [240]:
w - w_true

DeviceArray([-5.9604645e-07, -7.1525574e-07,  2.0861626e-07,
              1.1920929e-07,  5.9604645e-08, -4.7683716e-07,
             -8.3446503e-07,  1.1920929e-07, -2.3841858e-07,
              1.7881393e-07, -5.2154064e-07,  5.9604645e-07,
              3.5762787e-07, -3.5762787e-07,  5.3644180e-07,
             -2.3841858e-07,  1.0617077e-07,  3.5762787e-07,
             -5.9604645e-07, -9.8347664e-07], dtype=float32)

In [247]:
w = jnp.array(onp.random.randn(n_dims))

y_true = forward(x, w_true)

def forward(x, w):
    return jnp.tanh(jnp.dot(x, w))

def mse(w, x, y_true):
    return jnp.mean((y_true - forward(x,w)) ** 2)

dloss = jax.grad(mse)
lr = 0.5
for i in range(100):
    dw = dloss(w, x, y_true)
    w -= lr * dw
    print( ((y_true - forward(x, w))**2).mean())

1.1080879
1.0532216
1.0174576
0.98665637
0.95628023
0.9267193
0.90058535
0.87823457
0.85860944
0.840466
0.82188416
0.8022505
0.78295106
0.7631966
0.7405787
0.7158583
0.6916823
0.6688055
0.64536095
0.6186625
0.59225816
0.57421744
0.5644643
0.55686396
0.5496638
0.54237026
0.5352684
0.5292953
0.5249361
0.5217982
0.519377
0.5173796
0.51565146
0.51410115
0.51266485
0.5112913
0.509933
0.5085388
0.50704736
0.5053768
0.5034105
0.5009799
0.49786186
0.4938654
0.48916587
0.48469406
0.48141545
0.47921497
0.4775651
0.47617444
0.47492582
0.4737678
0.47267407
0.4716295
0.47062537
0.4696567
0.46872163
0.4678195
0.46695092
0.46611664
0.46531713
0.4645526
0.46382248
0.46312553
0.46245995
0.4618235
0.46121338
0.46062663
0.46006015
0.45951056
0.45897454
0.4584485
0.45792884
0.45741174
0.4568934
0.45636964
0.45583564
0.45528668
0.45471704
0.4541202
0.45348886
0.45281434
0.4520863
0.45129243
0.45041803
0.4494452
0.44835225
0.44711342
0.44569778
0.44406986
0.44218895
0.44001034
0.4374849
0.43455496
0.4311412

In [248]:
jax.grad??