In [1]:
import jax
import jax.numpy as jnp
from jax import random, vmap
from flax import linen as nn
import optax

# Neural network

In [45]:
batch_size = 16
input_dim = 700
x = random.normal(key=random.PRNGKey(42), shape=(batch_size, input_dim))
y = random.normal(key=random.PRNGKey(42), shape=(batch_size, 1))

In [46]:
print(x.shape, y.shape)

(16, 700) (16, 1)


In [51]:
class JaxNet(nn.Module):
    def setup(self):
        #self.W1 = nn.Dense(features=754)
        self.W2 = nn.Dense(features=700)
        self.W3 = nn.Dense(features=754)
        self.W4 = nn.Dense(features=1)
    
    def __call__(self, x):
        #x = self.W1(x)
        #x = nn.relu(x)
        x = self.W2(x)
        x = self.W3(x)
        x = self.W4(x)
        return x
    
def mse_loss(params, model, x, y):
    pred = model.apply(params, x)
    return (pred - y) ** 2


In [52]:
model = JaxNet()
params = model.init(random.PRNGKey(42), x[0])

optimizer = optax.adam(learning_rate=3e-4)
opt_state = optimizer.init(params)

In [6]:
a = jnp.array([[1.,2.], [3., 4.]])
b = jax.nn.one_hot(1, a.size).reshape(a.shape)

def f(W):
    return jnp.sum(W)

jax.jvp(f, (a, ), (b, ))

(Array(10., dtype=float32), Array(1., dtype=float32))

# Using chunked jvp

In [55]:
gradients = jax.tree_map(lambda x: jnp.sum(x, axis=(0)), out)
gradients

{'params': {'W2': {'bias': Array([-2.61358678e-01,  2.31760576e-01, -8.12946141e-01, -1.07932937e+00,
           3.70122015e-01, -7.87013829e-01, -3.35031003e-01, -2.08256912e+00,
           4.37163502e-01, -7.42606521e-01, -7.92251706e-01,  4.02134508e-02,
           2.02033475e-01,  4.97104824e-02, -6.68506622e-02, -2.43432205e-02,
          -4.21549380e-01, -4.77232277e-01,  4.00792122e-01, -9.39639747e-01,
           9.02613401e-01,  8.28535557e-01,  4.28926274e-02,  1.20348394e+00,
          -7.95918927e-02, -2.25953400e-01,  6.31774068e-01, -1.92610487e-01,
           4.02806252e-01,  2.29138756e+00, -1.21205854e+00, -9.06038642e-01,
          -8.98797274e-01, -6.64001167e-01,  1.59861743e-01,  4.33588058e-01,
          -6.24723196e-01,  4.80631024e-01, -2.45005399e-01,  2.48027239e-02,
          -9.15537417e-01,  6.14193857e-01,  1.75818372e+00, -9.22377646e-01,
          -4.58616942e-01, -6.81993723e-01,  1.59832680e+00, -6.45181477e-01,
          -1.19969928e+00,  1.07032299e+

In [None]:
def jacobian_column(idx, flatparams):
    basis_vector = jax.nn.one_hot(idx, flatparams.size)
    primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, evaluation_point, evaluation_target), (flatparams,), (basis_vector,))
    return tangent

flat_params, unravel_fn = jax.flatten_util.ravel_pytree(params)
size = flat_params.size
new_params = jax.vmap(lambda x: jacobian_column(x, flat_params), in_axes=0)(jnp.arange(size))
unravel_fn(new_params)

# Using jvp vmap

In [None]:
jax.nn.one_hot(2, 79000)

In [None]:
evaluation_point = jnp.array([1.0, 0.5, 1.5, 2.0, 1.0])
evaluation_target = jnp.array([10.0])

model.apply(params, evaluation_point)

In [None]:
def jacobian_column(idx, flatparams):
    basis_vector = jax.nn.one_hot(idx, flatparams.size)
    primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, evaluation_point, evaluation_target), (flatparams,), (basis_vector,))
    return tangent

flat_params, unravel_fn = jax.flatten_util.ravel_pytree(params)
size = flat_params.size
new_params = jax.vmap(lambda x: jacobian_column(x, flat_params), in_axes=0)(jnp.arange(size))
unravel_fn(new_params)

In [None]:
def jacobian_column(idx, flatparams, xi, yi):
    basis_vector = jax.nn.one_hot(idx, flatparams.size)
    primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, xi, yi), (flatparams,), (basis_vector,))
    return tangent

