Skip to content

Commit

Permalink
Support setitem using fancy indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
hekaisheng committed Sep 14, 2021
1 parent 1bfad0c commit 0eb7b8a
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 34 deletions.
158 changes: 124 additions & 34 deletions mars/tensor/indexing/setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,52 +12,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import operator
from numbers import Integral

import numpy as np

from ... import opcodes as OperandDef
from ...core import ENTITY_TYPE, recursive_tile
from ...serialization.serializables import KeyField, TupleField, AnyField
from ...serialization.serializables import KeyField, TupleField, AnyField, BoolField
from ...tensor import tensor as astensor
from ...utils import has_unknown_shape
from ..core import TENSOR_TYPE
from ..operands import TensorHasInput, TensorOperandMixin
from ..utils import filter_inputs
from .core import process_index
from .core import process_index, calc_shape


class TensorIndexSetValue(TensorHasInput, TensorOperandMixin):
_op_type_ = OperandDef.INDEXSETVALUE

_input = KeyField('input')
_indexes = TupleField('indexes')
_value = AnyField('value')

def __init__(self, indexes=None, value=None, **kw):
super().__init__(_indexes=indexes, _value=value, **kw)

@property
def indexes(self):
return self._indexes

@property
def value(self):
return self._value
_input = KeyField('_input')
indexes = TupleField('indexes')
value = AnyField('value')
is_fancy_index = BoolField('is_fancy_index')
index_offset = TupleField('index_offset')

def _set_inputs(self, inputs):
super()._set_inputs(inputs)
inputs_iter = iter(self._inputs[1:])
new_indexes = [next(inputs_iter) if isinstance(index, ENTITY_TYPE) else index
for index in self._indexes]
self._indexes = tuple(new_indexes)
if isinstance(self._value, ENTITY_TYPE):
self._value = next(inputs_iter)
for index in self.indexes]
self.indexes = tuple(new_indexes)
if isinstance(self.value, ENTITY_TYPE):
self.value = next(inputs_iter)

def __call__(self, a, index, value):
inputs = filter_inputs([a] + list(index) + [value])
self._indexes = tuple(index)
self._value = value
self.indexes = tuple(index)
self.value = value
return self.new_tensor(inputs, a.shape, order=a.order)

def on_output_modify(self, new_output):
Expand All @@ -69,7 +62,77 @@ def on_input_modify(self, new_input):
return new_op.new_tensor(new_inputs, shape=self.outputs[0].shape)

@classmethod
def tile(cls, op: "TensorIndexSetValue"):
def _tile_fancy_index(cls, op: "TensorIndexSetValue"):
from ..merge import column_stack, TensorConcatenate

tensor = op.outputs[0]
inp = op.inputs[0]
value = op.value
indexes = op.indexes

if has_unknown_shape(inp):
yield
axis_to_tensor_index = dict((axis, ind) for axis, ind
in enumerate(indexes) if isinstance(ind, ENTITY_TYPE))
offsets_on_axis = [np.cumsum([0] + list(split)) for split in inp.nsplits]

out_chunks = []
for c in inp.chunks:
chunk_filters = []
chunk_index_offset = []
for axis in range(len(c.shape)):
offset = offsets_on_axis[axis][c.index[axis]]
chunk_index_offset.append(offset)
if axis in axis_to_tensor_index:
index_on_axis = axis_to_tensor_index[axis]
filtered = (index_on_axis >= offset) & \
(index_on_axis < offset + c.shape[axis])
chunk_filters.append(filtered)
combined_filter = functools.reduce(operator.and_, chunk_filters)
if isinstance(value, ENTITY_TYPE):
concat_tensor = column_stack(list(axis_to_tensor_index.values()) + [value])
else:
concat_tensor = column_stack(list(axis_to_tensor_index.values()))
tiled_tensor = yield from recursive_tile(
concat_tensor[combined_filter])

chunk_indexes = []
tensor_index_order = 0
for axis in range(len(c.shape)):
if axis in axis_to_tensor_index:
index_chunks = [tiled_tensor.cix[i, tensor_index_order]
for i in range(tiled_tensor.chunk_shape[0])]
concat_op = TensorConcatenate(axis=0, dtype=index_chunks[0].dtype)
chunk_indexes.append(concat_op.new_chunk(
index_chunks, shape=(tiled_tensor.shape[0],), index=(0,)))
else:
chunk_indexes.append(slice(None))
tensor_index_order += 1

if isinstance(value, ENTITY_TYPE):
value_chunks = [tiled_tensor.cix[i, -1]
for i in range(tiled_tensor.chunk_shape[0])]
concat_op = TensorConcatenate(axis=0, dtype=value_chunks[0].dtype)
chunk_value = concat_op.new_chunk(
value_chunks, shape=(tiled_tensor.shape[0],), index=(0,))
else:
chunk_value = value
chunk_op = TensorIndexSetValue(
dtype=op.dtype, sparse=op.sparse,
indexes=tuple(chunk_indexes),
index_offset=tuple(chunk_index_offset),
value=chunk_value)
input_chunks = filter_inputs([c] + chunk_indexes + [chunk_value])
out_chunk = chunk_op.new_chunk(input_chunks, shape=c.shape,
index=c.index, order=tensor.order)
out_chunks.append(out_chunk)

new_op = op.copy()
return new_op.new_tensors(op.inputs, tensor.shape, order=tensor.order,
chunks=out_chunks, nsplits=op.input.nsplits)

@classmethod
def _tile(cls, op: "TensorIndexSetValue"):
from ..base import broadcast_to
from .getitem import _getitem_nocheck

