Skip to content

Commit

Permalink
Set up an API to top trace and fdo profile in memory.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 573276173
  • Loading branch information
wang12tao authored and jax authors committed Oct 13, 2023
1 parent e088a8e commit c568110
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 6 deletions.
37 changes: 31 additions & 6 deletions jax/_src/profiler.py
Expand Up @@ -30,6 +30,7 @@

from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version

_profiler_server: Optional[xla_client.profiler.ProfilerServer] = None

Expand Down Expand Up @@ -76,6 +77,13 @@ def __init__(self):
self.create_perfetto_trace = False
self.lock = threading.Lock()

def reset(self):
_profile_state.profile_session = None
_profile_state.create_perfetto_link = False
_profile_state.create_perfetto_trace = False
_profile_state.log_dir = None


_profile_state = _ProfileState()


Expand Down Expand Up @@ -193,15 +201,33 @@ def stop_trace():
with _profile_state.lock:
if _profile_state.profile_session is None:
raise RuntimeError("No profile started")
_profile_state.profile_session.stop_and_export(_profile_state.log_dir)
if xla_extension_version > 205:
sess = _profile_state.profile_session
sess.export(sess.stop(), _profile_state.log_dir)
else:
_profile_state.profile_session.stop_and_export(_profile_state.log_dir) # pytype: disable=attribute-error
if _profile_state.create_perfetto_trace:
abs_filename = _write_perfetto_trace_file(_profile_state.log_dir)
if _profile_state.create_perfetto_link:
_host_perfetto_trace_file(abs_filename)
_profile_state.profile_session = None
_profile_state.create_perfetto_link = False
_profile_state.create_perfetto_trace = False
_profile_state.log_dir = None
_profile_state.reset()


def stop_and_get_fdo_profile() -> bytes:
"""Stops the currently-running profiler trace and export fdo_profile.
Currently, this is only supported for GPU.
Raises a RuntimeError if a trace hasn't been started.
"""
if xla_extension_version < 206:
raise NotImplementedError("stop and get fdo profile is not supported.")
with _profile_state.lock:
if _profile_state.profile_session is None:
raise RuntimeError("No profile started")
xspace = _profile_state.profile_session.stop()
fdo_profile = xla_client.profiler.get_fdo_profile(xspace)
_profile_state.reset()
return fdo_profile


@contextmanager
Expand Down Expand Up @@ -316,7 +342,6 @@ def wrapper(*args, **kwargs):
return wrapper



def device_memory_profile(backend: Optional[str] = None) -> bytes:
"""Captures a JAX device memory profile as ``pprof``-format protocol buffer.
Expand Down
14 changes: 14 additions & 0 deletions tests/profiler_test.py
Expand Up @@ -27,6 +27,7 @@
import jax.numpy as jnp
import jax.profiler
from jax import config
from jax._src.lib import xla_extension_version
import jax._src.test_util as jtu

try:
Expand Down Expand Up @@ -103,6 +104,19 @@ def testProgrammaticProfiling(self):
self.assertIn(b"/device:TPU", proto)
self.assertIn(b"pxla.py", proto)

def testProfilerGetFDOProfile(self):
if xla_extension_version < 206:
return
# Tests stop_and_get_fod_profile could run.
try:
jax.profiler.start_trace("test")
jax.pmap(lambda x: jax.lax.psum(x + 1, "i"), axis_name="i")(
jnp.ones(jax.local_device_count())
)
finally:
fdo_profile = jax._src.profiler.stop_and_get_fdo_profile()
self.assertEqual(fdo_profile, b"")

def testProgrammaticProfilingErrors(self):
with self.assertRaisesRegex(RuntimeError, "No profile started"):
jax.profiler.stop_trace()
Expand Down

0 comments on commit c568110

Please sign in to comment.