Skip to content

Commit

Permalink
Fix lax imports
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Sep 27, 2022
1 parent 82636b0 commit 1d895b2
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions jax/_src/ad_checkpoint.py
Expand Up @@ -36,6 +36,8 @@
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.lax import lax as lax_internal
from jax._src.lax import convolution as lax_convolution
from jax._src.lib.mlir.dialects import mhlo
from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
Expand All @@ -60,12 +62,12 @@ def nothing_saveable(*_, **__) -> bool:

def checkpoint_dots(prim, *_, **__) -> bool:
# Matrix multiplies are expensive, so let's save them (and nothing else).
return prim in {lax.lax.dot_general_p,
lax.convolution.conv_general_dilated_p}
return prim in {lax_internal.dot_general_p,
lax_convolution.conv_general_dilated_p}

def dot_with_no_batch_dims(prim, *_, **params) -> bool:
# This is a useful heuristic for transformers.
if prim is lax.lax.dot_general_p:
if prim is lax_internal.dot_general_p:
(_, _), (lhs_b, rhs_b) = params['dimension_numbers']
if not lhs_b and not rhs_b:
return True
Expand Down Expand Up @@ -439,8 +441,8 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy):
ad.primitive_jvps[remat_p] = remat_jvp

remat_allowed_effects: Set[core.Effect] = set()
remat_allowed_effects.add(lax.lax.InOutFeedEffect.Infeed)
remat_allowed_effects.add(lax.lax.InOutFeedEffect.Outfeed)
remat_allowed_effects.add(lax_internal.InOutFeedEffect.Infeed)
remat_allowed_effects.add(lax_internal.InOutFeedEffect.Outfeed)

def remat_partial_eval(trace, *tracers, jaxpr, **params):
assert not jaxpr.constvars
Expand Down

0 comments on commit 1d895b2

Please sign in to comment.