Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 447549479
  • Loading branch information
mattjj authored and jax authors committed May 9, 2022
1 parent 4bc1c1c commit bb56f40
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 15 deletions.
2 changes: 1 addition & 1 deletion jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def ignore_errors_jaxpr(jaxpr, error):
payload_aval = core.raise_to_shaped(core.get_aval(error.payload))
consts = jaxpr.consts
jaxpr = jaxpr.jaxpr
new_vars = core.gensym()
new_vars = core.gensym([jaxpr])
new_invars = (new_vars(err_aval), new_vars(code_aval),
new_vars(payload_aval), *jaxpr.invars)
new_jaxpr = jaxpr.replace(invars=new_invars)
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/lax/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _initial_style_jaxprs_with_common_consts(
unzip3(_initial_style_open_jaxpr(fun, in_tree, in_avals, primitive_name)
for fun in funs)

newvar = core.gensym(suffix='_')
newvar = core.gensym(jaxprs, suffix='_')
all_const_avals = [[raise_to_shaped(core.get_aval(c)) for c in consts]
for consts in all_consts]
unused_const_vars = [[newvar(aval) for aval in const_avals]
Expand Down Expand Up @@ -451,7 +451,7 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
[body_nconsts, num_carry], [len(bconst_dot), len(init_dot)],
[num_carry], [len(init_dot)])

newvar = core.gensym()
newvar = core.gensym([cond_jaxpr.jaxpr])
invars_aug = (
cond_jaxpr.jaxpr.invars + [newvar(core.get_aval(x)) for x in init_dot])
cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars,
Expand Down Expand Up @@ -1147,7 +1147,7 @@ def f_aug(*args):
# that it does not read.
def _join_cond_pe_staged_jaxpr_inputs(jaxprs, all_res_avals,
res_aval_indices_per_jaxpr):
newvar = core.gensym(suffix='_')
newvar = core.gensym([j.jaxpr for j in jaxprs], suffix='_')
all_res_vars = _map(newvar, all_res_avals)

def augment_jaxpr(jaxpr, res_indices):
Expand Down
41 changes: 35 additions & 6 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ class JaxprEqn(NamedTuple):
effects: Effects
source_info: source_info_util.SourceInfo

def __repr__(self):
return str(pp_eqn(self, JaxprPpContext(), JaxprPpSettings())).rstrip()

def replace(self, *args, **kwargs):
return self._replace(*args, **kwargs)

Expand All @@ -210,18 +213,27 @@ def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None)
return JaxprEqn(invars, outvars, primitive, params, effects, source_info)


@total_ordering
class Var:
# TODO(frostig,mattjj): We don't override __eq__ or __hash__, so comparison is
# by object id, but pretty printing might collide.
count: int
suffix: str
aval: AbstractValue

def __init__(self, suffix: str, aval: AbstractValue):
def __init__(self, count: int, suffix: str, aval: AbstractValue):
self.count = count
self.suffix = suffix
self.aval = raise_to_shaped(aval)

def __lt__(self, other):
if not isinstance(other, Var):
return NotImplemented
else:
return (self.count, self.suffix) < (other.count, other.suffix)

def __repr__(self):
return f'Var(id={id(self)}){self.suffix}:{self.aval.str_short()}'
return _encode_digits_alphabetic(self.count) + self.suffix

def _encode_digits_alphabetic(n):
if n == -1:
Expand All @@ -232,16 +244,33 @@ def _encode_digits_alphabetic(n):
s = chr(97 + i % 26) + s
return s

def gensym(suffix: str = '') -> Callable[[AbstractValue], Var]:
return functools.partial(Var, suffix)
def _jaxpr_vars(jaxpr):
return it.chain(
jaxpr.invars, jaxpr.constvars,
(v for eqn in jaxpr.eqns for v in eqn.outvars))

def gensym(jaxprs: Optional[Sequence[Jaxpr]] = None,
suffix: str = '') -> Callable[[AbstractValue], Var]:
"""Produce distinct variables, printed with the optional suffix.
If `jaxprs` is provided, the variables produced will be distinct from those in
any of the given jaxprs.
"""
if jaxprs is None:
start = 0
else:
all_vars = it.chain.from_iterable(_jaxpr_vars(j) for j in jaxprs)
start = 1 + max((v.count for v in all_vars), default=-1)
counter = it.count(start=start)
return lambda aval: Var(next(counter), suffix, aval)

# In a jaxpr, `dropvar` can appear in place of a bound variable to indicate that
# the assignment is dropped, i.e. that an expression's output value will never
# be read. In that sense, `dropvar` is not a variable, but it is convenient to
# treat it as a special case of one. Its `aval` is similarly inexact.
class DropVar(Var):
def __init__(self, aval: AbstractValue):
super().__init__('', aval)
super().__init__(-1, '', aval)
def __repr__(self): return '_'

