Skip to content

Commit

Permalink
[dynamic-shapes] Expand the handling of dynamic shapes for reshape an…
Browse files Browse the repository at this point in the history
…d iota.

Also add more tests.
  • Loading branch information
gnecula committed Jul 5, 2022
1 parent 5d6f81c commit 5983d38
Show file tree
Hide file tree
Showing 8 changed files with 276 additions and 52 deletions.
2 changes: 1 addition & 1 deletion jax/_src/config.py
Expand Up @@ -840,7 +840,7 @@ def _update_disable_jit_thread_local(val):
# if the intended backend can handle lowering the result
config.define_bool_state(
name='jax_dynamic_shapes',
default=False,
default=bool(os.getenv('JAX_DYNAMIC_SHAPES', '')),
help=('Enables experimental features for staging out computations with '
'dynamic shapes.'),
update_global_hook=lambda val: \
Expand Down
1 change: 1 addition & 0 deletions jax/_src/dispatch.py
Expand Up @@ -302,6 +302,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
shape=tuple(expected_shape), dtype=expected_type.dtype,
weak_type=expected_type.weak_type)
assert core.typematch(expected_aval, aval)

with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
Expand Down
162 changes: 120 additions & 42 deletions jax/_src/lax/lax.py
Expand Up @@ -18,8 +18,8 @@
from functools import partial
import itertools
import operator
from typing import (Any, Callable, Optional, Sequence, Tuple, List, TypeVar,
Union, cast as type_cast)
from typing import (Any, Callable, Dict, Optional, Sequence, Tuple,
List, TypeVar, Union, cast as type_cast)
import warnings

import numpy as np
Expand Down Expand Up @@ -156,6 +156,60 @@ def _broadcast_ranks(s1, s2):

def _identity(x): return x

def _extract_tracers_dyn_shape(shape: Sequence[Union[int, core.Tracer]]
) -> Tuple[Sequence[core.Tracer],
Sequence[Optional[int]]]:
"""Returns the list of tracers in `shape`, and a static versio of `shape`
with tracers replaced with None"""
if config.jax_dynamic_shapes:
# We must gate this behavior under a flag because otherwise the errors
# raised are different (and have worse source provenance information).
dyn_shape = [d for d in shape if isinstance(d, core.Tracer)]
static_shape = [d if not isinstance(d, core.Tracer) else None for d in shape]
return dyn_shape, static_shape
else:
return [], shape


def _merge_dyn_shape(static_shape: Sequence[Optional[int]],
dyn_shape: Sequence[mlir.Value],
) -> Sequence[mlir.Value]:
"""Returns static_shape with None values filled in from dyn_shape."""
dyn_shape = iter(dyn_shape)
shape = [next(dyn_shape) if d is None else d for d in static_shape]
assert next(dyn_shape, None) is None
return shape

def _stage_with_dyn_shape(trace: core.Trace,
prim: core.Primitive,
args: Sequence[core.Tracer],
dyn_shape_args: Sequence[core.Tracer],
params: Dict[str, Any],
static_shape: Sequence[Optional[int]],
out_dtype: core.Type,
out_weak_type: bool,
) -> core.Tracer:
"""Stages out a primitive that takes dynamic shapes.
dyn_shape_args are the tracers corresponding to the None values in static_shape.
"""
if not dyn_shape_args:
return trace.default_process_primitive(prim, args, params)
assert len(dyn_shape_args) == sum(d is None for d in static_shape)
source_info = source_info_util.current()

ds = iter(dyn_shape_args)
out_shape_for_tracer: List[Union[int, core.Tracer]] = [
next(ds) if d is None else d for d in static_shape]
aval = core.DShapedArray(tuple(out_shape_for_tracer), out_dtype, out_weak_type)
out_tracer = pe.DynamicJaxprTracer(trace, aval, source_info)
invars = [*(trace.getvar(x) for x in args), *(trace.getvar(d) for d in dyn_shape_args)]
eqn = pe.new_jaxpr_eqn(invars, [trace.makevar(out_tracer)],
prim, params, core.no_effects, source_info)
trace.frame.eqns.append(eqn)

