Skip to content

Commit

Permalink
Delete cached_call_jaxpr_lowerings since a more general cached_primit…
Browse files Browse the repository at this point in the history
…ive_lowerings is available

PiperOrigin-RevId: 577993595
  • Loading branch information
yashk2810 authored and jax authors committed Oct 30, 2023
1 parent 85af862 commit 20255dc
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions jax/_src/interpreters/mlir.py
Expand Up @@ -473,7 +473,6 @@ class ModuleContext:

# Cached primitive lowerings.
cached_primitive_lowerings: dict[Any, func_dialect.FuncOp]
cached_call_jaxpr_lowerings: dict[Any, func_dialect.FuncOp]

lowering_parameters: LoweringParameters

Expand All @@ -498,8 +497,6 @@ def __init__(
symbol_table: ir.SymbolTable | None = None,
cached_primitive_lowerings: None | (dict[Any,
func_dialect.FuncOp]) = None,
cached_call_jaxpr_lowerings: None | (dict[Any,
func_dialect.FuncOp]) = None,
shape_poly_state = None):

self.context = context or make_ir_context()
Expand All @@ -515,9 +512,6 @@ def __init__(
self.channel_iterator = channel_iterator
self.keepalives = keepalives
self.host_callbacks = host_callbacks
self.cached_call_jaxpr_lowerings = ({}
if cached_call_jaxpr_lowerings is None
else cached_call_jaxpr_lowerings)
self.shape_poly_state = (
shape_poly_state or ShapePolyLoweringState((), tuple(platforms)))
self.lowering_parameters = lowering_parameters
Expand Down Expand Up @@ -1646,12 +1640,12 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects,
# Cacheable.
key = (fn_name, call_jaxpr.jaxpr, tuple(effects))
try:
func_op = ctx.cached_call_jaxpr_lowerings[key]
func_op = ctx.cached_primitive_lowerings[key]
except KeyError:
func_op = lower_jaxpr_to_fun(
ctx, fn_name, call_jaxpr, effects, arg_names=arg_names,
result_names=result_names)
ctx.cached_call_jaxpr_lowerings[key] = func_op
ctx.cached_primitive_lowerings[key] = func_op
else:
func_op = lower_jaxpr_to_fun(
ctx, fn_name, call_jaxpr, effects, arg_names=arg_names,
Expand Down

0 comments on commit 20255dc

Please sign in to comment.