In [1]:
import jax
import jax.numpy as jnp
from jax import random, vmap
from flax import linen as nn
import optax

# Neural network

In [2]:
batch_size = 8
input_dim = 5
x = random.normal(key=random.PRNGKey(42), shape=(batch_size, input_dim))
y = random.normal(key=random.PRNGKey(42), shape=(batch_size, 1))

2023-09-15 20:32:36.919205: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_OUT_OF_MEMORY: out of memory; total memory reported: 25385107456
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [3]:
print(x.shape, y.shape)

(8, 5) (8, 1)


In [4]:
class JaxNet(nn.Module):
    def setup(self):
        self.W1 = nn.Dense(features=10)
        self.W2 = nn.Dense(features=1)
    
    def __call__(self, x):
        x = self.W1(x)
        x = nn.relu(x)
        x = self.W2(x)
        return x
    
def mse_loss(params, model, x, y):
    pred = model.apply(params, x)
    return (pred - y) ** 2


In [5]:
model = JaxNet()
params = model.init(random.PRNGKey(42), x[0])

optimizer = optax.adam(learning_rate=3e-4)
opt_state = optimizer.init(params)

# Using jvp and loops

In [6]:
flat_params, unravel_fn = jax.flatten_util.ravel_pytree(params)
jacobian_matrix = jnp.zeros((len(flat_params), 1))

In [7]:
evaluation_point = jnp.array([1.0, 0.5, 1.5, 2.0, 1.0])
evaluation_target = jnp.array([10])
basis_vector = jnp.zeros((len(flat_params)))
basis_vector = basis_vector.at[0].set(1.0)

primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, evaluation_point, evaluation_target), (flat_params,), (basis_vector,))
print(primal, tangent)
jacobian_matrix = jacobian_matrix.at[0].set(tangent)

[114.70784] [-8.955847]


In [8]:
flat_params.shape

(71,)

In [9]:
unravel_fn(jacobian_matrix)

{'params': {'W1': {'bias': Array([-8.955847,  0.      ,  0.      ,  0.      ,  0.      ,  0.      ,
           0.      ,  0.      ,  0.      ,  0.      ], dtype=float32),
   'kernel': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)},
  'W2': {'bias': Array([0.], dtype=float32),
   'kernel': Array([[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]], dtype=float32)}}}

In [10]:
def loop_jvp(flat_params, model, x_point, y_point):
    basis_vector = jnp.zeros((len(flat_params)))
    jacobian_matrix = jnp.zeros((len(flat_params), 1))
    for i in range(len(flat_params)):
        e_i = basis_vector.at[i].set(1.0)
        primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, x_point, y_point), (flat_params,), (e_i,))
        jacobian_matrix = jacobian_matrix.at[i].set(tangent)
    return jacobian_matrix

In [11]:
jacobian_sum = jnp.zeros((len(flat_params), 1))
l = []
for i in range(batch_size):
    jacobian = loop_jvp(flat_params, model, x[i], y[i])
    jacobian_sum += jacobian
    l.append(jacobian)
    #jacobians = vmap(loop_jvp, in_axes=(None, None, 0, 0), out_axes=0)(flat_params, model, x, y)
# jacobian_matrix /= len(flat_params)

In [12]:
#jnp.sum(jnp.array(l), 0)