Expand Down Expand Up @@ -116,26 +179,50 @@ def tile(cls, op: "TensorIndexSetValue"):
return new_op.new_tensors(op.inputs, tensor.shape, order=tensor.order,
chunks=out_chunks, nsplits=op.input.nsplits)

@classmethod
def tile(cls, op: "TensorIndexSetValue"):
if op.is_fancy_index:
return (yield from cls._tile_fancy_index(op))
else:
return (yield from cls._tile(op))

@classmethod
def execute(cls, ctx, op):
indexes = [ctx[index.key] if hasattr(index, 'key') else index
for index in op.indexes]
if getattr(op, 'index_offset', None) is not None:
new_indexs = []
index_iter = iter(indexes)
for ind, offset in zip(indexes, op.index_offset):
if isinstance(ind, np.ndarray):
new_indexs.append(next(index_iter) - offset)
else:
new_indexs.append(ind)
indexes = new_indexs
input_ = ctx[op.inputs[0].key].copy()
value = ctx[op.value.key] if hasattr(op.value, 'key') else op.value
if hasattr(input_, 'flags') and not input_.flags.writeable:
input_.setflags(write=True)
input_[tuple(indexes)] = value
try:
input_[tuple(indexes)] = value
except:
print(indexes)
ctx[op.outputs[0].key] = input_


def _check_support(index):
if isinstance(index, (slice, Integral)):
pass
elif isinstance(index, (np.ndarray, TENSOR_TYPE)) and index.dtype == np.bool_:
pass
else: # pragma: no cover
raise NotImplementedError('Only slice, int, or bool indexing '
f'supported by now, got {type(index)}')
def _check_support(indexes):
if all((isinstance(ix, (TENSOR_TYPE, np.ndarray)) and ix.dtype != np.bool_
or isinstance(ix, slice) and ix == slice(None)) for ix in indexes):
return True
for index in indexes:
if isinstance(index, (slice, Integral)):
pass
elif isinstance(index, (np.ndarray, TENSOR_TYPE)) and index.dtype == np.bool_:
pass
else: # pragma: no cover
raise NotImplementedError('Only slice, int, or bool indexing '
f'supported by now, got {type(index)}')
return False


def _setitem(a, item, value):
Expand All @@ -144,11 +231,14 @@ def _setitem(a, item, value):
# do not convert for tuple when dtype is record type.
value = astensor(value)

for ix in index:
_check_support(ix)
is_fancy_index = _check_support(index)
if is_fancy_index:
index = [astensor(ind) if isinstance(ind, np.ndarray) else ind
for ind in index]

# __setitem__ on a view should be still a view, see GH #732.
op = TensorIndexSetValue(dtype=a.dtype, sparse=a.issparse(),
is_fancy_index=is_fancy_index,
indexes=tuple(index), value=value,
create_view=a.op.create_view)
ret = op(a, index, value)
Expand Down
61 changes: 61 additions & 0 deletions mars/tensor/indexing/tests/test_indexing_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,67 @@ def test_mixed_indexing_execution(setup):
np.testing.assert_array_equal(res, expected)


def test_setitem_fancy_index_execution(setup):
rs = np.random.RandomState(0)

raw = rs.randint(0, 10, size=(11, 12))

# index is a ndarray, value is a scalar
arr = tensor(raw.copy(), chunk_size=5)
idx = rs.randint(0, 11, (5,))
arr[idx] = 20
res = arr.execute().fetch()
expected = raw.copy()
expected[idx] = 20
np.testing.assert_array_equal(res, expected)

# index is a tensor, value is a scalar
arr = tensor(raw.copy(), chunk_size=5)
raw_index = rs.randint(0, 11, (8,))
idx = tensor(raw_index.copy(), chunk_size=5)
arr[idx] = 2
res = arr.execute().fetch()
expected = raw.copy()
expected[raw_index] = 2
np.testing.assert_array_equal(res, expected)

# indexes are all tensors
arr = tensor(raw.copy(), chunk_size=6)
raw_index1 = rs.randint(0, 11, (20,))
idx1 = tensor(raw_index1.copy(), chunk_size=8)
raw_index2 = rs.randint(0, 12, (20,))
idx2 = tensor(raw_index2.copy(), chunk_size=8)
arr[idx1, idx2] = 2
res = arr.execute().fetch()
expected = raw.copy()
expected[raw_index1, raw_index2] = 2
np.testing.assert_array_equal(res, expected)

# indexes all tensors, value is also a tensor
arr = tensor(raw.copy(), chunk_size=6)
raw_index1 = rs.randint(0, 11, (20,))
idx1 = tensor(raw_index1.copy(), chunk_size=8)
raw_index2 = rs.randint(0, 12, (20,))
idx2 = tensor(raw_index2.copy(), chunk_size=8)
raw_value = rs.randint(0, 10, (20,))
arr[idx1, idx2] = tensor(raw_value, chunk_size=4)
res = arr.execute().fetch()
expected = raw.copy()
expected[raw_index1, raw_index2] = raw_value
np.testing.assert_array_equal(res, expected)

raw = rs.randint(0, 10, size=(20,))
arr = tensor(raw.copy(), chunk_size=6)
raw_index = rs.randint(0, 11, (9,))
raw_value = rs.randint(0, 10, (9,))
index = tensor(raw_index, chunk_size=3)
arr[index] = tensor(raw_value, chunk_size=4)
res = arr.execute().fetch()
expected = raw.copy()
expected[raw_index] = raw_value
np.testing.assert_array_equal(res, expected)


def test_setitem_execution(setup):
rs = np.random.RandomState(0)

Expand Down

0 comments on commit 0eb7b8a

Please sign in to comment.