Skip to content

Commit

Permalink
[dynamic-shapes] small fix to einsum (and indexing)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Oct 4, 2022
1 parent a60ca9f commit 06a2c85
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
8 changes: 5 additions & 3 deletions jax/_src/lax/lax.py
Expand Up @@ -1456,7 +1456,11 @@ def _iter(tracer):
raise TypeError("iteration over a 0-d array") # same as numpy error
else:
n = int(tracer.shape[0])
return (slicing.index_in_dim(tracer, i, keepdims=False) for i in range(n))
if any(isinstance(d, core.Tracer) for d in tracer.shape):
return (slicing.dynamic_index_in_dim(tracer, i, keepdims=False)
for i in range(n))
else:
return (slicing.index_in_dim(tracer, i, keepdims=False) for i in range(n))
ShapedArray._iter = staticmethod(_iter)
core.DShapedArray._iter = staticmethod(_iter)

Expand Down Expand Up @@ -3186,11 +3190,9 @@ def _squeeze_lower(ctx, operand, *, dimensions):
).results
else:
return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), operand).results

mlir.register_lowering(squeeze_p, _squeeze_lower)



def shape_as_value(shape: core.Shape):
"""Converts a shape that may contain Poly values into a JAX value."""
if len(shape) == 0:
Expand Down
23 changes: 12 additions & 11 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -77,7 +77,8 @@
from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip,
ceil_of_ratio, partition_list,
canonicalize_axis as _canonicalize_axis)
from jax._src.array import ArrayImpl

Expand Down Expand Up @@ -2979,15 +2980,11 @@ def sum_repeats(operand, names, counts, keep_names):
return operand, names

def filter_singleton_dims(operand, names, other_shape, other_names):
s = shape(operand)
new_shape = []
new_names = []
for i, d in enumerate(names):
other_i = other_names.find(d)
if not core.symbolic_equal_dim(s[i], 1) or other_i == -1 or core.symbolic_equal_dim(other_shape[other_i], 1):
new_shape.append(s[i])
new_names.append(d)
return reshape(operand, tuple(new_shape)), "".join(new_names)
eq = core.symbolic_equal_dim
keep = [not eq(operand.shape[i], 1) or j == -1 or eq(other_shape[j], 1)
for i, j in enumerate(map(other_names.find, names))]
sqez_axes, keep_axes = partition_list(keep, list(range(operand.ndim)))
return lax.squeeze(operand, sqez_axes), "".join(names[i] for i in keep_axes)

for operand_indices, contracted_names_set, einstr in contractions:
contracted_names = sorted(contracted_names_set)
Expand Down Expand Up @@ -3658,7 +3655,11 @@ def _rewriting_take(arr, idx, indices_are_sorted=False, unique_indices=False,
step = idx.step if idx.step is not None else 1
if (0 <= start < n and 0 <= stop <= n and 0 < step and
(start, stop, step) != (0, n, 1)):
return lax.slice_in_dim(arr, start, stop, step)
if _any(isinstance(d, core.Tracer) for d in arr.shape[1:]):
if step == 1: # TODO(mattjj, sharadmv): handle step != 1
return lax.dynamic_slice_in_dim(arr, start, _max(0, stop - start), 0)
else:
return lax.slice_in_dim(arr, start, stop, step)

# TODO(mattjj,dougalm): expand dynamic shape indexing support
if jax.config.jax_dynamic_shapes and arr.ndim > 0:
Expand Down

0 comments on commit 06a2c85

Please sign in to comment.