diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 416c72b117e0..5cfaff6678ea 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -473,7 +473,6 @@ class ModuleContext: # Cached primitive lowerings. cached_primitive_lowerings: dict[Any, func_dialect.FuncOp] - cached_call_jaxpr_lowerings: dict[Any, func_dialect.FuncOp] lowering_parameters: LoweringParameters @@ -498,8 +497,6 @@ def __init__( symbol_table: ir.SymbolTable | None = None, cached_primitive_lowerings: None | (dict[Any, func_dialect.FuncOp]) = None, - cached_call_jaxpr_lowerings: None | (dict[Any, - func_dialect.FuncOp]) = None, shape_poly_state = None): self.context = context or make_ir_context() @@ -515,9 +512,6 @@ def __init__( self.channel_iterator = channel_iterator self.keepalives = keepalives self.host_callbacks = host_callbacks - self.cached_call_jaxpr_lowerings = ({} - if cached_call_jaxpr_lowerings is None - else cached_call_jaxpr_lowerings) self.shape_poly_state = ( shape_poly_state or ShapePolyLoweringState((), tuple(platforms))) self.lowering_parameters = lowering_parameters @@ -1646,12 +1640,12 @@ def _lower_jaxpr_to_fun_cached(ctx, fn_name, call_jaxpr, effects, # Cacheable. key = (fn_name, call_jaxpr.jaxpr, tuple(effects)) try: - func_op = ctx.cached_call_jaxpr_lowerings[key] + func_op = ctx.cached_primitive_lowerings[key] except KeyError: func_op = lower_jaxpr_to_fun( ctx, fn_name, call_jaxpr, effects, arg_names=arg_names, result_names=result_names) - ctx.cached_call_jaxpr_lowerings[key] = func_op + ctx.cached_primitive_lowerings[key] = func_op else: func_op = lower_jaxpr_to_fun( ctx, fn_name, call_jaxpr, effects, arg_names=arg_names,