In [1]:
import jax
import jax.numpy as jnp
import brax.training.agents.diffrl_shac.networks as shac_networks
from brax.training.acme import running_statistics, specs
from brax.envs.inverted_pendulum import InvertedPendulum
from jax import tree_util

In [2]:
env = InvertedPendulum(backend="mjx")
obs_size = env.observation_size
action_size = env.action_size

In [3]:
print("Observation size", obs_size)
print("Action size", action_size)

Observation size 4
Action size 1


In [4]:
network = shac_networks.make_shac_networks(
    4,
    1,
    policy_hidden_layer_sizes=(64, 64),
    value_hidden_layer_sizes=(64, 64),
)
make_inference_fn = shac_networks.make_inference_fn(network)

prng = jax.random.PRNGKey(10)
key_policy, key_inference, key_env = jax.random.split(prng, 3)


  dummy_obs = jnp.zeros((1, obs_size), dtype=dtype)
  dummy_obs = jnp.zeros((1, obs_size), dtype=dtype)


In [5]:
def inference_grad_fn(policy_params):
    normalizer_params = running_statistics.init_state(
        specs.Array((4,), jnp.dtype('float32'))
    )
    inference_fn = make_inference_fn((normalizer_params, policy_params))
    
    obs = jnp.asarray([3.0, 3.0, 3.0, 3.0])
    action, metrics = inference_fn(obs, key_inference)
    return jnp.square(jnp.sum(action))

inference_grad = jax.grad(inference_grad_fn)

In [6]:
policy_params = network.policy_network.init(key_policy)
grad = inference_grad(policy_params)
tree_util.tree_map(lambda x: x.block_until_ready(), grad)

jax.profiler.save_device_memory_profile("memory.prof")

In [7]:
with jax.profiler.trace("/tmp/tensorboard"):
    policy_params = network.policy_network.init(key_policy)
    grad = inference_grad(policy_params)
    tree_util.tree_map(lambda x: x.block_until_ready(), grad)

2025-03-23 19:38:31.014209: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742758711.029774  130187 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742758711.034712  130187 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1742758711.046791  130187 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1742758711.046803  130187 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1742758711.046805  130187 computation_placer.cc:177] computation placer alr

In [6]:
state = env.reset(key_env)

def env_step_grad_fn(actions):
    env.step(state, actions)
    return state.reward

env_step_grad = jax.grad(env_step_grad_fn)

In [None]:
with jax.profiler.trace("/tmp/tensorboard"):
    actions = jnp.asarray([0.0])
    grad = env_step_grad(actions)
    tree_util.tree_map(lambda x: x.block_until_ready(), grad)

In [None]:
actions = jnp.asarray([0.0])
grad = env_step_grad(actions)
tree_util.tree_map(lambda x: x.block_until_ready(), grad)

jax.profiler.save_device_memory_profile("memory.prof")