In [1]:
import jax
import jax.numpy as jnp  # type: ignore
import numpy as np  # type: ignore

from jax import grad, jit, vmap, pmap  # type: ignore
import matplotlib.pyplot as plt

In [2]:
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])

s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])

def loss_fn(theta, s_tm1, r_t, s_t):
    stop_term = r_t + value_fn(theta, s_t)
    return (jax.lax.stop_gradient(stop_term) - value_fn(theta, s_tm1))**2

grad_fn = jax.grad(loss_fn)
update = grad_fn(theta, s_tm1, r_t, s_t)

print(update)


2024-08-21 16:55:20.303821: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.1 which is older than the PTX compiler version (12.5.82). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


[-2.4 -4.8  2.4]


In [3]:
def f(x):
    return jnp.round(x)

def straight_through_f(x):
    return x + jax.lax.stop_gradient(f(x) - x)

x = 5.6
print("f(x): ", f(x))
print("Straight through f(x): ", straight_through_f(x))

print("grad(f)(x): ", jax.grad(f)(x))
print("grad(straight_through_f)(x): ", jax.grad(straight_through_f)(x))

f(x):  6.0
Straight through f(x):  6.0
grad(f)(x):  0.0
grad(straight_through_f)(x):  1.0


In [5]:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])

example_grads = jax.jit(jax.vmap(jax.grad(loss_fn), in_axes=(None, 0, 0, 0)))
example_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)

Array([[-2.4, -4.8,  2.4],
       [-2.4, -4.8,  2.4]], dtype=float32)

In [None]:
def loss_fn(params, data):
    pass

lr = 0.1

def meta_loss_fn(params, data):
    grads = jax.grad(loss_fn)(params, data)
    return loss_fn(params - lr * grads, data)

meta_grads = jax.grad(meta_loss_fn)(params, data)