In [45]:
import jax
import jax.numpy as jnp
from jax import random, vmap
from flax import linen as nn
import optax

In [46]:
def f(x):
    u = jnp.array([
        x[0]**6 * x[1]**4 * x[2]**9 * x[3]**2,
        x[0]**2 * x[1]**3 * x[2]**5 * x[3]**3,
        x[0]**5 * x[1]**7 * x[2]**7 * x[3]**6,
    ])
    return u

In [47]:
evaluation_point = jnp.array([1.0, 0.5, 1.5, 2.0])

In [48]:
f(evaluation_point)

Array([9.61084 , 7.59375 , 8.542969], dtype=float32)

In [49]:
jax.jacfwd(f)(evaluation_point)

Array([[ 57.66504 ,  76.88672 ,  57.66504 ,   9.61084 ],
       [ 15.1875  ,  45.5625  ,  25.3125  ,  11.390625],
       [ 42.714844, 119.60156 ,  39.867188,  25.628906]], dtype=float32)

In [50]:
multiplication_point = jnp.array([0.2, 0.3, 0.4, 0.8])
jax.jvp(f, (evaluation_point,), (multiplication_point,))

(Array([9.61084 , 7.59375 , 8.542969], dtype=float32),
 Array([65.353714, 35.943752, 80.87344 ], dtype=float32))

In [51]:
unit_vector = jnp.array([1., 0., 0., 0.])
jax.jvp(f, (evaluation_point,), (unit_vector,))

(Array([9.61084 , 7.59375 , 8.542969], dtype=float32),
 Array([57.66504 , 15.1875  , 42.714844], dtype=float32))

# Neural network

In [52]:
batch_size = 8
input_dim = 5
x = random.normal(key=random.PRNGKey(42), shape=(batch_size, input_dim))
y = random.normal(key=random.PRNGKey(42), shape=(batch_size, 1))

In [53]:
class JaxNet(nn.Module):
    def setup(self):
        self.W1 = nn.Dense(features=10)
        self.W2 = nn.Dense(features=1)
    
    def __call__(self, x):
        x = self.W1(x)
        x = nn.relu(x)
        x = self.W2(x)
        return x
    
model = JaxNet()
params = model.init(random.PRNGKey(42), x[0])

def mse_loss(params, model, x, y):
    pred = model.apply(params, x)
    return (pred - y)**2
    #return jnp.mean((pred - y)**2)

optimizer = optax.adam(learning_rate=3e-4)
opt_state = optimizer.init(params)

In [54]:
loss_jacobian = jax.jacfwd(mse_loss, argnums=0)(params, model, x, y)
loss_jacobian

