Skip to content

Commit

Permalink
Don't leak the keepalive in debug_callback lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed May 4, 2022
1 parent a9c0a97 commit 78a4e30
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions jax/_src/debugging.py
Expand Up @@ -69,10 +69,6 @@ def debug_callback_transpose_rule(*flat_args, callback: Callable[..., Any],
raise NotImplementedError('Transpose not supported for `debug_callback`.')
ad.primitive_transposes[debug_callback_p] = debug_callback_transpose_rule

# TODO(sharadmv): remove this global keepalive list in favor of attaching
# keepalives to the module context.
_keepalives = []

def _ordered_effect_lowering(ctx, token, *args, **params):
avals_in = [core.abstract_token, *ctx.avals_in]
avals_out = [core.abstract_token, *ctx.avals_out]
Expand All @@ -97,7 +93,7 @@ def _callback(*flat_args):
*flat_args, effect=effect, callback=callback, **params))
result, keepalive = mlir.emit_python_callback(ctx.module_context.platform,
_callback, list(args), ctx.avals_in, ctx.avals_out, True)
_keepalives.append(keepalive)
ctx.module_context.add_keepalive(keepalive)
return result
mlir.register_lowering(debug_callback_p, debug_callback_lowering,
platform="cpu")
Expand Down

0 comments on commit 78a4e30

Please sign in to comment.