class Literal:
Expand Down Expand Up @@ -2046,7 +2075,7 @@ def subst_axis_names_var(v: Var, subst: AxisSubst, var_map: Dict[Var, Var]) -> V
named_shape = {name: axis_frame(name).size for name in names}
if len(named_shape) != len(names):
raise DuplicateAxisNameError(v)
new_v = Var(v.suffix, v.aval.update(named_shape=named_shape))
new_v = Var(v.count, v.suffix, v.aval.update(named_shape=named_shape))
var_map[v] = new_v
return new_v

Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/host_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,7 +1500,7 @@ def _rewrite_jaxpr(jaxpr: core.Jaxpr, has_input_token: bool,
if not has_input_token and not core.jaxpr_uses_outfeed(jaxpr):
return jaxpr

mk_new_var = core.gensym()
mk_new_var = core.gensym([jaxpr])

eqns: List[core.JaxprEqn] = []
# store the incoming tokens
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,7 +1833,7 @@ def _fix_inferred_spmd_sharding(jaxpr, resource_env, gen_fresh_name = None):
return jaxpr.map_jaxpr(rec)
assert isinstance(jaxpr, core.Jaxpr)
if gen_fresh_name is None:
gen_fresh_name = core.gensym()
gen_fresh_name = core.gensym([jaxpr])
new_eqns = []
for eqn in jaxpr.eqns:
new_jaxpr_params = core.traverse_jaxpr_params(rec, eqn.params)
Expand Down
2 changes: 1 addition & 1 deletion jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,7 @@ def call_partial_eval_custom_rule(
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
_, out_binders_staged = partition_list(inst_out, eqn.outvars)
kept_outs_staged = inst_out
newvar = core.gensym()
newvar = core.gensym([jaxpr_known, jaxpr_staged])
residuals = [newvar(v.aval) for v in jaxpr_staged.invars[:num_res]]
params_known = {**eqn.params, jaxpr_param_name: jaxpr_known}
params_staged = {**eqn.params, jaxpr_param_name: jaxpr_staged}
Expand Down
38 changes: 36 additions & 2 deletions tests/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections import namedtuple
from functools import partial
import gc
import itertools as it
import operator

import numpy as np
Expand All @@ -29,7 +30,8 @@
from jax import linear_util as lu
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.core import UnshapedArray, ShapedArray
from jax.tree_util import tree_flatten, tree_unflatten, tree_map, tree_reduce
from jax.tree_util import (tree_flatten, tree_unflatten, tree_map, tree_reduce,
tree_leaves)
from jax.interpreters import partial_eval as pe

from jax._src import test_util as jtu
Expand Down Expand Up @@ -297,6 +299,38 @@ def f(x):
finally:
gc.set_debug(debug)

def test_comparing_var(self):
newsym = core.gensym()
a = newsym(core.ShapedArray((), np.dtype('int32')))
b = newsym(core.ShapedArray((), np.dtype('int32')))
c = newsym(core.ShapedArray((), np.dtype('int32')))
assert a < b < c
assert c > b > a
assert a != b and b != c and a != c

def test_var_ordering(self):
newsym = core.gensym()
a = newsym(core.ShapedArray((), np.dtype('int32')))
b = newsym(core.ShapedArray((), np.dtype('int32')))
c = newsym(core.ShapedArray((), np.dtype('int32')))
for ordering in it.permutations([a, b, c]):
assert sorted(list(ordering)) == [a, b, c]

def test_var_compared_by_identity(self):
a1 = core.gensym()(core.ShapedArray((), np.dtype('int32')))
a2 = core.gensym()(core.ShapedArray((), np.dtype('int32')))
assert str(a1) == str(a2)
assert a1 != a2

def test_var_tree_flatten(self):
newsym = core.gensym()
aval = core.ShapedArray((), np.dtype('int32'))
a, b, c, d = (
newsym(aval), newsym(aval),
newsym(aval), newsym(aval))
syms = {c: d, a: b}
assert 'bd' == ''.join(map(str, tree_leaves(syms)))

def test_concrete_array_string_representation(self):
# https://github.com/google/jax/issues/5364
self.assertEqual(
Expand Down Expand Up @@ -457,7 +491,7 @@ def f(x):
def test_jaxpr_undefined_eqn_invar(self):
jaxpr = make_jaxpr(lambda x: jnp.sin(x) + jnp.cos(x))(1.).jaxpr
cos = next(eqn for eqn in jaxpr.eqns if eqn.primitive.name == 'cos')
cos.invars[0] = core.gensym(suffix='_test')(cos.invars[0].aval)
cos.invars[0] = core.gensym([jaxpr], suffix='_test')(cos.invars[0].aval)
self.assertRaisesRegex(
core.JaxprTypeError,
r"Variable '.+_test' not defined\n\nin equation:",
Expand Down

0 comments on commit bb56f40

Please sign in to comment.