Skip to content

Commit

Permalink
delete old remat implementation
Browse files Browse the repository at this point in the history
moved lowering rule logic from remat_impl.py (now deleted) to ad_checkpoint.py
  • Loading branch information
mattjj committed Aug 17, 2022
1 parent 332d7d0 commit d19e34f
Show file tree
Hide file tree
Showing 14 changed files with 131 additions and 424 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Expand Up @@ -10,6 +10,10 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.

## jax 0.3.17 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...main).
* Breaking changes
* {func}`jax.checkpoint`, also known as {func}`jax.remat`, no longer supports
the `concrete` option, following the previous version's deprecation; see
[JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html).

## jax 0.3.16
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main).
Expand Down
8 changes: 5 additions & 3 deletions docs/jep/11830-new-remat-checkpoint.md
Expand Up @@ -19,16 +19,18 @@ As of [#11830](https://github.com/google/jax/pull/11830) we're switching on a ne

## How can I disable the change, and go back to the old behavior for now?

In case you have a problem with this change, it will **temporarily** be possible to switch off the new implementation by setting the `jax_new_checkpoint` config option to be False, in any one of these ways:


In case you have a problem with this change, **through version `jax==0.3.16`** it is possible to switch off the new implementation by setting the `jax_new_checkpoint` config option to be False, in any one of these ways:

1. set the shell environment variable `JAX_NEW_CHECKPOINT=0`;
2. execute `jax.config.update('jax_new_checkpoint', False)`;
3. if you parse flags with `absl`, pass the `--jax_new_checkpoint=False` option.

If you need to revert to the old implementation, **please reach out** on a GitHub issue so that we can make the new implementation work for you.

As of `jax==0.3.17` the `jax_new_checkpoint` config option is no longer
available. If you have an issue, please reach out on [the issue
tracker](https://github.com/google/jax/issues) so we can help fix it!


## Why are we doing this?

Expand Down
107 changes: 103 additions & 4 deletions jax/_src/ad_checkpoint.py
Expand Up @@ -18,29 +18,30 @@
import types

from absl import logging
import numpy as np

import jax
from jax import core
from jax import linear_util as lu
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
from jax._src import util
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.lib.mlir.dialects import mhlo
from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
safe_zip, merge_lists, weakref_lru_cache)

source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)

# TODO(mattjj): before this can be the standard remat implementation, we must:
# [ ] fix up callers who use the 'concrete' option (now removed)

map = safe_map
zip = safe_zip

Expand Down Expand Up @@ -582,6 +583,104 @@ def remat_dce(used_outputs: List[bool], eqn: core.JaxprEqn
pe.dce_rules[remat_p] = remat_dce


def remat_lowering(*args, jaxpr: core.Jaxpr, prevent_cse: bool,
differentiated: bool, is_gpu_platform: bool = False,
**_):
assert not jaxpr.constvars

if differentiated and prevent_cse:
if jax.config.jax_remat_opt_barrier:
translation_rule = _remat_translation_using_opt_barrier
elif is_gpu_platform:
translation_rule = _remat_translation_using_while
else:
translation_rule = _remat_translation_using_cond
else:
translation_rule = lambda *args, jaxpr: core.eval_jaxpr(jaxpr, (), *args)

return jax.named_call(translation_rule, name="remat")(*args, jaxpr=jaxpr)

def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr):
args = _optimization_barrier(args)
return core.eval_jaxpr(jaxpr, (), *args)

# TODO(mattjj): add core utility for 'create dummy value for this type'?
def _dummy_like(aval: core.AbstractValue) -> Any:
if aval is core.abstract_token:
return jax.lax.create_token()
elif isinstance(aval, (core.ShapedArray, core.DShapedArray)):
return jax.lax.broadcast(jax.lax.empty(aval.dtype), aval.shape) # type: ignore
else:
raise ValueError(aval)

def _remat_translation_using_while(*args, jaxpr: core.Jaxpr):
# Implements:
# for(counter=0, result=0; counter < rng(1, 2); counter ++) {
# result = eval_jaxpr(*args)
# }
# The loop carry is a tuple: (counter, result, args)
avals_out = tuple(v.aval for v in jaxpr.outvars)
carry_init = (np.int32(0), tuple(map(_dummy_like, avals_out)), args)
def cond(carry):
counter, _, _ = carry
unif = jax.lax.rng_uniform(np.int32(1), np.int32(2), shape=())
return counter < unif

def body(carry):
counter, _, args = carry
results = core.eval_jaxpr(jaxpr, (), *args)
return (counter + 1, tuple(results), args)

carry_res = jax.lax.while_loop(cond, body, carry_init)
return carry_res[1]