Array([[ 2.62686777e+00],
       [-5.24947122e-02],
       [-1.51924896e+00],
       [ 5.27301896e-03],
       [ 1.26865625e-01],
       [-3.10725188e+00],
       [ 3.30611289e-01],
       [ 1.31503940e-02],
       [-9.10157487e-02],
       [-3.65455389e+00],
       [ 2.60141182e+00],
       [ 7.21702864e-03],
       [-1.21777129e+00],
       [ 2.87887221e-03],
       [-2.91166365e-01],
       [-3.10612392e+00],
       [ 4.27514523e-01],
       [ 6.18888974e-01],
       [-1.17692761e-01],
       [-3.61913967e+00],
       [-2.63472939e+00],
       [-3.64475511e-02],
       [ 1.10690355e+00],
       [-1.68364972e-03],
       [ 1.05206363e-01],
       [ 2.11527300e+00],
       [ 3.41741666e-02],
       [-9.02291238e-02],
       [-9.40798968e-03],
       [ 3.66549134e+00],
       [-2.34140325e+00],
       [-3.72017436e-02],
       [ 1.09190488e+00],
       [-1.56823487e-03],
       [-4.84190583e-01],
       [ 1.88298571e+00],
       [ 6.43614411e-01],
       [ 4.68558490e-01],
       [-1.7

In [13]:
unravel_fn(jacobian_sum)

{'params': {'W1': {'bias': Array([ 2.6268678 , -0.05249471, -1.519249  ,  0.00527302,  0.12686563,
          -3.107252  ,  0.3306113 ,  0.0131504 , -0.09101575, -3.654554  ],      dtype=float32),
   'kernel': Array([[ 2.60141182e+00,  7.21702864e-03, -1.21777129e+00,
            2.87887221e-03, -2.91166365e-01, -3.10612392e+00,
            4.27514523e-01,  6.18888974e-01, -1.17692761e-01,
           -3.61913967e+00],
          [-2.63472939e+00, -3.64475511e-02,  1.10690355e+00,
           -1.68364972e-03,  1.05206355e-01,  2.11527300e+00,
            3.41741666e-02, -9.02291313e-02, -9.40798968e-03,
            3.66549134e+00],
          [-2.34140325e+00, -3.72017436e-02,  1.09190488e+00,
           -1.56823487e-03, -4.84190583e-01,  1.88298559e+00,
            6.43614411e-01,  4.68558490e-01, -1.77184075e-01,
            3.25740981e+00],
          [ 2.15046763e+00,  2.38137618e-02, -1.21183181e+00,
            5.86587982e-03,  2.58109301e-01, -3.15916657e+00,
            7.16617471e-0

In [38]:
def scan_batch(carry, xy):
    xi, yi = xy
    jacobian = loop_jvp(flat_params, model, xi, yi)
    carry = carry + jacobian
    return carry, None

jacobian_init = jnp.zeros((len(flat_params), 1))
#xs = jnp.concatenate((x, y), axis=1)
xs = (x, y)
jacobian_matrix, _ = jax.lax.scan(scan_batch, jacobian_init, xs)

In [39]:
jacobian_matrix

Array([[ 2.62686777e+00],
       [-5.24947122e-02],
       [-1.51924896e+00],
       [ 5.27301896e-03],
       [ 1.26865625e-01],
       [-3.10725188e+00],
       [ 3.30611289e-01],
       [ 1.31504014e-02],
       [-9.10157487e-02],
       [-3.65455389e+00],
       [ 2.60141182e+00],
       [ 7.21702864e-03],
       [-1.21777129e+00],
       [ 2.87887221e-03],
       [-2.91166365e-01],
       [-3.10612392e+00],
       [ 4.27514523e-01],
       [ 6.18888974e-01],
       [-1.17692761e-01],
       [-3.61913967e+00],
       [-2.63472939e+00],
       [-3.64475511e-02],
       [ 1.10690355e+00],
       [-1.68364972e-03],
       [ 1.05206355e-01],
       [ 2.11527300e+00],
       [ 3.41741666e-02],
       [-9.02291313e-02],
       [-9.40798968e-03],
       [ 3.66549134e+00],
       [-2.34140325e+00],
       [-3.72017436e-02],
       [ 1.09190488e+00],
       [-1.56823487e-03],
       [-4.84190583e-01],
       [ 1.88298559e+00],
       [ 6.43614411e-01],
       [ 4.68558490e-01],
       [-1.7

In [40]:
unravel_fn(jacobian_matrix)

{'params': {'W1': {'bias': Array([ 2.6268678 , -0.05249471, -1.519249  ,  0.00527302,  0.12686563,
          -3.107252  ,  0.3306113 ,  0.0131504 , -0.09101575, -3.654554  ],      dtype=float32),
   'kernel': Array([[ 2.60141182e+00,  7.21702864e-03, -1.21777129e+00,
            2.87887221e-03, -2.91166365e-01, -3.10612392e+00,
            4.27514523e-01,  6.18888974e-01, -1.17692761e-01,
           -3.61913967e+00],
          [-2.63472939e+00, -3.64475511e-02,  1.10690355e+00,
           -1.68364972e-03,  1.05206355e-01,  2.11527300e+00,
            3.41741666e-02, -9.02291313e-02, -9.40798968e-03,
            3.66549134e+00],
          [-2.34140325e+00, -3.72017436e-02,  1.09190488e+00,
           -1.56823487e-03, -4.84190583e-01,  1.88298559e+00,
            6.43614411e-01,  4.68558490e-01, -1.77184075e-01,
            3.25740981e+00],
          [ 2.15046763e+00,  2.38137618e-02, -1.21183181e+00,
            5.86587982e-03,  2.58109301e-01, -3.15916657e+00,
            7.16617471e-0

In [61]:
def scan_jvp_old(flat_params, model, x_point, y_point):
    basis_vector = jnp.zeros((len(flat_params)))
    jacobian_matrix = jnp.zeros((len(flat_params), 1))
    for i in range(len(flat_params)):
        e_i = basis_vector.at[i].set(1.0)
        primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, x_point, y_point), (flat_params,), (e_i,))
        jacobian_matrix = jacobian_matrix.at[i].set(tangent)
    return jacobian_matrix

def scan_jvp(carry, xs):
    jacobian_sum, x_point, y_point  = carry
    param_idx = xs
    basis_vector = jnp.zeros((len(flat_params)))
    e_i = basis_vector.at[param_idx].set(1.0)
    primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, x_point, y_point), (flat_params,), (e_i,))
    # jacobian_sum += tangent * e_i[:, None]
    jacobian_sum = jacobian_sum.at[param_idx].set(tangent)
    return (jacobian_sum, x_point, y_point), None

def scan_foo(carry, xs):
    jacobian_sum, param_idxes = carry
    x_point, y_point = xs
    (jacobian, x_point, y_point), _ = jax.lax.scan(scan_jvp, (jacobian_sum, x_point, y_point), param_idxes)
    jacobian_sum += jacobian
    return (jacobian_sum, param_index), None

param_index = jnp.arange(flat_params.shape[0], dtype=int)
jacobian_init = jnp.zeros((len(flat_params), 1))

(jacobian_matrix, _), _ = jax.lax.scan(scan_foo, (jacobian_init, param_index), (x, y))

In [66]:
unravel_fn(jacobian_matrix)

{'params': {'W1': {'bias': Array([ 2.6268678 , -0.05249471, -1.519249  ,  0.00527302,  0.12686563,
          -3.107252  ,  0.3306113 ,  0.0131504 , -0.09101575, -3.654554  ],      dtype=float32),
   'kernel': Array([[ 2.60141182e+00,  7.21702864e-03, -1.21777129e+00,
            2.87887221e-03, -2.91166365e-01, -3.10612392e+00,
            4.27514523e-01,  6.18888974e-01, -1.17692761e-01,
           -3.61913967e+00],
          [-2.63472939e+00, -3.64475511e-02,  1.10690355e+00,
           -1.68364972e-03,  1.05206355e-01,  2.11527300e+00,
            3.41741666e-02, -9.02291313e-02, -9.40798968e-03,
            3.66549134e+00],
          [-2.34140325e+00, -3.72017436e-02,  1.09190488e+00,
           -1.56823487e-03, -4.84190583e-01,  1.88298559e+00,
            6.43614411e-01,  4.68558490e-01, -1.77184075e-01,
            3.25740981e+00],
          [ 2.15046763e+00,  2.38137618e-02, -1.21183181e+00,
            5.86587982e-03,  2.58109301e-01, -3.15916657e+00,
            7.16617471e-0

In [63]:
jacobian_sum = jnp.zeros((len(flat_params), 1))
for i in range(batch_size):
    param_index = jnp.arange(flat_params.shape[0], dtype=int)
    jacobian_init = jnp.zeros((len(flat_params), 1))
    (jacobian, x_point, y_point), _ = jax.lax.scan(scan_jvp, (jacobian_init, x[i], y[i]), param_index)
    jacobian_sum += jacobian

In [65]:
unravel_fn(jacobian_sum)

{'params': {'W1': {'bias': Array([ 2.6268678 , -0.05249471, -1.519249  ,  0.00527302,  0.12686563,
          -3.107252  ,  0.3306113 ,  0.0131504 , -0.09101575, -3.654554  ],      dtype=float32),
   'kernel': Array([[ 2.60141182e+00,  7.21702864e-03, -1.21777129e+00,
            2.87887221e-03, -2.91166365e-01, -3.10612392e+00,
            4.27514523e-01,  6.18888974e-01, -1.17692761e-01,
           -3.61913967e+00],
          [-2.63472939e+00, -3.64475511e-02,  1.10690355e+00,
           -1.68364972e-03,  1.05206355e-01,  2.11527300e+00,
            3.41741666e-02, -9.02291313e-02, -9.40798968e-03,
            3.66549134e+00],
          [-2.34140325e+00, -3.72017436e-02,  1.09190488e+00,
           -1.56823487e-03, -4.84190583e-01,  1.88298559e+00,
            6.43614411e-01,  4.68558490e-01, -1.77184075e-01,
            3.25740981e+00],
          [ 2.15046763e+00,  2.38137618e-02, -1.21183181e+00,
            5.86587982e-03,  2.58109301e-01, -3.15916657e+00,
            7.16617471e-0

In [22]:
jacobian_matrix[param_index[0]]

Array([0.], dtype=float32)

In [23]:
unravel_fn(jacobian_matrix)

{'params': {'W1': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
   'kernel': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)},
  'W2': {'bias': Array([0.], dtype=float32),
   'kernel': Array([[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]], dtype=float32)}}}

In [24]:
unravel_fn(jacobian_matrix)

{'params': {'W1': {'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),
   'kernel': Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)},
  'W2': {'bias': Array([0.], dtype=float32),
   'kernel': Array([[0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.],
          [0.]], dtype=float32)}}}

# Using jacfwd

In [25]:
loss_jacobian = jax.jacfwd(mse_loss, argnums=0)(params, model, evaluation_point, evaluation_target)
loss_jacobian

{'params': {'W1': {'bias': Array([[-8.955847  , -0.        , -0.        , -0.02139482,  4.9518595 ,
            7.8320136 , -6.778864  , -3.852275  , -0.        , 12.459564  ]],      dtype=float32),
   'kernel': Array([[[-8.9558468e+00, -0.0000000e+00, -0.0000000e+00, -2.1394815e-02,
             4.9518595e+00,  7.8320136e+00, -6.7788639e+00, -3.8522749e+00,
            -0.0000000e+00,  1.2459564e+01],
           [-4.4779234e+00, -0.0000000e+00, -0.0000000e+00, -1.0697408e-02,
             2.4759297e+00,  3.9160068e+00, -3.3894320e+00, -1.9261374e+00,
            -0.0000000e+00,  6.2297821e+00],
           [-1.3433770e+01, -0.0000000e+00, -0.0000000e+00, -3.2092221e-02,
             7.4277892e+00,  1.1748021e+01, -1.0168296e+01, -5.7784123e+00,
            -0.0000000e+00,  1.8689344e+01],
           [-1.7911694e+01, -0.0000000e+00, -0.0000000e+00, -4.2789631e-02,
             9.9037189e+00,  1.5664027e+01, -1.3557728e+01, -7.7045498e+00,
            -0.0000000e+00,  2.4919128e+01],
   

In [26]:
loss_jacobian["params"]["W1"]["kernel"].shape

(1, 5, 10)

In [27]:
gradients = jax.tree_map(lambda x: jnp.mean(x, axis=(0,1)), loss_jacobian)
gradients

{'params': {'W1': {'bias': Array(0.56350565, dtype=float32),
   'kernel': Array([-10.747017  ,   0.        ,   0.        ,  -0.02567378,
            5.942231  ,   9.3984165 ,  -8.134636  ,  -4.6227303 ,
            0.        ,  14.951477  ], dtype=float32)},
  'W2': {'bias': Array(-21.42035, dtype=float32),
   'kernel': Array([-8.777163], dtype=float32)}}}

In [28]:
gradients["params"]["W1"]["kernel"].shape

(10,)

In [29]:
loss_jacobian = vmap(jax.jacfwd(mse_loss, argnums=0), in_axes=(None, None, 0, 0), out_axes=0)(params, model, x, y)
loss_jacobian

{'params': {'W1': {'bias': Array([[[-0.66623163, -0.        , -0.        , -0.        ,
             0.36837226, -0.        , -0.        , -0.28657338,
            -0.        ,  0.9268756 ]],
   
          [[-0.        , -0.05249475, -0.        , -0.        ,
             0.02165244, -0.        , -0.02964118, -0.01684441,
             0.00816008, -0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
            -0.26315907,  0.        ,  0.36025247,  0.20472331,
            -0.09917583,  0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.        , -0.46606568,  0.        ,  0.2292403 ,
             0.        ,  0.        ]],
   
          [[-0.        , -0.        , -0.        , -0.        ,
            -0.        ,  0.47976726, -0.        , -0.23597959,
            -0.        , -0.        ]],
   
          [[ 0.        ,  0.        ,  0.        ,  0.        ,
             0.        , -0.24109204,  0.        ,  

In [30]:
loss_jacobian["params"]["W2"]["kernel"].mean((0,1,2))

Array([0.1828678], dtype=float32)

In [31]:
gradients = jax.tree_map(lambda x: jnp.sum(x, axis=(0, 1,2)), loss_jacobian)
gradients

{'params': {'W1': {'bias': Array(-5.321797, dtype=float32),
   'kernel': Array([-0.7296927 , -0.05353004, -0.1413371 ,  0.00319662, -0.3916064 ,
          -1.7406904 ,  1.3203006 ,  1.0290506 , -0.36347264,  1.0151641 ],      dtype=float32)},
  'W2': {'bias': Array(7.949466, dtype=float32),
   'kernel': Array([14.629423], dtype=float32)}}}

In [32]:
updates, opt_state = optimizer.update(loss_jacobian, opt_state)

In [33]:
params = optax.apply_updates(params, updates)
params

{'params': {'W1': {'bias': Array([[[ 0.0003,  0.    ,  0.    ,  0.    , -0.0003,  0.    ,  0.    ,
             0.0003,  0.    , -0.0003]],
   
          [[ 0.    ,  0.0003,  0.    ,  0.    , -0.0003,  0.    ,  0.0003,
             0.0003, -0.0003,  0.    ]],
   
          [[ 0.    ,  0.    ,  0.    ,  0.    ,  0.0003,  0.    , -0.0003,
            -0.0003,  0.0003,  0.    ]],
   
          [[ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.0003,  0.    ,
            -0.0003,  0.    ,  0.    ]],
   
          [[ 0.    ,  0.    ,  0.    ,  0.    ,  0.    , -0.0003,  0.    ,
             0.0003,  0.    ,  0.    ]],
   
          [[ 0.    ,  0.    ,  0.    ,  0.    ,  0.    ,  0.0003,  0.    ,
            -0.0003,  0.    ,  0.    ]],
   
          [[-0.0003,  0.    ,  0.0003,  0.    ,  0.    ,  0.0003,  0.    ,
             0.    ,  0.    ,  0.0003]],
   
          [[-0.0003,  0.    ,  0.0003, -0.0003,  0.    ,  0.0003,  0.    ,
             0.    ,  0.    ,  0.0003]]], dtype=float32),
   

In [34]:
gradients

{'params': {'W1': {'bias': Array(-5.321797, dtype=float32),
   'kernel': Array([-0.7296927 , -0.05353004, -0.1413371 ,  0.00319662, -0.3916064 ,
          -1.7406904 ,  1.3203006 ,  1.0290506 , -0.36347264,  1.0151641 ],      dtype=float32)},
  'W2': {'bias': Array(7.949466, dtype=float32),
   'kernel': Array([14.629423], dtype=float32)}}}

In [35]:
model2 = JaxNet()
params2 = model.init(random.PRNGKey(42), x[0])

optimizer_def = optax.adam(learning_rate=1e-3)
optimizer = optimizer_def.init(params2)

def loss_fn(params2, x, y):
    preds = model.apply(params2, x)
    loss = jnp.mean((preds - y)**2)
    return loss

grad_fn = jax.jit(jax.grad(loss_fn))

grads = grad_fn(params2, x, y)

In [36]:
grads

{'params': {'W1': {'bias': Array([ 0.32835847, -0.00656184, -0.18990612,  0.00065913,  0.0158582 ,
          -0.3884065 ,  0.04132641,  0.0016438 , -0.01137697, -0.45681924],      dtype=float32),
   'kernel': Array([[ 3.25176507e-01,  9.02128522e-04, -1.52221411e-01,
            3.59859027e-04, -3.63957956e-02, -3.88265491e-01,
            5.34393154e-02,  7.73611218e-02, -1.47115951e-02,
           -4.52392459e-01],
          [-3.29341173e-01, -4.55594389e-03,  1.38362929e-01,
           -2.10456201e-04,  1.31507935e-02,  2.64409125e-01,
            4.27177083e-03, -1.12786386e-02, -1.17599871e-03,
            4.58186418e-01],
          [-2.92675406e-01, -4.65021748e-03,  1.36488110e-01,
           -1.96029345e-04, -6.05238229e-02,  2.35373229e-01,
            8.04518089e-02,  5.85698113e-02, -2.21480075e-02,
            4.07176226e-01],
          [ 2.68808454e-01,  2.97672022e-03, -1.51478976e-01,
            7.33235036e-04,  3.22636627e-02, -3.94895792e-01,
            8.95771838e-0

In [37]:
updates, optimizer_state = optimizer.update(grads, optimizer.state)
updates

AttributeError: 'tuple' object has no attribute 'update'

In [None]:
gradients["params"]["W2"]["kernel"] == params2["params"]["W2"]["kernel"]