Skip to content

Commit

Permalink
small tweaks for bint ad
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Aug 5, 2022
1 parent c5d4eb5 commit fbf6aa2
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 10 deletions.
19 changes: 10 additions & 9 deletions jax/_src/lax/lax.py
Expand Up @@ -1461,7 +1461,7 @@ def unop(result_dtype, accepted_dtypes, name):
prim = standard_primitive(_attrgetter('shape'), dtype_rule, name,
weak_type_rule=weak_type_rule)
batching.defvectorized(prim)
pe.padding_rules[prim] = lambda _, __, x, **kw: [prim.bind(x, **kw)]
pe.def_trivial_padding(prim)
return prim
standard_unop = partial(unop, _identity)
_attrgetter = lambda name: lambda x, **kwargs: getattr(x, name)
Expand Down Expand Up @@ -1536,7 +1536,7 @@ def naryop(result_dtype, accepted_dtypes, name):
prim = standard_primitive(shape_rule, dtype_rule, name,
weak_type_rule=weak_type_rule)
batching.defbroadcasting(prim)
pe.padding_rules[prim] = lambda _, __, *xs, **kw: [prim.bind(*xs, **kw)]
pe.def_trivial_padding(prim)
return prim
standard_naryop = partial(naryop, _input_dtype)

Expand Down Expand Up @@ -2015,7 +2015,7 @@ def _integer_pow_jvp(g, x, *, y):
_attrgetter('shape'), _integer_pow_dtype_rule, 'integer_pow')
batching.defvectorized(integer_pow_p)
ad.defjvp(integer_pow_p, _integer_pow_jvp)
pe.padding_rules[integer_pow_p] = lambda _, __, x, y: [integer_pow_p.bind(x, y=y)]
pe.def_trivial_padding(integer_pow_p)

def _integer_pow(x, *, y):
# This should be kept in sync with the jax2tf translation rule.
Expand Down Expand Up @@ -2927,7 +2927,7 @@ def _clamp_batch_rule(batched_args, batch_dims, **params):
else:
mlir.register_lowering(
clamp_p, partial(_nary_lower_mhlo, mhlo.ClampOp))
pe.padding_rules[clamp_p] = lambda _, __, a, x, b: [clamp(a, x, b)]
pe.def_trivial_padding(clamp_p)

def _concatenate_shape_rule(*operands, **kwargs):
dimension = kwargs.pop('dimension')
Expand Down Expand Up @@ -3285,16 +3285,17 @@ def _transpose_batch_rule(batched_args, batch_dims, *, permutation):
perm = (bdim,) + tuple(i if i < bdim else i+1 for i in permutation)
return transpose(operand, perm), 0

def _transpose_lower(ctx, x, *, permutation):
aval_out, = ctx.avals_out
return mhlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results

transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype,
'transpose')
ad.deflinear2(transpose_p,
lambda t, _, permutation: [transpose(t, np.argsort(permutation))]) # type: ignore[arg-type]
batching.primitive_batchers[transpose_p] = _transpose_batch_rule

def _transpose_lower(ctx, x, *, permutation):
aval_out, = ctx.avals_out
return mhlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results
mlir.register_lowering(transpose_p, _transpose_lower)
pe.def_trivial_padding(transpose_p)


def _select_shape_rule(which, *cases):
Expand Down Expand Up @@ -3386,7 +3387,6 @@ def _select_jvp(primals, tangents):
out_dot = select_n(which, *case_tangents)
return out, out_dot


def _select_mhlo_lowering(ctx, which, *cases):
which_aval = ctx.avals_in[0]
if which_aval.dtype == np.dtype(np.bool_):
Expand Down Expand Up @@ -3420,6 +3420,7 @@ def _select(offset, cases):
ad.primitive_transposes[select_n_p] = _select_transpose_rule
batching.primitive_batchers[select_n_p] = _select_batch_rule
mlir.register_lowering(select_n_p, _select_mhlo_lowering)
pe.def_trivial_padding(select_n_p)


def _reduce_shape_rule(*avals, computation, jaxpr, consts, dimensions):
Expand Down
8 changes: 7 additions & 1 deletion jax/core.py
Expand Up @@ -1417,7 +1417,7 @@ class DShapedArray(UnshapedArray):
shape: Tuple[AxisSize, ...] # noqa: F821
array_abstraction_level: int = 3

