Skip to content

How to use named_call in jitted functions #7029

@woct0rdho

Description

@woct0rdho

I'm trying to profile my code and use named_call to annotate some functions, but the names will not show up in the captured trace if they're inside jitted functions.

Here is my test script (test_named_call.py):

import jax
from jax import numpy as jnp

@jax.named_call
def foo(x, y):
    return (x + y) / 2.

@jax.jit
def bar(a):
    def foo2(x, y):
        return foo(x, y), None

    out, _ = jax.lax.scan(foo2, 0., a)
    return out

a = jnp.array([1., 2., 3., 4., 5.])

jax.profiler.start_trace('/tmp/tensorboard')
with jax.profiler.StepTraceAnnotation('step', step_num=0): # JIT warm-up
    out = bar(a)
with jax.profiler.StepTraceAnnotation('step', step_num=1):
    out = bar(a)
out.block_until_ready()
jax.profiler.stop_trace()

My environments: Ubuntu 20.04, Python 3.8.5, CUDA 11.3, jax 0.2.14, jaxlib 0.1.67+cuda111, tensorflow 2.5.0, tbp-nightly 2.5.0a20210511

scrshot

Above is an overview of the captured trace. step 1 takes a very short time after step 0. I can find the name foo in that bunch of functions in step 0, but not in step 1.

scrshot2

Above is a zoomed-in view of step 1. The operations have only general names like 'fusion' or 'Memcpy'. Because there're 5 repeated operations, I can guess it's the scan loop. But in general it's really hard to associate those operations with Python lines.

Also, @jekbradbury mentioned that in the bottom plane there should be some information like 'source'. Is it available?

Metadata

Metadata

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions