Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use named_call in jitted functions #7029

Open
woct0rdho opened this issue Jun 19, 2021 · 12 comments
Open

How to use named_call in jitted functions #7029

woct0rdho opened this issue Jun 19, 2021 · 12 comments
Assignees
Labels
bug Something isn't working

Comments

@woct0rdho
Copy link

I'm trying to profile my code and use named_call to annotate some functions, but the names will not show up in the captured trace if they're inside jitted functions.

Here is my test script (test_named_call.py):

import jax
from jax import numpy as jnp

@jax.named_call
def foo(x, y):
    return (x + y) / 2.

@jax.jit
def bar(a):
    def foo2(x, y):
        return foo(x, y), None

    out, _ = jax.lax.scan(foo2, 0., a)
    return out

a = jnp.array([1., 2., 3., 4., 5.])

jax.profiler.start_trace('/tmp/tensorboard')
with jax.profiler.StepTraceAnnotation('step', step_num=0): # JIT warm-up
    out = bar(a)
with jax.profiler.StepTraceAnnotation('step', step_num=1):
    out = bar(a)
out.block_until_ready()
jax.profiler.stop_trace()

My environments: Ubuntu 20.04, Python 3.8.5, CUDA 11.3, jax 0.2.14, jaxlib 0.1.67+cuda111, tensorflow 2.5.0, tbp-nightly 2.5.0a20210511

scrshot

Above is an overview of the captured trace. step 1 takes a very short time after step 0. I can find the name foo in that bunch of functions in step 0, but not in step 1.

scrshot2

Above is a zoomed-in view of step 1. The operations have only general names like 'fusion' or 'Memcpy'. Because there're 5 repeated operations, I can guess it's the scan loop. But in general it's really hard to associate those operations with Python lines.

Also, @jekbradbury mentioned that in the bottom plane there should be some information like 'source'. Is it available?

@woct0rdho woct0rdho added the bug Something isn't working label Jun 19, 2021
@jekbradbury
Copy link
Contributor

I think the core issue here is the lack of the "TensorFlow" (really TF/JAX) row on the open source profiler trace viewer for GPU. Maybe @yisitu knows if this is a known issue?

@tomhennigan
Copy link
Member

FYI this was previously an issue (see internal bug: 183181931) but I think the jaxlib version you're using is new enough to include the appropriate fix so I'm confused.

@enolan
Copy link

enolan commented Aug 27, 2021

Any news on this? It's extremely difficult to debug performance when everything is named custom-call or fusion.

@danielsnider
Copy link

danielsnider commented Aug 27, 2021

Hi, I can confirm this is still an issue. I see the "foo" function in the trace for step 0 but not for step 1. I am expecting to see "foo" in step 0 and step 1.

image

image

This is using the same code as above in the the original post.
My environment is: Ubuntu 20.04 Python 3.8.10, CUDA 11.2.152, jax 0.2.19, jaxlib 0.1.70+cuda111, tbp-nightly 2.5.0a20210511, tensorflow 2.5.0.

@oliverdutton
Copy link
Contributor

Any movement here?

I similarly have beautiful traces in Tensorboard with fusion_*, and I can read between the lines what they likely are.

Is there any way of linking these back to jaxpr's? or ideally what lax primitives or more ideally functions have been fused?

@fabiannagel
Copy link

+1 on this. Also missing the "TensorFlow" row in the trace viewer.
Python 3.9.7, CUDA 11.1.105, jax 0.2.25, jaxlib 0.1.73+cuda11.cudnn805, tbp-nightly 2.5.0a20210511, tensorflow 2.7.0

@danielsnider
Copy link

I can confirm this is still an issue when using: Python 3.8.10, CUDA 11.4 (driver 470.103.01), jax 0.2.28, jaxlib 0.1.76+cuda11.cudnn82, tbp-nightly 2.5.0a20210511, tensorflow 2.8.0

The named_call documentation describes exactly what I want, but it doesn't work. Without this link between generated GPU kernels and source code, performance debugging XLA is extremely difficult.

@danielsnider
Copy link

PS, I also tried all the annotation related Jax APIs which include:

jax.profiler.annotate_function
jax.profiler.TraceAnnotation
jax.profiler.StepTraceAnnotation
n_call_func = jax.named_call(func, name='my_func')
@jax.named_call

They all work as expected except named_call.

@oliverdutton
Copy link
Contributor

oliverdutton commented Apr 19, 2022

I have no solution for the named call, but I can link those fusions to actual lines in your code.

You can lower the jitted function and view the compiler ir, this has metadata for the actual ops. Here reduce.17 is the sum in line 7 (as expected) and the fusion is everything else. This means you can at least track where in the program the compiler is at any stage.

import jax
from jax import numpy as jnp

n = 1000 # array size
m = n # number of indices to take

values = jax.random.uniform(jax.random.PRNGKey(0), (n,))
indices = ((jax.random.truncated_normal(jax.random.PRNGKey(0), -1,1, (m,)) + 1)*n/2).astype(int)
inputs = (values, indices)

# def get(values, indices):
#     x = values[indices]
#     y = values.sum()
#     return x+y