def __init__(self, shape, dtype, weak_type):
def __init__(self, shape, dtype, weak_type=False):
self.shape = shape
self.dtype = dtype
self.weak_type = weak_type
Expand Down Expand Up @@ -1474,6 +1474,7 @@ def __init__(self, shape, dtype, weak_type, val):
pytype_aval_mappings: Dict[type, Callable[[Any], AbstractValue]] = {}


# TODO(mattjj): remove this, replace with arrays of bints
class AbstractBInt(AbstractValue):
__slots__ = ['bound']
bound: int
Expand All @@ -1486,11 +1487,16 @@ def __eq__(self, other):
return type(other) is AbstractBInt and self.bound == other.bound
def __hash__(self) -> int:
return hash((type(self), self.bound))
def at_least_vspace(self):
return self # should return float0 array
def join(self, other):
return self

class BInt:
val: Any # Union[int, Array]
bound: int
def __init__(self, val, bound):
assert 0 <= val <= bound
self.val = val
self.bound = bound
def __repr__(self) -> str:
Expand Down
12 changes: 12 additions & 0 deletions jax/interpreters/partial_eval.py
Expand Up @@ -2497,6 +2497,18 @@ def _substitute_axis_sizes(env: Dict, aval: AbstractValue) -> AbstractValue:

padding_rules: Dict[Primitive, Callable] = {}

def def_trivial_padding(prim: Primitive) -> None:
if prim.multiple_results:
padding_rules[prim] = partial(_trivial_padding_rule_multi, prim)
else:
padding_rules[prim] = partial(_trivial_padding_rule, prim)

def _trivial_padding_rule(prim, _, __, *args, **params):
return [prim.bind(*args, **params)]

def _trivial_padding_rule_multi(prim, _, __, *args, **params):
return prim.bind(*args, **params)

def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params):
if call_jaxpr.constvars: raise NotImplementedError
padded_jaxpr, padded_consts = pad_jaxpr(call_jaxpr, ())
Expand Down
42 changes: 42 additions & 0 deletions tests/dynamic_api_test.py
Expand Up @@ -1121,6 +1121,48 @@ def loss_ref(params, batch):
expected = jax.grad(loss_ref)(params, batch1)
self.assertAllClose(ans, expected)

@jax.enable_checks(False) # TODO(mattjj): upgrade typecompat to handle bints
def test_mlp_autodiff_dynamic_batch_bint(self):
count = 0

def predict(params, inputs):
for W, b in params:
outputs = jnp.dot(inputs, W) + b
inputs = jnp.maximum(0, outputs)
return outputs

def loss_ref(params, batch):
nonlocal count
count += 1 # count traces
inputs, targets = batch
predictions = predict(params, inputs)
return jnp.sum((predictions - targets) ** 2)

loss = jax.jit(loss_ref, abstracted_axes=({}, {0: 'n'}))

params = [(jnp.ones((784, 256)), jnp.ones(256)),
(jnp.ones((256, 10)), jnp.ones( 10))]

# two different batch sizes *with bints*
bs1 = jax.lax.make_bint(128, 128)
batch1 = (jnp.ones((bs1, 784)), jnp.ones((bs1, 10)))

bs2 = jax.lax.make_bint(32, 128)
batch2 = (jnp.ones((bs2, 784)), jnp.ones((bs2, 10)))

# count retraces (and don't crash)
self.assertEqual(count, 0)
_ = jax.grad(loss)(params, batch1)
self.assertEqual(count, 1)
g2 = jax.grad(loss)(params, batch2)
self.assertEqual(count, 1) # cache hit!

# check the numbers make sense
batch = (jnp.ones((32, 784)), jnp.ones((32, 10)))
g2_expected = jax.grad(loss_ref)(params, batch)
self.assertAllClose(g2, g2_expected, check_dtypes=False,
atol=1e-3, rtol=1e-3)

def test_bint_basic(self):
d = lax.make_bint(3, 5)
self.assertEqual(str(d), '3{≤5}')
Expand Down

0 comments on commit fbf6aa2

Please sign in to comment.