Skip to content

Commit

Permalink
Reapply #2017 (Allow shapecheck of PixelCNN++), fixing #2245 (#2800)
Browse files Browse the repository at this point in the history
* Unrevert "Allow shapecheck of PixelCNN++ (#2017)"

This reverts commit ceab1e3.

* Fix out-of-bound slices (#2245)

* Minor

* Add type annotations

* Fix Poly.__rsub__

* any -> _any

* tweaks, mostly comments/whitespace

* separate polymorphic code path, patch _slice_sizes

* put back some logic for handling Poly sizes

* improve test_slice_indices

* Remove to_index, replace with canonicalize_shape

* Fix slicing with polymorphic start/stop

* Test negative step for polymorphic slicing

* Refactor polymorphic slicing

* Simplify diff

* Fix shapecheck(iota)

Co-authored-by: Matthew Johnson <mattjj@google.com>
  • Loading branch information
juliuskunze and mattjj committed May 1, 2020
1 parent 1b56428 commit c00e9a2
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 300 deletions.
26 changes: 11 additions & 15 deletions jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@
from .lib.xla_bridge import (device_count, local_device_count, devices, local_devices,
host_id, host_ids, host_count)
from .abstract_arrays import ConcreteArray, ShapedArray, raise_to_shaped
from .interpreters.masking import eval_polymorphic_shape, Poly, Mon
from .interpreters import partial_eval as pe
from .interpreters import xla
from .interpreters import pxla
from .interpreters import ad
from .interpreters import batching
from .interpreters import parallel
from .interpreters import masking
from .interpreters.masking import ensure_poly
from .custom_derivatives import custom_jvp, custom_vjp
from .config import flags, config, bool_env

Expand Down Expand Up @@ -1175,24 +1175,23 @@ def wrapped_fun(args, logical_env):
out_shapes = map(masking.finalize_spec, out_specs, map(onp.shape, outs))
if not out_shapes == list(out_shapes_):
raise masking.ShapeError
if not all(onp.shape(out) == masking.eval_shape_expr(padded_env, expr)
for out, expr in zip(outs, out_shapes)):
if not all(onp.shape(out) == eval_polymorphic_shape(shape, padded_env)
for out, shape in zip(outs, out_shapes)):
raise masking.ShapeError
return tree_unflatten(out_tree(), outs)
return wrapped_fun

def _remap_ids(names, shape_spec):
ShapeSpec, Poly, Mon = masking.ShapeSpec, masking.Poly, masking.Mon
mdim = masking.monomorphic_dim
return ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
return masking.ShapeSpec(Poly({Mon({names[id] : deg for id, deg in mon.items()})
: coeff for mon, coeff in poly.items()})
if poly is not mdim else mdim for poly in shape_spec)
if poly is not masking._monomorphic_dim else
masking._monomorphic_dim for poly in shape_spec)

def _bind_shapes(shape_exprs, shapes):
env = {}
for shape_expr, shape in zip(shape_exprs, shapes):
for poly, d in zip(shape_expr, shape):
if ensure_poly(poly).is_constant:
if type(poly) is not Poly or poly.is_constant:
continue
else:
(binder,), = poly # TODO generalize to handle striding
Expand All @@ -1201,23 +1200,20 @@ def _bind_shapes(shape_exprs, shapes):


@curry
def shapecheck(in_shapes, out_shape, fun):
def shapecheck(in_shapes, out_shape, fun: Callable):
_check_callable(fun)
in_shapes, in_tree = tree_flatten(in_shapes)
in_shapes = map(masking.parse_spec, in_shapes)
out_shapes, out_tree = tree_flatten(out_shape)
out_shapes = map(masking.parse_spec, out_shapes)
flat_fun, out_tree_ = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
out_shapes_ = masking.shapecheck(flat_fun, in_shapes)
avals = map(partial(ShapedArray, dtype=onp.float32), in_shapes)
out_shapes_ = [o.shape for o in pe.abstract_eval_fun(flat_fun.call_wrapped, *avals)]
if out_tree != out_tree_(): raise TypeError("pytree mismatch")
if not all(map(_shape_spec_consistent, out_shapes, out_shapes_)):
if not all(map(masking._shape_spec_consistent, out_shapes, out_shapes_)):
raise masking.ShapeError
return fun

def _shape_spec_consistent(spec, expr):
return all(a == b for a, b in zip(spec, expr) if a is not masking.monomorphic_dim)


def jvp(fun: Callable, primals, tangents) -> Tuple[Any, Any]:
"""Computes a (forward-mode) Jacobian-vector product of ``fun``.
Expand Down

0 comments on commit c00e9a2

Please sign in to comment.