# Write out to file
s = '''
import jax
from jax import numpy as jnp

def get(values, indices):
    x = values[indices]
    y = values.sum()
    return x+y
'''
open('/tmp/my_func.py','w').write(s)
import sys
sys.path.insert(0, '/tmp')
from my_func import get

# Compile and trace
print(jax.jit(get).lower(*inputs).compile().compiler_ir()[0].to_string())
 
f = jax.jit(get)
f(*inputs)
with jax.profiler.trace('/tmp/tensorboard/'):
    _ = jax.block_until_ready(f(*inputs))
HloModule jit_get.27

%region_0.13 (Arg_0.14: f32[], Arg_1.15: f32[]) -> f32[] {
  %Arg_0.14 = f32[] parameter(0)
  %Arg_1.15 = f32[] parameter(1)
  ROOT %add.16 = f32[] add(f32[] %Arg_0.14, f32[] %Arg_1.15), metadata={op_name="jit(get)/jit(main)/reduce_sum[axes=(0,)]" source_file="/tmp/my_func.py" source_line=7}
}

%fused_computation (param_0.1: f32[], param_1.2: f32[1000], param_2.3: s32[1000]) -> f32[1000] {
  %param_1.2 = f32[1000]{0} parameter(1)
  %param_2.3 = s32[1000]{0} parameter(2)
  %constant_1 = s32[] constant(0)
  %broadcast.2 = s32[1000]{0} broadcast(s32[] %constant_1), dimensions={}, metadata={op_name="jit(get)/jit(main)/lt" source_file="/tmp/my_func.py" source_line=6}
  %compare.0 = pred[1000]{0} compare(s32[1000]{0} %param_2.3, s32[1000]{0} %broadcast.2), direction=LT, metadata={op_name="jit(get)/jit(main)/lt" source_file="/tmp/my_func.py" source_line=6}
  %constant_0 = s32[] constant(1000)
  %broadcast.1 = s32[1000]{0} broadcast(s32[] %constant_0), dimensions={}, metadata={op_name="jit(get)/jit(main)/add" source_file="/tmp/my_func.py" source_line=6}
  %add.1 = s32[1000]{0} add(s32[1000]{0} %param_2.3, s32[1000]{0} %broadcast.1), metadata={op_name="jit(get)/jit(main)/add" source_file="/tmp/my_func.py" source_line=6}
  %select.0 = s32[1000]{0} select(pred[1000]{0} %compare.0, s32[1000]{0} %add.1, s32[1000]{0} %param_2.3), metadata={op_name="jit(get)/jit(main)/select_n" source_file="/tmp/my_func.py" source_line=6}
  %bitcast.1 = s32[1000,1]{1,0} bitcast(s32[1000]{0} %select.0), metadata={op_name="jit(get)/jit(main)/broadcast_in_dim[shape=(1000, 1) broadcast_dimensions=(0,)]" source_file="/tmp/my_func.py" source_line=6}
  %gather.0 = f32[1000]{0} gather(f32[1000]{0} %param_1.2, s32[1000,1]{1,0} %bitcast.1), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_name="jit(get)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1,) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/tmp/my_func.py" source_line=6}
  %param_0.1 = f32[] parameter(0)
  %broadcast.0 = f32[1000]{0} broadcast(f32[] %param_0.1), dimensions={}, metadata={op_name="jit(get)/jit(main)/add" source_file="/tmp/my_func.py" source_line=8}
  ROOT %add.0 = f32[1000]{0} add(f32[1000]{0} %gather.0, f32[1000]{0} %broadcast.0), metadata={op_name="jit(get)/jit(main)/add" source_file="/tmp/my_func.py" source_line=8}
}

ENTRY %main.20 (Arg_0.1: f32[1000], Arg_1.2: s32[1000]) -> f32[1000] {
  %Arg_0.1 = f32[1000]{0} parameter(0)
  %constant_3 = f32[] constant(0)
  %reduce.17 = f32[] reduce(f32[1000]{0} %Arg_0.1, f32[] %constant_3), dimensions={0}, to_apply=%region_0.13, metadata={op_name="jit(get)/jit(main)/reduce_sum[axes=(0,)]" source_file="/tmp/my_func.py" source_line=7}
  %Arg_1.2 = s32[1000]{0} parameter(1)
  ROOT %fusion = f32[1000]{0} fusion(f32[] %reduce.17, f32[1000]{0} %Arg_0.1, s32[1000]{0} %Arg_1.2), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(get)/jit(main)/add" source_file="/tmp/my_func.py" source_line=8}
}

image

image

@danielsnider
Copy link

Has anyone tried if jax.named_scope preserves annotions in jitted functions?

@sharadmv
Copy link
Member

jax.named_scope uses the same mechanism to set HLO metadata as jax.named_call so I suspect they will have similar problems. I think this is a profiler issue because in the HLO, we can see the correct information in the op_name field in metadata.

@sharadmv sharadmv self-assigned this Oct 24, 2022
@cagrikymk
Copy link

@sharadmv Is there any profiler that works well with JAX?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

9 participants