Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use named_call in jitted functions #7029

woct0rdho opened this issue Jun 19, 2021 · 12 comments

How to use named_call in jitted functions #7029

woct0rdho opened this issue Jun 19, 2021 · 12 comments
bug Something isn't working


Copy link

I'm trying to profile my code and use named_call to annotate some functions, but the names will not show up in the captured trace if they're inside jitted functions.

Here is my test script (

import jax
from jax import numpy as jnp

def foo(x, y):
    return (x + y) / 2.

def bar(a):
    def foo2(x, y):
        return foo(x, y), None

    out, _ = jax.lax.scan(foo2, 0., a)
    return out

a = jnp.array([1., 2., 3., 4., 5.])

with jax.profiler.StepTraceAnnotation('step', step_num=0): # JIT warm-up
    out = bar(a)
with jax.profiler.StepTraceAnnotation('step', step_num=1):
    out = bar(a)

My environments: Ubuntu 20.04, Python 3.8.5, CUDA 11.3, jax 0.2.14, jaxlib 0.1.67+cuda111, tensorflow 2.5.0, tbp-nightly 2.5.0a20210511


Above is an overview of the captured trace. step 1 takes a very short time after step 0. I can find the name foo in that bunch of functions in step 0, but not in step 1.


Above is a zoomed-in view of step 1. The operations have only general names like 'fusion' or 'Memcpy'. Because there're 5 repeated operations, I can guess it's the scan loop. But in general it's really hard to associate those operations with Python lines.

Also, @jekbradbury mentioned that in the bottom plane there should be some information like 'source'. Is it available?

@woct0rdho woct0rdho added the bug Something isn't working label Jun 19, 2021
Copy link

I think the core issue here is the lack of the "TensorFlow" (really TF/JAX) row on the open source profiler trace viewer for GPU. Maybe @yisitu knows if this is a known issue?

Copy link

FYI this was previously an issue (see internal bug: 183181931) but I think the jaxlib version you're using is new enough to include the appropriate fix so I'm confused.

Copy link

enolan commented Aug 27, 2021

Any news on this? It's extremely difficult to debug performance when everything is named custom-call or fusion.

Copy link

danielsnider commented Aug 27, 2021

Hi, I can confirm this is still an issue. I see the "foo" function in the trace for step 0 but not for step 1. I am expecting to see "foo" in step 0 and step 1.



This is using the same code as above in the the original post.
My environment is: Ubuntu 20.04 Python 3.8.10, CUDA 11.2.152, jax 0.2.19, jaxlib 0.1.70+cuda111, tbp-nightly 2.5.0a20210511, tensorflow 2.5.0.

Copy link

Any movement here?

I similarly have beautiful traces in Tensorboard with fusion_*, and I can read between the lines what they likely are.

Is there any way of linking these back to jaxpr's? or ideally what lax primitives or more ideally functions have been fused?

Copy link

+1 on this. Also missing the "TensorFlow" row in the trace viewer.
Python 3.9.7, CUDA 11.1.105, jax 0.2.25, jaxlib 0.1.73+cuda11.cudnn805, tbp-nightly 2.5.0a20210511, tensorflow 2.7.0

Copy link

I can confirm this is still an issue when using: Python 3.8.10, CUDA 11.4 (driver 470.103.01), jax 0.2.28, jaxlib 0.1.76+cuda11.cudnn82, tbp-nightly 2.5.0a20210511, tensorflow 2.8.0

The named_call documentation describes exactly what I want, but it doesn't work. Without this link between generated GPU kernels and source code, performance debugging XLA is extremely difficult.

Copy link

PS, I also tried all the annotation related Jax APIs which include:

n_call_func = jax.named_call(func, name='my_func')

They all work as expected except named_call.

Copy link

oliverdutton commented Apr 19, 2022

I have no solution for the named call, but I can link those fusions to actual lines in your code.

You can lower the jitted function and view the compiler ir, this has metadata for the actual ops. Here reduce.17 is the sum in line 7 (as expected) and the fusion is everything else. This means you can at least track where in the program the compiler is at any stage.

import jax
from jax import numpy as jnp

n = 1000 # array size
m = n # number of indices to take

values = jax.random.uniform(jax.random.PRNGKey(0), (n,))
indices = ((jax.random.truncated_normal(jax.random.PRNGKey(0), -1,1, (m,)) + 1)*n/2).astype(int)
inputs = (values, indices)

# def get(values, indices):
#     x = values[indices]
#     y = values.sum()
#     return x+y

# Write out to file
s = '''
import jax
from jax import numpy as jnp

def get(values, indices):
    x = values[indices]
    y = values.sum()
    return x+y
import sys
sys.path.insert(0, '/tmp')
from my_func import get

# Compile and trace
f = jax.jit(get)
with jax.profiler.trace('/tmp/tensorboard/'):
    _ = jax.block_until_ready(f(*inputs))
HloModule jit_get.27

%region_0.13 (Arg_0.14: f32[], Arg_1.15: f32[]) -> f32[] {
  %Arg_0.14 = f32[] parameter(0)
  %Arg_1.15 = f32[] parameter(1)
  ROOT %add.16 = f32[] add(f32[] %Arg_0.14, f32[] %Arg_1.15), metadata={op_name="jit(get)/jit(main)/reduce_sum[axes=(0,)]" source_file="/tmp/" source_line=7}

%fused_computation (param_0.1: f32[], param_1.2: f32[1000], param_2.3: s32[1000]) -> f32[1000] {
  %param_1.2 = f32[1000]{0} parameter(1)
  %param_2.3 = s32[1000]{0} parameter(2)
  %constant_1 = s32[] constant(0)
  %broadcast.2 = s32[1000]{0} broadcast(s32[] %constant_1), dimensions={}, metadata={op_name="jit(get)/jit(main)/lt" source_file="/tmp/" source_line=6}
  %compare.0 = pred[1000]{0} compare(s32[1000]{0} %param_2.3, s32[1000]{0} %broadcast.2), direction=LT, metadata={op_name="jit(get)/jit(main)/lt" source_file="/tmp/" source_line=6}
  %constant_0 = s32[] constant(1000)
  %broadcast.1 = s32[1000]{0} broadcast(s32[] %constant_0), dimensions={}, metadata={op_name="jit(get)/jit(main)/add" source_file="/tmp/" source_line=6}
  %add.1 = s32[1000]{0} add(s32[1000]{0} %param_2.3, s32[1000]{0} %broadcast.1), metadata={op_name="jit(get)/jit(main)/add" source_file="/tmp/" source_line=6}
  %select.0 = s32[1000]{0} select(pred[1000]{0} %compare.0, s32[1000]{0} %add.1, s32[1000]{0} %param_2.3), metadata={op_name="jit(get)/jit(main)/select_n" source_file="/tmp/" source_line=6}
  %bitcast.1 = s32[1000,1]{1,0} bitcast(s32[1000]{0} %select.0), metadata={op_name="jit(get)/jit(main)/broadcast_in_dim[shape=(1000, 1) broadcast_dimensions=(0,)]" source_file="/tmp/" source_line=6}
  %gather.0 = f32[1000]{0} gather(f32[1000]{0} %param_1.2, s32[1000,1]{1,0} %bitcast.1), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(get)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/tmp/" source_line=6}
  %param_0.1 = f32[] parameter(0)
  %broadcast.0 = f32[1000]{0} broadcast(f32[] %param_0.1), dimensions={}, metadata={op_name="jit(get)/jit(main)/add" source_file="/tmp/" source_line=8}
  ROOT %add.0 = f32[1000]{0} add(f32[1000]{0} %gather.0, f32[1000]{0} %broadcast.0), metadata={op_name="jit(get)/jit(main)/add" source_file="/tmp/" source_line=8}

ENTRY %main.20 (Arg_0.1: f32[1000], Arg_1.2: s32[1000]) -> f32[1000] {
  %Arg_0.1 = f32[1000]{0} parameter(0)
  %constant_3 = f32[] constant(0)
  %reduce.17 = f32[] reduce(f32[1000]{0} %Arg_0.1, f32[] %constant_3), dimensions={0}, to_apply=%region_0.13, metadata={op_name="jit(get)/jit(main)/reduce_sum[axes=(0,)]" source_file="/tmp/" source_line=7}
  %Arg_1.2 = s32[1000]{0} parameter(1)
  ROOT %fusion = f32[1000]{0} fusion(f32[] %reduce.17, f32[1000]{0} %Arg_0.1, s32[1000]{0} %Arg_1.2), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(get)/jit(main)/add" source_file="/tmp/" source_line=8}



Copy link

Has anyone tried if jax.named_scope preserves annotions in jitted functions?

Copy link

jax.named_scope uses the same mechanism to set HLO metadata as jax.named_call so I suspect they will have similar problems. I think this is a profiler issue because in the HLO, we can see the correct information in the op_name field in metadata.

@sharadmv sharadmv self-assigned this Oct 24, 2022
Copy link

@sharadmv Is there any profiler that works well with JAX?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
bug Something isn't working
None yet

No branches or pull requests

9 participants