From 1d895b2c85e17b9f563cd41d9a340528179d29aa Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Mon, 26 Sep 2022 17:29:08 -0700 Subject: [PATCH] Fix lax imports --- jax/_src/ad_checkpoint.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 616503f073ef..d2e35184a096 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -36,6 +36,8 @@ from jax._src import source_info_util from jax._src import traceback_util from jax._src.api_util import flatten_fun, shaped_abstractify +from jax._src.lax import lax as lax_internal +from jax._src.lax import convolution as lax_convolution from jax._src.lib.mlir.dialects import mhlo from jax._src.traceback_util import api_boundary from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map, @@ -60,12 +62,12 @@ def nothing_saveable(*_, **__) -> bool: def checkpoint_dots(prim, *_, **__) -> bool: # Matrix multiplies are expensive, so let's save them (and nothing else). - return prim in {lax.lax.dot_general_p, - lax.convolution.conv_general_dilated_p} + return prim in {lax_internal.dot_general_p, + lax_convolution.conv_general_dilated_p} def dot_with_no_batch_dims(prim, *_, **params) -> bool: # This is a useful heuristic for transformers. - if prim is lax.lax.dot_general_p: + if prim is lax_internal.dot_general_p: (_, _), (lhs_b, rhs_b) = params['dimension_numbers'] if not lhs_b and not rhs_b: return True @@ -439,8 +441,8 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): ad.primitive_jvps[remat_p] = remat_jvp remat_allowed_effects: Set[core.Effect] = set() -remat_allowed_effects.add(lax.lax.InOutFeedEffect.Infeed) -remat_allowed_effects.add(lax.lax.InOutFeedEffect.Outfeed) +remat_allowed_effects.add(lax_internal.InOutFeedEffect.Infeed) +remat_allowed_effects.add(lax_internal.InOutFeedEffect.Outfeed) def remat_partial_eval(trace, *tracers, jaxpr, **params): assert not jaxpr.constvars