return out_tracer

### traceables

def neg(x: Array) -> Array:
Expand Down Expand Up @@ -740,16 +794,9 @@ def broadcast_in_dim(operand: Array, shape: Shape,
if (np.ndim(operand) == len(shape) and not len(broadcast_dimensions)
and isinstance(operand, (device_array.DeviceArray, core.Tracer))):
return operand
if config.jax_dynamic_shapes:
# We must gate this behavior under a flag because otherwise the errors
# raised are different (and have worse source provenance information).
dyn_shape = [d for d in shape if isinstance(d, core.Tracer)]
shape_ = [d if not isinstance(d, core.Tracer) else None for d in shape]
else:
dyn_shape = []
shape_ = shape # type: ignore
dyn_shape, static_shape = _extract_tracers_dyn_shape(shape)
return broadcast_in_dim_p.bind(
operand, *dyn_shape, shape=tuple(shape_),
operand, *dyn_shape, shape=tuple(static_shape),
broadcast_dimensions=tuple(broadcast_dimensions))

def broadcast_to_rank(x: Array, rank: int) -> Array:
Expand Down Expand Up @@ -808,8 +855,10 @@ def reshape(operand: Array, new_sizes: Shape,
and isinstance(operand, (core.Tracer, device_array.DeviceArray))):
return operand
else:
dyn_shape, static_new_sizes = _extract_tracers_dyn_shape(new_sizes)

return reshape_p.bind(
operand, new_sizes=new_sizes,
operand, *dyn_shape, new_sizes=static_new_sizes,
dimensions=None if dims is None or same_dims else dims)

def pad(operand: Array, padding_value: Array,
Expand Down Expand Up @@ -1118,7 +1167,8 @@ def iota(dtype: DType, size: int) -> Array:
"""
dtype = dtypes.canonicalize_dtype(dtype)
size, = canonicalize_shape((size,))
return iota_p.bind(dtype=dtype, shape=(size,), dimension=0)
dyn_shape, static_shape = _extract_tracers_dyn_shape((size,))
return iota_p.bind(*dyn_shape, dtype=dtype, shape=static_shape, dimension=0)

def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array:
"""Convenience wrapper around ``iota``."""
Expand Down Expand Up @@ -2638,9 +2688,7 @@ def _broadcast_in_dim_typecheck_rule(
return [out_aval], effects
else:
# TODO(mattjj): perform more checks like _broadcast_in_dim_shape_rule
dyn_shape_ = iter(dyn_shape)
out_shape = [next(dyn_shape_) if d is None else d for d in shape]
assert next(dyn_shape_, None) is None
out_shape = _merge_dyn_shape(shape, dyn_shape)
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape]
out_aval = core.DShapedArray(tuple(out_shape), operand.aval.dtype,
operand.aval.weak_type)
Expand Down Expand Up @@ -2676,22 +2724,9 @@ def _broadcast_in_dim_fwd_rule(eqn):
def _broadcast_in_dim_staging_rule(
trace, x, *dyn_shape, shape, broadcast_dimensions):
params = dict(shape=shape, broadcast_dimensions=broadcast_dimensions)
if not dyn_shape:
return trace.default_process_primitive(broadcast_in_dim_p, (x,), params)
assert len(dyn_shape) == sum(d is None for d in shape)
source_info = source_info_util.current()

ds = iter(dyn_shape)
out_shape_for_tracer: List[Union[int, core.Tracer]] = [
next(ds) if d is None else d for d in shape]
aval = core.DShapedArray(tuple(out_shape_for_tracer), x.dtype, x.weak_type)
out_tracer = pe.DynamicJaxprTracer(trace, aval, source_info)
invars = [trace.getvar(x), *(trace.getvar(d) for d in dyn_shape)]
eqn = pe.new_jaxpr_eqn(invars, [trace.makevar(out_tracer)],
broadcast_in_dim_p, params, core.no_effects, source_info)
trace.frame.eqns.append(eqn)

return out_tracer
return _stage_with_dyn_shape(trace, broadcast_in_dim_p,
(x,), dyn_shape, params,
shape, x.dtype, x.weak_type)

def _broadcast_in_dim_padding_rule(in_avals, out_avals, x, *dyn_shape,
shape, broadcast_dimensions):
Expand Down Expand Up @@ -2746,9 +2781,7 @@ def _broadcast_in_dim_partial_eval(
def _broadcast_in_dim_lower(ctx, x, *dyn_shape, shape, broadcast_dimensions):
aval_out, = ctx.avals_out
if dyn_shape:
dyn_shape = iter(dyn_shape)
shape = [next(dyn_shape) if d is None else d for d in shape]
assert next(dyn_shape, None) is None
shape = _merge_dyn_shape(shape, dyn_shape)
return mhlo.DynamicBroadcastInDimOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.shape_tensor(shape),
Expand Down Expand Up @@ -3115,7 +3148,8 @@ def _reshape_shape_rule(operand, *, new_sizes, dimensions):
if not all(core.greater_equal_dim(d, 0) for d in new_sizes):
msg = 'reshape new_sizes must all be positive, got {}.'
raise TypeError(msg.format(new_sizes))
if not core.same_shape_sizes(np.shape(operand), new_sizes):
# TODO(necula): re-enable this check
if not config.jax_dynamic_shapes and not core.same_shape_sizes(np.shape(operand), new_sizes):
msg = 'reshape total size must be unchanged, got new_sizes {} for shape {}.'
raise TypeError(msg.format(new_sizes, np.shape(operand)))
if dimensions is not None:
Expand Down Expand Up @@ -3168,13 +3202,30 @@ def merge_const_sizes(shape):
batching.primitive_batchers[reshape_p] = _reshape_batch_rule
masking.masking_rules[reshape_p] = _reshape_masking_rule

def _reshape_lower(ctx, x, *, new_sizes, dimensions):
def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions):
aval_out, = ctx.avals_out
if dimensions is not None:
x = mhlo.TransposeOp(x, mlir.dense_int_elements(dimensions)).result
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), x).results
if dyn_shape:
shape = _merge_dyn_shape(new_sizes, dyn_shape)
return mhlo.DynamicReshapeOp(
mlir.aval_to_ir_type(aval_out), x,
mlir.shape_tensor(shape),
).results
else:
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), x).results

mlir.register_lowering(reshape_p, _reshape_lower)

def _reshape_staging_rule(
trace, x, *dyn_shape, new_sizes, dimensions):
params = dict(new_sizes=new_sizes, dimensions=dimensions)
# TODO(necula): shouldn't this include the same checks as in reshape_shape_rule?
return _stage_with_dyn_shape(trace, reshape_p, (x,), dyn_shape, params,
new_sizes, x.dtype, x.weak_type)

pe.custom_staging_rules[reshape_p] = _reshape_staging_rule

def _rev_shape_rule(operand, *, dimensions):
_check_shapelike('rev', 'dimensions', dimensions)
if len(set(dimensions)) != len(dimensions):
Expand Down Expand Up @@ -4325,11 +4376,38 @@ def _iota_abstract_eval(*, dtype, shape, dimension):
iota_p.def_impl(partial(xla.apply_primitive, iota_p))
iota_p.def_abstract_eval(_iota_abstract_eval)

def _iota_lower(ctx, *, dtype, shape, dimension):
del dtype, shape
def _iota_staging_rule(
trace, *dyn_shape, dtype, shape, dimension):
params = dict(dtype=dtype, shape=shape, dimension=dimension)
return _stage_with_dyn_shape(trace, iota_p, (), dyn_shape, params,
shape, dtype, False)
pe.custom_staging_rules[iota_p] = _iota_staging_rule

def _iota_typecheck_rule(*dyn_shape, dtype, shape, dimension):
if not dyn_shape:
out_aval, effects = iota_p.abstract_eval(
dtype=dtype, shape=shape, dimension=dimension)
return [out_aval], effects
else:
out_shape = _merge_dyn_shape(shape, dyn_shape)
out_shape = [x.val if type(x) is core.Literal else x for x in out_shape]
out_aval = core.DShapedArray(tuple(out_shape), dtype, False)
return [out_aval], core.no_effects
core.custom_typechecks[iota_p] = _iota_typecheck_rule

def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension):
del dtype
aval_out, = ctx.avals_out
return mhlo.IotaOp(mlir.aval_to_ir_type(aval_out),
mlir.i64_attr(dimension)).results
if dyn_shape:
shape = _merge_dyn_shape(shape, dyn_shape)
return mhlo.DynamicIotaOp(
mlir.aval_to_ir_type(aval_out),
mlir.shape_tensor(shape),
mlir.i64_attr(dimension),
).results
else:
return mhlo.IotaOp(mlir.aval_to_ir_type(aval_out),
mlir.i64_attr(dimension)).results
mlir.register_lowering(iota_p, _iota_lower)


Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/lax_numpy.py
Expand Up @@ -2094,7 +2094,7 @@ def arange(start: core.DimSize, stop: Optional[core.DimSize]=None,
raise ValueError(
"jax.numpy.arange supports non-constant arguments only in single-argument form. "
f"Found jax.numpy.arange(start={start}, stop={stop}, step={step})")
return lax.iota(int_, start)
return lax.iota(dtype or int_, start)
if dtype is None:
dtype = result_type(start, *(x for x in [stop, step] if x is not None))
dtype = _jnp_dtype(dtype)
Expand Down
4 changes: 0 additions & 4 deletions jax/core.py
Expand Up @@ -1653,10 +1653,6 @@ def symbolic_equal_shape(s1: Shape, s2: Shape) -> bool:
all(unsafe_map(symbolic_equal_dim, s1, s2)))

def greater_equal_dim(d1: DimSize, d2: DimSize) -> bool:
# TODO(mattjj): revise this temporary workaround for dynamic shapes
if isinstance(d1, Tracer) or isinstance(d2, Tracer):
return True

handler, ds = _dim_handler_and_canonical(d1, d2)
return handler.greater_equal(*ds)

Expand Down
2 changes: 2 additions & 0 deletions jax/interpreters/mlir.py
Expand Up @@ -54,6 +54,8 @@

T = typing.TypeVar("T")

Value = ir.Value

# mypy implicitly sets this variable to true when type checking.
MYPY = False

Expand Down
38 changes: 38 additions & 0 deletions jax/interpreters/partial_eval.py
Expand Up @@ -2365,6 +2365,44 @@ def _substitute_vars_in_type(
else:
return a


class DimensionHandlerTracer(core.DimensionHandler):
"""See core.DimensionHandler.
Most methods are inherited.
"""
def is_constant(self, d: core.DimSize) -> bool:
assert isinstance(d, Tracer)
return False

def symbolic_equal(self, d1: core.DimSize, d2: core.DimSize) -> bool:
return d1 is d2

def greater_equal(self, d1: core.DimSize, d2: core.DimSize):
raise core.InconclusiveDimensionOperation("TODO")

def divide_shape_sizes(self, s1: core.Shape, s2: core.Shape) -> core.DimSize:
"""Computes integer "i" such that i * size(s2) == size(s1).
Raise InconclusiveDimensionOperation if there is no such integer for all
contexts.
"""
s1_size = functools.reduce(op.mul, s1, 1)
s2_size = functools.reduce(op.mul, s2, 1)
q, r = divmod(s1_size, s2_size)
# TODO(necula): must check that r == 0!
return q

def stride(self, d: core.DimSize, window_size: core.DimSize, window_stride: core.DimSize) -> core.DimSize:
"""Implements `(d - window_size) // window_stride + 1`"""
raise core.InconclusiveDimensionOperation("TODO")

def as_value(self, d: core.DimSize):
"""Turns a dimension size into a Jax value that we can compute with."""
raise core.InconclusiveDimensionOperation("TODO")

core._SPECIAL_DIMENSION_HANDLERS[DynamicJaxprTracer] = DimensionHandlerTracer()

Const = Any
Val = Any

Expand Down

0 comments on commit 5983d38

Please sign in to comment.