In [75]:
import jax
import jax.numpy as jnp
from jax import random, vmap
from flax import linen as nn
import optax

# Neural network

In [76]:
batch_size = 8
input_dim = 5
x = random.normal(key=random.PRNGKey(42), shape=(batch_size, input_dim))
y = random.normal(key=random.PRNGKey(42), shape=(batch_size, 1))

In [77]:
print(x.shape, y.shape)

(8, 5) (8, 1)


In [78]:
class JaxNet(nn.Module):
    def setup(self):
        self.W1 = nn.Dense(features=10)
        self.W2 = nn.Dense(features=1)
    
    def __call__(self, x):
        x = self.W1(x)
        x = nn.relu(x)
        x = self.W2(x)
        return x
    
def mse_loss(params, model, x, y):
    pred = model.apply(params, x)
    return (pred - y) ** 2


In [79]:
model = JaxNet()
params = model.init(random.PRNGKey(42), x[0])

optimizer = optax.adam(learning_rate=3e-4)
opt_state = optimizer.init(params)

# Using jvp and loops

In [80]:
flat_params, unravel_fn = jax.flatten_util.ravel_pytree(params)
jacobian_matrix = jnp.zeros((len(flat_params), 1))

In [81]:
evaluation_point = jnp.array([1.0, 0.5, 1.5, 2.0, 1.0])
evaluation_target = jnp.array([10])
basis_vector = jnp.zeros((len(flat_params)))
basis_vector = basis_vector.at[0].set(1.0)

primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, evaluation_point, evaluation_target), (flat_params,), (basis_vector,))
print(primal, tangent)
jacobian_matrix = jacobian_matrix.at[0].set(tangent)

[114.70784] [-8.955848]


In [82]:
flat_params.shape

(71,)

In [83]:
unravel_fn(jacobian_matrix)

{'params': {'W1': {'bias': Array([-8.955848,  0.      ,  0.      ,  0.      ,  0.      ,  0.      ,
           0.      ,  0.      ,  0.      ,  0.      ], dtype=float32),
   'kernel': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)},
  'W2': {'bias': Array([0.], dtype=float32),
   'kernel': Array([[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]], dtype=float32)}}}

In [84]:
def loop_jvp(flat_params, model, x_point, y_point):
    basis_vector = jnp.zeros((len(flat_params)))
    jacobian_matrix = jnp.zeros((len(flat_params), 1))
    for i in range(len(flat_params)):
        e_i = basis_vector.at[i].set(1.0)
        primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, x_point, y_point), (flat_params,), (e_i,))
        jacobian_matrix = jacobian_matrix.at[i].set(tangent)
    return jacobian_matrix

In [85]:
jacobian_sum = jnp.zeros((len(flat_params), 1))
l = []
for i in range(batch_size):
    jacobian = loop_jvp(flat_params, model, x[i], y[i])
    jacobian_sum += jacobian
    l.append(jacobian)
    #jacobians = vmap(loop_jvp, in_axes=(None, None, 0, 0), out_axes=0)(flat_params, model, x, y)
# jacobian_matrix /= len(flat_params)

In [86]:
jnp.sum(jnp.array(l), 0)