{'params': {'W1': {'bias': Array([[[-0.66623175, -0.        , -0.        , -0.        ,
             0.36837232, -0.        , -0.        , -0.2865735 ,
            -0.        ,  0.92687577]],
   
          [[-0.        , -0.05249464, -0.        , -0.        ,
             0.02165239, -0.        , -0.02964111, -0.01684438,
             0.00816006, -0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
            -0.263159  ,  0.        ,  0.36025238,  0.2047233 ,
            -0.0991758 ,  0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.        , -0.46606576,  0.        ,  0.22924037,
             0.        ,  0.        ]],
   
          [[-0.        , -0.        , -0.        , -0.        ,
            -0.        ,  0.4797673 , -0.        , -0.23597965,
            -0.        , -0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.        , -0.24109201,  0.        ,  

In [55]:
loss_jacobian["params"]["W1"]["kernel"].shape

(8, 1, 5, 10)

In [56]:
gradients = jax.tree_map(lambda x: jnp.sum(x, axis=(0,1)), loss_jacobian)
gradients

{'params': {'W1': {'bias': Array([ 2.6268678 , -0.05249464, -1.5192493 ,  0.005273  ,  0.12686571,
          -3.1072524 ,  0.33061126,  0.01315033, -0.09101573, -3.6545544 ],      dtype=float32),
   'kernel': Array([[ 2.60141230e+00,  7.21701840e-03, -1.21777153e+00,
            2.87886267e-03, -2.91166306e-01, -3.10612392e+00,
            4.27514464e-01,  6.18889034e-01, -1.17692739e-01,
           -3.61913991e+00],
          [-2.63472939e+00, -3.64475027e-02,  1.10690379e+00,
           -1.68364414e-03,  1.05206355e-01,  2.11527300e+00,
            3.41741890e-02, -9.02291387e-02, -9.40799434e-03,
            3.66549182e+00],
          [-2.34140301e+00, -3.72016914e-02,  1.09190500e+00,
           -1.56822964e-03, -4.84190494e-01,  1.88298559e+00,
            6.43614411e-01,  4.68558550e-01, -1.77184030e-01,
            3.25740957e+00],
          [ 2.15046787e+00,  2.38137301e-02, -1.21183228e+00,
            5.86586073e-03,  2.58109331e-01, -3.15916705e+00,
            7.16615515e-0

In [57]:
gradients["params"]["W1"]["kernel"].shape

(5, 10)

In [58]:
loss_jacobian = vmap(jax.jacfwd(mse_loss, argnums=0), in_axes=(None, None, 0, 0), out_axes=0)(params, model, x, y)
loss_jacobian

{'params': {'W1': {'bias': Array([[[-0.66623175, -0.        , -0.        , -0.        ,
             0.36837232, -0.        , -0.        , -0.2865735 ,
            -0.        ,  0.92687577]],
   
          [[-0.        , -0.05249464, -0.        , -0.        ,
             0.02165239, -0.        , -0.02964111, -0.01684438,
             0.00816006, -0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
            -0.263159  ,  0.        ,  0.36025238,  0.2047233 ,
            -0.0991758 ,  0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.        , -0.46606576,  0.        ,  0.22924037,
             0.        ,  0.        ]],
   
          [[-0.        , -0.        , -0.        , -0.        ,
            -0.        ,  0.4797673 , -0.        , -0.23597965,
            -0.        , -0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.        , -0.24109201,  0.        ,  

In [68]:
loss_jacobian["params"]["W2"]["kernel"].shape

(8, 1, 10, 1)

In [69]:
gradients = jax.tree_map(lambda x: jnp.sum(x, axis=(0, 1)), loss_jacobian)
gradients

{'params': {'W1': {'bias': Array([ 2.6268678 , -0.05249464, -1.5192493 ,  0.005273  ,  0.12686571,
          -3.1072524 ,  0.33061126,  0.01315033, -0.09101573, -3.6545544 ],      dtype=float32),
   'kernel': Array([[ 2.60141230e+00,  7.21701840e-03, -1.21777153e+00,
            2.87886267e-03, -2.91166306e-01, -3.10612392e+00,
            4.27514464e-01,  6.18889034e-01, -1.17692739e-01,
           -3.61913991e+00],
          [-2.63472939e+00, -3.64475027e-02,  1.10690379e+00,
           -1.68364414e-03,  1.05206355e-01,  2.11527300e+00,
            3.41741890e-02, -9.02291387e-02, -9.40799434e-03,
            3.66549182e+00],
          [-2.34140301e+00, -3.72016914e-02,  1.09190500e+00,
           -1.56822964e-03, -4.84190494e-01,  1.88298559e+00,
            6.43614411e-01,  4.68558550e-01, -1.77184030e-01,
            3.25740957e+00],
          [ 2.15046787e+00,  2.38137301e-02, -1.21183228e+00,
            5.86586073e-03,  2.58109331e-01, -3.15916705e+00,
            7.16615515e-0

In [70]:
updates, opt_state = optimizer.update(loss_jacobian, opt_state)

AttributeError: 'tuple' object has no attribute 'update'

In [62]:
params = optax.apply_updates(params, updates)
params

{'params': {'W1': {'bias': Array([[[ 0.00030001,  0.        ,  0.        ,  0.        ,
            -0.00030001,  0.        ,  0.        ,  0.00030001,
             0.        , -0.00030001]],
   
          [[ 0.        ,  0.00030001,  0.        ,  0.        ,
            -0.00030001,  0.        ,  0.00030001,  0.00030001,
            -0.00030001,  0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.00030001,  0.        , -0.00030001, -0.00030001,
             0.00030001,  0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.        ,  0.00030001,  0.        , -0.00030001,
             0.        ,  0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.        , -0.00030001,  0.        ,  0.00030001,
             0.        ,  0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.        ,  0.00030001,  0.        , -

In [71]:
gradients

{'params': {'W1': {'bias': Array([ 2.6268678 , -0.05249464, -1.5192493 ,  0.005273  ,  0.12686571,
          -3.1072524 ,  0.33061126,  0.01315033, -0.09101573, -3.6545544 ],      dtype=float32),
   'kernel': Array([[ 2.60141230e+00,  7.21701840e-03, -1.21777153e+00,
            2.87886267e-03, -2.91166306e-01, -3.10612392e+00,
            4.27514464e-01,  6.18889034e-01, -1.17692739e-01,
           -3.61913991e+00],
          [-2.63472939e+00, -3.64475027e-02,  1.10690379e+00,
           -1.68364414e-03,  1.05206355e-01,  2.11527300e+00,
            3.41741890e-02, -9.02291387e-02, -9.40799434e-03,
            3.66549182e+00],
          [-2.34140301e+00, -3.72016914e-02,  1.09190500e+00,
           -1.56822964e-03, -4.84190494e-01,  1.88298559e+00,
            6.43614411e-01,  4.68558550e-01, -1.77184030e-01,
            3.25740957e+00],
          [ 2.15046787e+00,  2.38137301e-02, -1.21183228e+00,
            5.86586073e-03,  2.58109331e-01, -3.15916705e+00,
            7.16615515e-0

In [64]:
model2 = JaxNet()
params2 = model.init(random.PRNGKey(42), x[0])

optimizer_def = optax.adam(learning_rate=1e-3)
optimizer = optimizer_def.init(params2)

def loss_fn(params2, x, y):
    preds = model.apply(params2, x)
    loss = jnp.mean((preds - y)**2)
    return loss

grad_fn = jax.jit(jax.grad(loss_fn))

grads = grad_fn(params2, x, y)

In [65]:
grads

{'params': {'W1': {'bias': Array([ 0.32835847, -0.00656183, -0.18990617,  0.00065913,  0.01585821,
          -0.38840654,  0.04132641,  0.00164379, -0.01137697, -0.4568193 ],      dtype=float32),
   'kernel': Array([[ 3.2517651e-01,  9.0212730e-04, -1.5222144e-01,  3.5985786e-04,
           -3.6395788e-02, -3.8826552e-01,  5.3439308e-02,  7.7361137e-02,
           -1.4711591e-02, -4.5239249e-01],
          [-3.2934120e-01, -4.5559378e-03,  1.3836296e-01, -2.1045552e-04,
            1.3150794e-02,  2.6440915e-01,  4.2717736e-03, -1.1278642e-02,
           -1.1759993e-03,  4.5818645e-01],
          [-2.9267538e-01, -4.6502114e-03,  1.3648811e-01, -1.9602870e-04,
           -6.0523812e-02,  2.3537321e-01,  8.0451801e-02,  5.8569815e-02,
           -2.2148004e-02,  4.0717623e-01],
          [ 2.6880848e-01,  2.9767163e-03, -1.5147904e-01,  7.3323259e-04,
            3.2263666e-02, -3.9489585e-01,  8.9576939e-04,  2.7901260e-02,
           -2.4660115e-04, -3.7397209e-01],
          [-6.3180

In [66]:
updates, optimizer_state = optimizer.update(grads, optimizer.state)
updates

AttributeError: 'tuple' object has no attribute 'update'

In [None]:
gradients["params"]["W2"]["kernel"] == params2["params"]["W2"]["kernel"]