flat_params, unravel_fn = jax.flatten_util.ravel_pytree(params)
size = flat_params.size

def sample_jacobian(xi, yi):
    new_params = jax.vmap(lambda idx: jacobian_column(idx, flat_params, xi, yi), in_axes=0)(jnp.arange(size))
    return unravel_fn(new_params)

#out = jax.vmap(jax.jacfwd(mse_loss, argnums=0), in_axes=(None, None, 0, 0), out_axes=0)(params, model, x, y)
out = jax.vmap(lambda xi, yi: sample_jacobian(xi, yi), in_axes=(0,0), out_axes=0)(x, y)
out

In [None]:
out["params"]["W2"]["kernel"].shape

In [None]:
gradients = jax.tree_map(lambda x: jnp.sum(x, axis=(0)), out)
gradients

In [None]:
# def jacobian_column(param, idx):
#     basis_vector = jax.nn.one_hot(idx, param.size).reshape(param.shape)
#     primal, tangent = jax.jvp(lambda par: mse_loss(params, model, evaluation_point, evaluation_target), (param,), (basis_vector,))
#     return tangent
# 
# def jacobian_param(param):
#     size = param.size
#     return jax.vmap(jacobian_column, in_axes=(None, 0))(param, jnp.arange(size))
# 
# new_params = jax.tree_map(jacobian_param, params)
# new_params

# jvp vmap

In [None]:
flat_params, unravel_fn = jax.flatten_util.ravel_pytree(params)
jacobian_matrix = jnp.zeros((len(flat_params), 1))

In [None]:
evaluation_point = jnp.array([1.0, 0.5, 1.5, 2.0, 1.0])
evaluation_target = jnp.array([10])
basis_vector = jnp.zeros((len(flat_params)))
basis_vector = basis_vector.at[0].set(1.0)

primal, tangent = jax.vmap( jax.jvp(lambda par: mse_loss(unravel_fn(par), model, evaluation_point, evaluation_target), (flat_params,), (basis_vector,)) )
print(primal, tangent)
jacobian_matrix = jacobian_matrix.at[0].set(tangent)

In [None]:
evaluation_point = jnp.array([1.0, 0.5, 1.5, 2.0, 1.0])
evaluation_target = jnp.array([10])
basis_vector = jnp.zeros((len(flat_params)))
basis_vector = basis_vector.at[0].set(1.0)

primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, evaluation_point, evaluation_target), (flat_params,), (basis_vector,))
print(primal, tangent)
jacobian_matrix = jacobian_matrix.at[0].set(tangent)

In [None]:
flat_params.shape

In [None]:
unravel_fn(jacobian_matrix)

In [None]:
def loop_jvp(flat_params, model, x_point, y_point):
    basis_vector = jnp.zeros((len(flat_params)))
    jacobian_matrix = jnp.zeros((len(flat_params), 1))
    for i in range(len(flat_params)):
        e_i = basis_vector.at[i].set(1.0)
        primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, x_point, y_point), (flat_params,), (e_i,))
        jacobian_matrix = jacobian_matrix.at[i].set(tangent)
    return jacobian_matrix

In [None]:
jacobian_sum = jnp.zeros((len(flat_params), 1))
l = []
for i in range(batch_size):
    jacobian = loop_jvp(flat_params, model, x[i], y[i])
    jacobian_sum += jacobian
    l.append(jacobian)
    #jacobians = vmap(loop_jvp, in_axes=(None, None, 0, 0), out_axes=0)(flat_params, model, x, y)
# jacobian_matrix /= len(flat_params)

In [None]:
#jnp.sum(jnp.array(l), 0)

In [None]:
unravel_fn(jacobian_sum)

In [None]:
def scan_batch(carry, xy):
    xi, yi = xy
    jacobian = loop_jvp(flat_params, model, xi, yi)
    carry = carry + jacobian
    return carry, None

jacobian_init = jnp.zeros((len(flat_params), 1))
#xs = jnp.concatenate((x, y), axis=1)
xs = (x, y)
jacobian_matrix, _ = jax.lax.scan(scan_batch, jacobian_init, xs)

