In [52]:
import jax
import jax.numpy as jnp
import jpviz
import brax.training.agents.diffrl_shac.networks as shac_networks
from brax.training.acme import running_statistics, specs
from brax.envs.inverted_pendulum import InvertedPendulum

In [79]:
env = InvertedPendulum(backend="mjx")
obs_size = env.observation_size
action_size = env.action_size

In [80]:
network = shac_networks.make_shac_networks(
    obs_size,
    action_size,
    policy_hidden_layer_sizes=(64, 64),
    value_hidden_layer_sizes=(64, 64),
)
make_inference_fn = shac_networks.make_inference_fn(network)

prng = jax.random.PRNGKey(10)

key_policy, key_inference, key_env = jax.random.split(prng, 3)
policy_params = network.policy_network.init(key_policy)
normalizer_params = running_statistics.init_state(
    specs.Array((obs_size,), jnp.dtype('float32'))
)
inference_fn = make_inference_fn((normalizer_params, policy_params))

In [81]:
obs = jnp.asarray([3.0, 3.0, 3.0, 3.0])
jax.make_jaxpr(inference_fn)(obs, key_inference)

let silu = { lambda ; a:f32[64]. let
    b:f32[64] = logistic a
    c:f32[64] = mul a b
  in (c,) } in
{ lambda d:f32[4,64] e:f32[64] f:f32[64,64] g:f32[64] h:f32[64,2] i:f32[2]; j:f32[4]
    k:u32[2]. let
    l:f32[64] = dot_general[dimension_numbers=(([0], [0]), ([], []))] j d
    m:f32[64] = add l e
    n:f32[64] = pjit[name=silu jaxpr=silu] m
    o:f32[64] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n f
    p:f32[64] = add o g
    q:f32[64] = pjit[name=silu jaxpr=silu] p
    r:f32[2] = dot_general[dimension_numbers=(([0], [0]), ([], []))] q h
    s:f32[2] = add r i
    t:f32[1] u:f32[1] = split[axis=0 sizes=(np.int64(1), np.int64(1))] s
    v:f32[1] = pjit[
      name=softplus
      jaxpr={ lambda ; w:f32[1]. let
          x:f32[1] = custom_jvp_call[
            call_jaxpr={ lambda ; y:f32[1] z:f32[]. let
                ba:f32[1] = max y z
                bb:f32[1] = sub y z
                bc:bool[1] = ne bb bb
                bd:f32[1] = add y z
                be:f3

In [144]:
def sum_cost_utilization(cost_analysis):
    return sum(v for k, v in cost_analysis.items() if k.startswith("utilization"))

In [145]:
print(jax.jit(inference_fn).lower(obs, key_inference))

<jax._src.stages.Lowered object at 0x7d74c82e7ec0>


In [146]:
jax.jit(inference_fn).lower(obs, key_inference).compile().cost_analysis()

{'utilization6{}': 2.0,
 'transcendentals': 133.0,
 'utilization3{}': 4.0,
 'bytes accessed2{}': 12.0,
 'utilization0{}': 43.0,
 'utilization4{}': 3.0,
 'utilization7{}': 2.0,
 'flops': 9579.0,
 'bytes accessed3{}': 8.0,
 'bytes accessed0{}': 1224.0,
 'bytes accessed4{}': 4.0,
 'bytes accessed': 20541.0,
 'bytes accessedout{}': 1333.0,
 'bytes accessed1{}': 19504.0,
 'utilization2{}': 5.0,
 'utilization1{}': 18.0,
 'utilization5{}': 2.0}

In [147]:
jax.jit(inference_fn).lower(obs, key_inference).compile().cost_analysis()["flops"]

9579.0

In [148]:
jitted = jax.jit(inference_fn).lower(obs, key_inference).compile().cost_analysis()
sum_cost_utilization(jitted)

79.0

In [149]:
dot_graph = jpviz.draw(jax.jit(inference_fn))(obs, key_inference)
dot_graph.write_png("computation_graph.png")

In [150]:
state = env.reset(key_env)

actions = jnp.asarray([0.0])
jax.make_jaxpr(env.step)(state, actions)

let _take = { lambda ; a:f32[2,3] b:i32[1,0]. let
    _:i32[1,0] = pjit[
      name=remainder
      jaxpr={ lambda ; c:i32[1,0] d:i32[]. let
          e:bool[] = eq d 0
          f:i32[] = pjit[
            name=_where
            jaxpr={ lambda ; g:bool[] h:i32[] i:i32[]. let
                j:i32[] = select_n g i h
              in (j,) }
          ] e 1 d
          k:i32[1,0] = rem c f
          l:bool[1,0] = ne k 0
          m:bool[1,0] = lt k 0
          n:bool[] = lt f 0
          o:bool[1,0] = ne m n
          p:bool[1,0] = and o l
          q:i32[1,0] = add k f
          r:i32[1,0] = select_n p k q
        in (r,) }
    ] b 2
    s:f32[1,0,3] = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1, 0, 3)
      sharding=None
    ] 0.0
  in (s,) } in
let _take1 = { lambda ; t:f32[2] u:i32[1,0]. let
    _:i32[1,0] = pjit[
      name=remainder
      jaxpr={ lambda ; c:i32[1,0] d:i32[]. let
          e:bool[] = eq d 0
          f:i32[] = pjit[
            name=_where
      

In [151]:
jax.jit(env.step).lower(state, actions).compile().cost_analysis()["flops"]

5304.0

In [152]:
jitted = jax.jit(env.step).lower(state, actions).compile().cost_analysis()
sum_cost_utilization(jitted)

6137.479248046875