Skip to content

Commit

Permalink
[JAX] Use up to date APIs to compute FLOP estimate.
Browse files Browse the repository at this point in the history
Change in preparation for dropping support for XlaComputation arguments to compile().

PiperOrigin-RevId: 498618680
  • Loading branch information
hawkinsp authored and romanngg committed Jan 3, 2023
1 parent 7f43576 commit 7b8cccc
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions neural_tangents/_src/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2265,14 +2265,16 @@ def _get_fwd(


def _get_flops(f: Callable, optimize: bool, *a, **kw) -> float:
m = jax.xla_computation(f)(*a, **kw)
client = jax.lib.xla_bridge.get_backend()
if optimize:
m = client.compile(m).hlo_modules()[0]
e = jax.jit(f).lower(*a, **kw).compile()
return e.cost_analysis()[0]['flops']
else:
m = jax.xla_computation(f)(*a, **kw)
client = jax.lib.xla_bridge.get_backend()
m = m.as_hlo_module()
analysis = jax.lib.xla_client._xla.hlo_module_cost_analysis(client, m)
return analysis['flops']
analysis = jax.lib.xla_client._xla.hlo_module_cost_analysis(client, m)
return analysis['flops']



def _std_basis(pytree: PyTree) -> PyTree:
Expand Down

0 comments on commit 7b8cccc

Please sign in to comment.