In [None]:
jacobian_matrix

In [None]:
unravel_fn(jacobian_matrix)

In [None]:
def scan_jvp_old(flat_params, model, x_point, y_point):
    basis_vector = jnp.zeros((len(flat_params)))
    jacobian_matrix = jnp.zeros((len(flat_params), 1))
    for i in range(len(flat_params)):
        e_i = basis_vector.at[i].set(1.0)
        primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, x_point, y_point), (flat_params,), (e_i,))
        jacobian_matrix = jacobian_matrix.at[i].set(tangent)
    return jacobian_matrix

def scan_jvp(carry, xs):
    jacobian_sum, x_point, y_point  = carry
    param_idx = xs
    basis_vector = jnp.zeros((len(flat_params)))
    e_i = basis_vector.at[param_idx].set(1.0)
    primal, tangent = jax.jvp(lambda par: mse_loss(unravel_fn(par), model, x_point, y_point), (flat_params,), (e_i,))
    # jacobian_sum += tangent * e_i[:, None]
    jacobian_sum = jacobian_sum.at[param_idx].set(tangent)
    return (jacobian_sum, x_point, y_point), None

def scan_foo(carry, xs):
    jacobian_sum, param_idxes = carry
    x_point, y_point = xs
    (jacobian, x_point, y_point), _ = jax.lax.scan(scan_jvp, (jacobian_sum, x_point, y_point), param_idxes)
    jacobian_sum += jacobian
    return (jacobian_sum, param_index), None

param_index = jnp.arange(flat_params.shape[0], dtype=int)
jacobian_init = jnp.zeros((len(flat_params), 1))

(jacobian_matrix, _), _ = jax.lax.scan(scan_foo, (jacobian_init, param_index), (x, y))

In [None]:
unravel_fn(jacobian_matrix)

In [None]:
jacobian_sum = jnp.zeros((len(flat_params), 1))
for i in range(batch_size):
    param_index = jnp.arange(flat_params.shape[0], dtype=int)
    jacobian_init = jnp.zeros((len(flat_params), 1))
    (jacobian, x_point, y_point), _ = jax.lax.scan(scan_jvp, (jacobian_init, x[i], y[i]), param_index)
    jacobian_sum += jacobian

In [None]:
unravel_fn(jacobian_sum)

In [None]:
jacobian_matrix[param_index[0]]

In [None]:
unravel_fn(jacobian_matrix)

In [None]:
unravel_fn(jacobian_matrix)

# Using jacfwd

In [None]:
loss_jacobian = jax.jacfwd(mse_loss, argnums=0)(params, model, evaluation_point, evaluation_target)
loss_jacobian

In [None]:
loss_jacobian["params"]["W1"]["kernel"].shape

In [None]:
gradients = jax.tree_map(lambda x: jnp.mean(x, axis=(0,1)), loss_jacobian)
gradients

In [None]:
gradients["params"]["W1"]["kernel"].shape

In [None]:
loss_jacobian = vmap(jax.jacfwd(mse_loss, argnums=0), in_axes=(None, None, 0, 0), out_axes=0)(params, model, x, y)
loss_jacobian

In [None]:
loss_jacobian["params"]["W1"]["kernel"].shape #.mean((0,1,2))

In [None]:
gradients = jax.tree_map(lambda x: jnp.sum(x, axis=(0,1)), loss_jacobian)
gradients

In [None]:
updates, opt_state = optimizer.update(loss_jacobian, opt_state)

In [None]:
params = optax.apply_updates(params, updates)
params

In [None]:
gradients

In [None]:
model2 = JaxNet()
params2 = model.init(random.PRNGKey(42), x[0])

optimizer_def = optax.adam(learning_rate=1e-3)
optimizer = optimizer_def.init(params2)

def loss_fn(params2, x, y):
    preds = model.apply(params2, x)
    loss = jnp.mean((preds - y)**2)
    return loss

grad_fn = jax.jit(jax.grad(loss_fn))

grads = grad_fn(params2, x, y)

In [None]:
grads

In [None]:
updates, optimizer_state = optimizer.update(grads, optimizer.state)
updates

In [None]:
gradients["params"]["W2"]["kernel"] == params2["params"]["W2"]["kernel"]