Array([[ 2.62686777e+00],
       [-5.24946414e-02],
       [-1.51924932e+00],
       [ 5.27300173e-03],
       [ 1.26865715e-01],
       [-3.10725236e+00],
       [ 3.30611318e-01],
       [ 1.31503642e-02],
       [-9.10157487e-02],
       [-3.65455437e+00],
       [ 2.60141230e+00],
       [ 7.21701840e-03],
       [-1.21777153e+00],
       [ 2.87886267e-03],
       [-2.91166335e-01],
       [-3.10612392e+00],
       [ 4.27514523e-01],
       [ 6.18889093e-01],
       [-1.17692746e-01],
       [-3.61913991e+00],
       [-2.63472939e+00],
       [-3.64475027e-02],
       [ 1.10690379e+00],
       [-1.68364414e-03],
       [ 1.05206370e-01],
       [ 2.11527300e+00],
       [ 3.41741964e-02],
       [-9.02291387e-02],
       [-9.40799620e-03],
       [ 3.66549182e+00],
       [-2.34140301e+00],
       [-3.72016914e-02],
       [ 1.09190500e+00],
       [-1.56822964e-03],
       [-4.84190553e-01],
       [ 1.88298559e+00],
       [ 6.43614471e-01],
       [ 4.68558580e-01],
       [-1.7

In [87]:
unravel_fn(jacobian_sum)

{'params': {'W1': {'bias': Array([ 2.6268678 , -0.05249464, -1.5192493 ,  0.005273  ,  0.12686571,
          -3.1072524 ,  0.33061132,  0.01315036, -0.09101575, -3.6545544 ],      dtype=float32),
   'kernel': Array([[ 2.60141230e+00,  7.21701840e-03, -1.21777153e+00,
            2.87886267e-03, -2.91166335e-01, -3.10612392e+00,
            4.27514523e-01,  6.18889093e-01, -1.17692746e-01,
           -3.61913991e+00],
          [-2.63472939e+00, -3.64475027e-02,  1.10690379e+00,
           -1.68364414e-03,  1.05206370e-01,  2.11527300e+00,
            3.41741964e-02, -9.02291387e-02, -9.40799620e-03,
            3.66549182e+00],
          [-2.34140301e+00, -3.72016914e-02,  1.09190500e+00,
           -1.56822964e-03, -4.84190553e-01,  1.88298559e+00,
            6.43614471e-01,  4.68558580e-01, -1.77184060e-01,
            3.25740957e+00],
          [ 2.15046787e+00,  2.38137301e-02, -1.21183228e+00,
            5.86586073e-03,  2.58109331e-01, -3.15916705e+00,
            7.16615422e-0

In [88]:
def scan_batch(carry, xy):
    xi, yi = xy
    jacobian = loop_jvp(flat_params, model, xi, yi)
    carry = carry + jacobian
    return carry, None

jacobian_init = jnp.zeros((len(flat_params), 1))
#xs = jnp.concatenate((x, y), axis=1)
xs = (x, y)
jacobian_matrix, _ = jax.lax.scan(scan_batch, jacobian_init, xs)

In [89]:
jacobian_matrix

Array([[ 2.62686777e+00],
       [-5.24946414e-02],
       [-1.51924932e+00],
       [ 5.27300173e-03],
       [ 1.26865715e-01],
       [-3.10725236e+00],
       [ 3.30611318e-01],
       [ 1.31503642e-02],
       [-9.10157487e-02],
       [-3.65455437e+00],
       [ 2.60141230e+00],
       [ 7.21701840e-03],
       [-1.21777153e+00],
       [ 2.87886267e-03],
       [-2.91166335e-01],
       [-3.10612392e+00],
       [ 4.27514523e-01],
       [ 6.18889093e-01],
       [-1.17692746e-01],
       [-3.61913991e+00],
       [-2.63472939e+00],
       [-3.64475027e-02],
       [ 1.10690379e+00],
       [-1.68364414e-03],
       [ 1.05206370e-01],
       [ 2.11527300e+00],
       [ 3.41741964e-02],
       [-9.02291387e-02],
       [-9.40799620e-03],
       [ 3.66549182e+00],
       [-2.34140301e+00],
       [-3.72016914e-02],
       [ 1.09190500e+00],
       [-1.56822964e-03],
       [-4.84190553e-01],
       [ 1.88298559e+00],
       [ 6.43614471e-01],
       [ 4.68558580e-01],
       [-1.7

In [90]:
unravel_fn(jacobian_matrix)

{'params': {'W1': {'bias': Array([ 2.6268678 , -0.05249464, -1.5192493 ,  0.005273  ,  0.12686571,
          -3.1072524 ,  0.33061132,  0.01315036, -0.09101575, -3.6545544 ],      dtype=float32),
   'kernel': Array([[ 2.60141230e+00,  7.21701840e-03, -1.21777153e+00,
            2.87886267e-03, -2.91166335e-01, -3.10612392e+00,
            4.27514523e-01,  6.18889093e-01, -1.17692746e-01,
           -3.61913991e+00],
          [-2.63472939e+00, -3.64475027e-02,  1.10690379e+00,
           -1.68364414e-03,  1.05206370e-01,  2.11527300e+00,
            3.41741964e-02, -9.02291387e-02, -9.40799620e-03,
            3.66549182e+00],
          [-2.34140301e+00, -3.72016914e-02,  1.09190500e+00,
           -1.56822964e-03, -4.84190553e-01,  1.88298559e+00,
            6.43614471e-01,  4.68558580e-01, -1.77184060e-01,
            3.25740957e+00],
          [ 2.15046787e+00,  2.38137301e-02, -1.21183228e+00,
            5.86586073e-03,  2.58109331e-01, -3.15916705e+00,
            7.16615422e-0

In [113]:
def scan_jvp_old(flat_params, model, x_point, y_point):
    basis_vector = jnp.zeros((len(flat_params)))
    jacobian_matrix = jnp.zeros((len(flat_params), 1))
    for i in range(len(flat_params)):
        e_i = basis_vector.at[i].set(1.0)
        primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, x_point, y_point), (flat_params,), (e_i,))
        jacobian_matrix = jacobian_matrix.at[i].set(tangent)
    return jacobian_matrix

def scan_jvp(carry, xs):
    jacobian_sum, x_point, y_point  = carry
    param_idx = xs.squeeze()
    basis_vector = jnp.zeros((len(flat_params)))
    e_i = basis_vector.at[param_idx].set(1.0)
    primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, x_point, y_point), (flat_params,), (e_i,))
    #jacobian_sum += tangent * e_i[:, None]
    jacobian_sum.at[param_idx].set(tangent)
    return (jacobian_sum, x_point, y_point), None

def scan_foo(carry, xs):
    jacobian_sum, param_idxes = carry
    x_point, y_point = xs
    (jacobian, x_point, y_point), _ = jax.lax.scan(scan_jvp, (jacobian_sum, x_point, y_point), param_idxes)
    jacobian_sum += jacobian
    return (jacobian_sum, param_index), None

param_index = jnp.arange(flat_params.shape[0], dtype=int)[:,None]
jacobian_init = jnp.zeros((len(flat_params), 1))

(jacobian_matrix, _), _ = jax.lax.scan(scan_foo, (jacobian_init, param_index), (x, y))

In [114]:
unravel_fn(jacobian_matrix)

{'params': {'W1': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
   'kernel': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)},
  'W2': {'bias': Array([0.], dtype=float32),
   'kernel': Array([[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]], dtype=float32)}}}

In [115]:
jacobian_sum = jnp.zeros((len(flat_params), 1))
for i in range(batch_size):
    param_index = jnp.arange(flat_params.shape[0], dtype=int)
    jacobian_init = jnp.zeros((len(flat_params), 1))
    (jacobian, x_point, y_point), _ = jax.lax.scan(scan_jvp, (jacobian_init, x[i], y[i]), param_index)
    jacobian_sum += jacobian

In [118]:
x_point

Array([ 0.5459628 , -0.3192952 , -0.2974074 ,  1.112433  , -0.43547055],      dtype=float32)

In [117]:
unravel_fn(jacobian)

{'params': {'W1': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
   'kernel': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)},
  'W2': {'bias': Array([0.], dtype=float32),
   'kernel': Array([[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]], dtype=float32)}}}

In [95]:
jacobian_matrix[param_index[0]]

Array([0.], dtype=float32)

In [96]:
unravel_fn(jacobian_matrix)

{'params': {'W1': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
   'kernel': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)},
  'W2': {'bias': Array([0.], dtype=float32),
   'kernel': Array([[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]], dtype=float32)}}}

In [97]:
unravel_fn(jacobian_matrix)

{'params': {'W1': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
   'kernel': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)},
  'W2': {'bias': Array([0.], dtype=float32),
   'kernel': Array([[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]], dtype=float32)}}}

# Using jacfwd

In [98]:
loss_jacobian = jax.jacfwd(mse_loss, argnums=0)(params, model, evaluation_point, evaluation_target)
loss_jacobian

{'params': {'W1': {'bias': Array([[-8.955848  , -0.        , -0.        , -0.02139474,  4.95186   ,
            7.832014  , -6.778865  , -3.8522758 , -0.        , 12.459565  ]],      dtype=float32),
   'kernel': Array([[[-8.9558477e+00, -0.0000000e+00, -0.0000000e+00, -2.1394743e-02,
             4.9518600e+00,  7.8320141e+00, -6.7788649e+00, -3.8522758e+00,
            -0.0000000e+00,  1.2459565e+01],
           [-4.4779239e+00, -0.0000000e+00, -0.0000000e+00, -1.0697371e-02,
             2.4759300e+00,  3.9160070e+00, -3.3894324e+00, -1.9261379e+00,
            -0.0000000e+00,  6.2297826e+00],
           [-1.3433770e+01, -0.0000000e+00, -0.0000000e+00, -3.2092113e-02,
             7.4277902e+00,  1.1748021e+01, -1.0168297e+01, -5.7784138e+00,
            -0.0000000e+00,  1.8689348e+01],
           [-1.7911695e+01, -0.0000000e+00, -0.0000000e+00, -4.2789485e-02,
             9.9037199e+00,  1.5664028e+01, -1.3557730e+01, -7.7045517e+00,
            -0.0000000e+00,  2.4919130e+01],
   

In [99]:
loss_jacobian["params"]["W1"]["kernel"].shape

(1, 5, 10)

In [100]:
gradients = jax.tree_map(lambda x: jnp.mean(x, axis=(0,1)), loss_jacobian)
gradients

{'params': {'W1': {'bias': Array(0.5635056, dtype=float32),
   'kernel': Array([-10.747017  ,   0.        ,   0.        ,  -0.02567369,
            5.942232  ,   9.398417  ,  -8.134639  ,  -4.622731  ,
            0.        ,  14.951479  ], dtype=float32)},
  'W2': {'bias': Array(-21.42035, dtype=float32),
   'kernel': Array([-8.777163], dtype=float32)}}}

In [101]:
gradients["params"]["W1"]["kernel"].shape

(10,)

In [102]:
loss_jacobian = vmap(jax.jacfwd(mse_loss, argnums=0), in_axes=(None, None, 0, 0), out_axes=0)(params, model, x, y)
loss_jacobian

{'params': {'W1': {'bias': Array([[[-0.66623175, -0.        , -0.        , -0.        ,
             0.36837232, -0.        , -0.        , -0.2865735 ,
            -0.        ,  0.92687577]],
   
          [[-0.        , -0.05249464, -0.        , -0.        ,
             0.02165239, -0.        , -0.02964111, -0.01684438,
             0.00816006, -0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
            -0.263159  ,  0.        ,  0.36025238,  0.2047233 ,
            -0.0991758 ,  0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.        , -0.46606576,  0.        ,  0.22924037,
             0.        ,  0.        ]],
   
          [[-0.        , -0.        , -0.        , -0.        ,
            -0.        ,  0.4797673 , -0.        , -0.23597965,
            -0.        , -0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.        , -0.24109201,  0.        ,  

In [103]:
loss_jacobian["params"]["W2"]["kernel"].mean((0,1,2))

Array([0.18286781], dtype=float32)

In [104]:
gradients = jax.tree_map(lambda x: jnp.sum(x, axis=(0, 1,2)), loss_jacobian)
gradients

{'params': {'W1': {'bias': Array(-5.3217983, dtype=float32),
   'kernel': Array([-0.72969264, -0.05352993, -0.14133751,  0.00319661, -0.39160612,
          -1.7406905 ,  1.3203003 ,  1.0290507 , -0.36347252,  1.0151639 ],      dtype=float32)},
  'W2': {'bias': Array(7.9494667, dtype=float32),
   'kernel': Array([14.629424], dtype=float32)}}}

In [105]:
updates, opt_state = optimizer.update(loss_jacobian, opt_state)

In [106]:
params = optax.apply_updates(params, updates)
params

{'params': {'W1': {'bias': Array([[[ 0.00030001,  0.        ,  0.        ,  0.        ,
            -0.00030001,  0.        ,  0.        ,  0.00030001,
             0.        , -0.00030001]],
   
          [[ 0.        ,  0.00030001,  0.        ,  0.        ,
            -0.00030001,  0.        ,  0.00030001,  0.00030001,
            -0.00030001,  0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.00030001,  0.        , -0.00030001, -0.00030001,
             0.00030001,  0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.        ,  0.00030001,  0.        , -0.00030001,
             0.        ,  0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.        , -0.00030001,  0.        ,  0.00030001,
             0.        ,  0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.        ,  0.00030001,  0.        , -

In [107]:
gradients

{'params': {'W1': {'bias': Array(-5.3217983, dtype=float32),
   'kernel': Array([-0.72969264, -0.05352993, -0.14133751,  0.00319661, -0.39160612,
          -1.7406905 ,  1.3203003 ,  1.0290507 , -0.36347252,  1.0151639 ],      dtype=float32)},
  'W2': {'bias': Array(7.9494667, dtype=float32),
   'kernel': Array([14.629424], dtype=float32)}}}

In [108]:
model2 = JaxNet()
params2 = model.init(random.PRNGKey(42), x[0])

optimizer_def = optax.adam(learning_rate=1e-3)
optimizer = optimizer_def.init(params2)

def loss_fn(params2, x, y):
    preds = model.apply(params2, x)
    loss = jnp.mean((preds - y)**2)
    return loss

grad_fn = jax.jit(jax.grad(loss_fn))

grads = grad_fn(params2, x, y)

In [109]:
grads

{'params': {'W1': {'bias': Array([ 0.32835847, -0.00656183, -0.18990617,  0.00065913,  0.01585821,
          -0.38840654,  0.04132641,  0.00164379, -0.01137697, -0.4568193 ],      dtype=float32),
   'kernel': Array([[ 3.2517651e-01,  9.0212730e-04, -1.5222144e-01,  3.5985786e-04,
           -3.6395788e-02, -3.8826552e-01,  5.3439308e-02,  7.7361137e-02,
           -1.4711591e-02, -4.5239249e-01],
          [-3.2934120e-01, -4.5559378e-03,  1.3836296e-01, -2.1045552e-04,
            1.3150794e-02,  2.6440915e-01,  4.2717736e-03, -1.1278642e-02,
           -1.1759993e-03,  4.5818645e-01],
          [-2.9267538e-01, -4.6502114e-03,  1.3648811e-01, -1.9602870e-04,
           -6.0523812e-02,  2.3537321e-01,  8.0451801e-02,  5.8569815e-02,
           -2.2148004e-02,  4.0717623e-01],
          [ 2.6880848e-01,  2.9767163e-03, -1.5147904e-01,  7.3323259e-04,
            3.2263666e-02, -3.9489585e-01,  8.9576939e-04,  2.7901260e-02,
           -2.4660115e-04, -3.7397209e-01],
          [-6.3180

In [110]:
updates, optimizer_state = optimizer.update(grads, optimizer.state)
updates

AttributeError: 'tuple' object has no attribute 'update'

In [None]:
gradients["params"]["W2"]["kernel"] == params2["params"]["W2"]["kernel"]