def _remat_translation_using_cond(*args, jaxpr: core.Jaxpr):
# Implements:
# if(rng(0, 1) < 2)
# return eval_jaxpr(*args)
# else:
# return 0
avals_out = tuple(v.aval for v in jaxpr.outvars)

def remat_comp(*args):
return tuple(core.eval_jaxpr(jaxpr, (), *args))
def dummy_comp(*args):
return tuple(map(_dummy_like, avals_out))

unif = jax.lax.rng_uniform(np.float32(0), np.float32(1), shape=())
return jax.lax.cond(unif < np.float32(2), remat_comp, dummy_comp, *args)

mlir.register_lowering(
remat_p, mlir.lower_fun(remat_lowering, multiple_results=True))
mlir.register_lowering(
remat_p,
mlir.lower_fun(partial(remat_lowering, is_gpu_platform=True),
multiple_results=True),
platform="gpu")

def _optimization_barrier_abstract_eval(*args):
return args

def _optimization_barrier_lowering_rule(ctx, *args):
barrier_types = map(mlir.aval_to_ir_types, ctx.avals_in)
flat_barrier_types = util.flatten(barrier_types)
flat_args = mlir.flatten_lowering_ir_args(args)
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
return util.unflatten(barrier_op.results, map(len, barrier_types))

def _optimization_barrier(arg):
flat_args, treedef = tree_flatten(arg)
return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args))

optimization_barrier_p = core.Primitive('optimization_barrier')
optimization_barrier_p.multiple_results = True
optimization_barrier_p.def_impl(
partial(xla.apply_primitive, optimization_barrier_p))
optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval)
mlir.register_lowering(optimization_barrier_p,
_optimization_barrier_lowering_rule)


def checkpoint_name(x, name):
return name_p.bind(x, name=name)

Expand Down
26 changes: 5 additions & 21 deletions jax/_src/api.py
Expand Up @@ -3103,27 +3103,11 @@ def checkpoint(fun: Callable, *,
" return f(x)\n"
" else:\n"
" return g(x)\n"
"\n")
if config.jax_new_checkpoint:
raise NotImplementedError(msg)
else:
warn(msg, DeprecationWarning)

if config.jax_new_checkpoint:
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
static_argnums=static_argnums)

@wraps(fun)
@api_boundary
def remat_f(*args, **kwargs):
f, args = _remat_static_argnums(fun, static_argnums, args)
args_flat, in_tree = tree_flatten((args, kwargs))
flat_fun, out_tree = flatten_fun(lu.wrap_init(f), in_tree)
out_flat = pe.remat_call(flat_fun, *args_flat, name=flat_fun.__name__,
concrete=concrete, prevent_cse=prevent_cse,
differentiated=False, policy=policy)
return tree_unflatten(out_tree(), out_flat)
return remat_f
"\n"
"See https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html\n")
raise NotImplementedError(msg)
return new_checkpoint(fun, prevent_cse=prevent_cse, policy=policy,
static_argnums=static_argnums)
remat = checkpoint # type: ignore


Expand Down
7 changes: 0 additions & 7 deletions jax/_src/config.py
Expand Up @@ -875,13 +875,6 @@ def _update_disable_jit_thread_local(val):
default=(lib.version >= (0, 3, 6)),
help=('Enables using optimization-barrier op for lowering remat.'))

# TODO(mattjj): set default to True, then remove
config.define_bool_state(
name='jax_new_checkpoint',
default=True,
upgrade=True,
help='Whether to use the new jax.checkpoint implementation.')

# TODO(b/205307544): Remove flag once coordination service has rolled out.
config.define_bool_state(
name='jax_coordination_service',
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow/__init__.py
Expand Up @@ -19,8 +19,6 @@
scan, scan_bind, scan_p,
_scan_impl, while_loop, while_p)
from jax._src.lax.control_flow.conditionals import cond, cond_p, switch
from jax._src.lax.control_flow.remat_impl import (remat_impl,
optimization_barrier_p)
from jax._src.lax.control_flow.solves import (custom_linear_solve, custom_root,
_custom_linear_solve_impl,
linear_solve_p)
Expand All @@ -32,3 +30,5 @@
_initial_style_jaxpr,
_initial_style_jaxprs_with_common_consts,
_check_tree_and_avals)
# TODO(mattjj): fix dependent library which expects optimization_barrier_p here
from jax._src.ad_checkpoint import optimization_barrier_p
152 changes: 0 additions & 152 deletions jax/_src/lax/control_flow/remat_impl.py

This file was deleted.

0 comments on commit d19e34f

Please sign in to comment.