Skip to content

Commit

Permalink
Use shuffle instead of recursive_tile
Browse files Browse the repository at this point in the history
  • Loading branch information
hekaisheng committed Oct 19, 2021
1 parent 3091c62 commit 65320eb
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 106 deletions.
1 change: 1 addition & 0 deletions mars/opcodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@
# according to the original fancy index order
FANCY_INDEX_DISTRIBUTE = 424
FANCY_INDEX_CONCAT = 425
INDEXSETVALUESHUFFLE = 426

# linear algebra
TENSORDOT = 501
Expand Down
5 changes: 4 additions & 1 deletion mars/services/scheduling/worker/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ async def _get_band_quota_ref(self, band: str) -> Union[mo.ActorRef, QuotaActor]

async def _prepare_input_data(self, subtask: Subtask, band_name: str):
queries = []
shuffle_queries = []
storage_api = await StorageAPI.create(
subtask.session_id, address=self.address, band_name=band_name)
pure_dep_keys = set()
Expand All @@ -139,10 +140,12 @@ async def _prepare_input_data(self, subtask: Subtask, band_name: str):
queries.append(storage_api.fetch.delay(chunk.key, band_name=to_fetch_band))
elif isinstance(chunk.op, FetchShuffle):
for key in chunk_key_to_data_keys[chunk.key]:
queries.append(storage_api.fetch.delay(
shuffle_queries.append(storage_api.fetch.delay(
key, band_name=to_fetch_band, error='ignore'))
if queries:
await storage_api.fetch.batch(*queries)
if shuffle_queries:
await storage_api.fetch.batch(*shuffle_queries)

async def _collect_input_sizes(self,
subtask: Subtask,
Expand Down
18 changes: 12 additions & 6 deletions mars/services/subtask/worker/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,10 @@ async def _load_input_data(self):
accept_nones.append(True)
elif isinstance(chunk.op, FetchShuffle):
for key in self._chunk_key_to_data_keys[chunk.key]:
keys.append(key)
gets.append(self._storage_api.get.delay(key, error='ignore'))
accept_nones.append(False)
if key not in keys:
keys.append(key)
gets.append(self._storage_api.get.delay(key, error='ignore'))
accept_nones.append(False)
if keys:
logger.debug('Start getting input data, keys: %s, '
'subtask id: %s', keys, self.subtask.subtask_id)
Expand Down Expand Up @@ -207,19 +208,24 @@ def cb(fut):
if ref_counts[inp.key] == 0:
# ref count reaches 0, remove it
for key in self._chunk_key_to_data_keys[inp.key]:
del self._datastore[key]
if key in self._datastore:
del self._datastore[key]

async def _unpin_data(self, data_keys):
# unpin input keys
unpins = []
shuffle_unpins = []
for key in data_keys:
if isinstance(key, tuple):
# a tuple key means it's a shuffle key,
# some shuffle data is None and not stored in storage
unpins.append(self._storage_api.unpin.delay(key, error='ignore'))
shuffle_unpins.append(self._storage_api.unpin.delay(key, error='ignore'))
else:
unpins.append(self._storage_api.unpin.delay(key))
await self._storage_api.unpin.batch(*unpins)
if unpins:
await self._storage_api.unpin.batch(*unpins)
if shuffle_unpins:
await self._storage_api.unpin.batch(*shuffle_unpins)

async def _store_data(self, chunk_graph: ChunkGraph):
# skip virtual operands for result chunks
Expand Down
230 changes: 131 additions & 99 deletions mars/tensor/indexing/setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,21 @@
# limitations under the License.

import functools
import itertools
import operator
from numbers import Integral

import numpy as np

from ... import opcodes as OperandDef
from ...core import ENTITY_TYPE, recursive_tile
from ...core.operand import OperandStage
from ...serialization.serializables import KeyField, TupleField, AnyField, BoolField
from ...tensor import tensor as astensor
from ...utils import has_unknown_shape
from ..core import TENSOR_TYPE
from ..operands import TensorHasInput, TensorOperandMixin
from ..core import TENSOR_TYPE, TensorOrder
from ..operands import TensorHasInput, TensorMapReduceOperand, \
TensorOperandMixin, TensorShuffleProxy
from ..utils import filter_inputs
from .core import process_index

Expand All @@ -33,47 +36,28 @@ class TensorIndexSetValue(TensorHasInput, TensorOperandMixin):
_op_type_ = OperandDef.INDEXSETVALUE

_input = KeyField('input')
_indexes = TupleField('indexes')
_value = AnyField('value')
_is_fancy_index = BoolField('is_fancy_index')
_index_offset = TupleField('index_offset')

def __init__(self, indexes=None, value=None,
is_fancy_index=None, index_offset=None, **kw):
super().__init__(_indexes=indexes, _value=value,
_is_fancy_index=is_fancy_index,
_index_offset=index_offset,
**kw)

@property
def indexes(self):
return self._indexes

@property
def value(self):
return self._value
indexes = TupleField('indexes')
value = AnyField('value')
is_fancy_index = BoolField('is_fancy_index')

@property
def is_fancy_index(self):
return self._is_fancy_index

@property
def index_offset(self):
return self._index_offset
def __init__(self, indexes=None, value=None, is_fancy_index=None, **kw):
super().__init__(indexes=indexes, value=value,
is_fancy_index=is_fancy_index,
**kw)

def _set_inputs(self, inputs):
super()._set_inputs(inputs)
inputs_iter = iter(self._inputs[1:])
new_indexes = [next(inputs_iter) if isinstance(index, ENTITY_TYPE) else index
for index in self._indexes]
self._indexes = tuple(new_indexes)
if isinstance(self._value, ENTITY_TYPE):
self._value = next(inputs_iter)
for index in self.indexes]
self.indexes = tuple(new_indexes)
if isinstance(self.value, ENTITY_TYPE):
self.value = next(inputs_iter)

def __call__(self, a, index, value):
inputs = filter_inputs([a] + list(index) + [value])
self._indexes = tuple(index)
self._value = value
self.indexes = tuple(index)
self.value = value
return self.new_tensor(inputs, a.shape, order=a.order)

def on_output_modify(self, new_output):
Expand All @@ -86,7 +70,7 @@ def on_input_modify(self, new_input):

@classmethod
def _tile_fancy_index(cls, op: "TensorIndexSetValue"):
from ..merge import column_stack, TensorConcatenate
from ..utils import unify_chunks

tensor = op.outputs[0]
inp = op.inputs[0]
Expand All @@ -95,64 +79,51 @@ def _tile_fancy_index(cls, op: "TensorIndexSetValue"):

if has_unknown_shape(inp):
yield
axis_to_tensor_index = dict((axis, ind) for axis, ind
in enumerate(indexes) if isinstance(ind, ENTITY_TYPE))
offsets_on_axis = [np.cumsum([0] + list(split)) for split in inp.nsplits]

out_chunks = []
for c in inp.chunks:
chunk_filters = []
chunk_index_offset = []
for axis in range(len(c.shape)):
offset = offsets_on_axis[axis][c.index[axis]]
chunk_index_offset.append(offset)
if axis in axis_to_tensor_index:
index_on_axis = axis_to_tensor_index[axis]
filtered = (index_on_axis >= offset) & \
(index_on_axis < offset + c.shape[axis])
chunk_filters.append(filtered)
combined_filter = functools.reduce(operator.and_, chunk_filters)
if isinstance(value, ENTITY_TYPE):
concat_tensor = column_stack(list(axis_to_tensor_index.values()) + [value])
else:
concat_tensor = column_stack(list(axis_to_tensor_index.values()))
tiled_tensor = yield from recursive_tile(
concat_tensor[combined_filter])

chunk_indexes = []
tensor_index_order = 0
for axis in range(len(c.shape)):
if axis in axis_to_tensor_index:
index_chunks = [tiled_tensor.cix[i, tensor_index_order]
for i in range(tiled_tensor.chunk_shape[0])]
concat_op = TensorConcatenate(axis=0, dtype=index_chunks[0].dtype)
chunk_indexes.append(concat_op.new_chunk(
index_chunks, shape=(tiled_tensor.shape[0],), index=(0,)))
else:
chunk_indexes.append(slice(None))
tensor_index_order += 1

if isinstance(value, ENTITY_TYPE):
value_chunks = [tiled_tensor.cix[i, -1]
for i in range(tiled_tensor.chunk_shape[0])]
concat_op = TensorConcatenate(axis=0, dtype=value_chunks[0].dtype)
chunk_value = concat_op.new_chunk(
value_chunks, shape=(tiled_tensor.shape[0],), index=(0,))
else:
chunk_value = value
chunk_op = TensorIndexSetValue(
dtype=op.dtype, sparse=op.sparse,
indexes=tuple(chunk_indexes),
index_offset=tuple(chunk_index_offset),
value=chunk_value)
input_chunks = filter_inputs([c] + chunk_indexes + [chunk_value])
out_chunk = chunk_op.new_chunk(input_chunks, shape=c.shape,
index=c.index, order=tensor.order)
out_chunks.append(out_chunk)
fancy_indexes = [index for index in indexes if isinstance(index, ENTITY_TYPE)]
if isinstance(value, ENTITY_TYPE):
value, *fancy_indexes = yield from unify_chunks(value, *fancy_indexes)
value = value.chunks
else:
fancy_indexes = yield from unify_chunks(*fancy_indexes)
value = [value] * len(fancy_indexes[0].chunks)
input_nsplits = inp.nsplits
shuffle_axes = tuple(axis for axis, ind in enumerate(indexes)
if isinstance(ind, ENTITY_TYPE))

map_chunks = []
for value_chunk, *index_chunks in zip(
value, *[index.chunks for index in fancy_indexes]):
map_op = TensorIndexSetValueShuffle(
stage=OperandStage.map, input_nsplits=input_nsplits,
value=value_chunk, indexes=tuple(index_chunks),
shuffle_axes=shuffle_axes, dtype=tensor.dtype)
inputs = filter_inputs([value_chunk] + list(index_chunks))
map_chunk = map_op.new_chunk(inputs, shape=(np.nan,),
index=index_chunks[0].index,
order=TensorOrder.C_ORDER)
map_chunks.append(map_chunk)

proxy_chunk = TensorShuffleProxy(dtype=tensor.dtype).new_chunk(
map_chunks, shape=(), order=TensorOrder.C_ORDER)

reducer_chunks = []
offsets_on_axis = [np.cumsum([0] + list(split)) for split in input_nsplits]
for input_chunk in inp.chunks:
chunk_offsets = tuple(offsets_on_axis[axis][input_chunk.index[axis]]
for axis in range(len(inp.shape)))
reducer_op = TensorIndexSetValueShuffle(
stage=OperandStage.reduce, dtype=input_chunk.dtype,
shuffle_axes=shuffle_axes, chunk_offsets=chunk_offsets)
reducer_chunk = reducer_op.new_chunk([input_chunk, proxy_chunk],
index=input_chunk.index,
shape=input_chunk.shape,
order=input_chunk.order)
reducer_chunks.append(reducer_chunk)

new_op = op.copy()
return new_op.new_tensors(op.inputs, tensor.shape, order=tensor.order,
chunks=out_chunks, nsplits=op.input.nsplits)
chunks=reducer_chunks, nsplits=op.input.nsplits)

@classmethod
def _tile(cls, op: "TensorIndexSetValue"):
Expand Down Expand Up @@ -213,24 +184,85 @@ def tile(cls, op: "TensorIndexSetValue"):
def execute(cls, ctx, op):
indexes = [ctx[index.key] if hasattr(index, 'key') else index
for index in op.indexes]
if getattr(op, 'index_offset', None) is not None:
new_indexs = []
index_iter = iter(indexes)
for ind, offset in zip(indexes, op.index_offset):
if isinstance(ind, np.ndarray):
new_indexs.append(next(index_iter) - offset)
else:
new_indexs.append(ind)
indexes = new_indexs
input_ = ctx[op.inputs[0].key].copy()
value = ctx[op.value.key] if hasattr(op.value, 'key') else op.value
if hasattr(input_, 'flags') and not input_.flags.writeable:
input_.setflags(write=True)

input_[tuple(indexes)] = value
ctx[op.outputs[0].key] = input_


class TensorIndexSetValueShuffle(TensorMapReduceOperand, TensorOperandMixin):
_op_type_ = OperandDef.INDEXSETVALUESHUFFLE

indexes = TupleField('indexes')
value = AnyField('value')
input_nsplits = TupleField('input_nsplits')
chunk_offsets = TupleField('chunk_offsets')
shuffle_axes = TupleField('shuffle_axes')

def __init__(self, indexes=None, value=None, input_nsplits=None,
chunk_offsets=None, shuffle_axes=None, **kw):
super().__init__(indexes=indexes, value=value,
input_nsplits=input_nsplits,
chunk_offsets=chunk_offsets,
shuffle_axes=shuffle_axes, **kw)

@classmethod
def execute(cls, ctx, op):
if op.stage == OperandStage.map:
cls._execute_map(ctx, op)
else:
cls._execute_reduce(ctx, op)

@classmethod
def _execute_map(cls, ctx, op):
nsplits = op.input_nsplits
shuffle_axes = op.shuffle_axes
all_inputs = [ctx[inp.key] for inp in op.inputs]
if hasattr(op.value, 'key'):
value = ctx[op.value.key]
indexes = all_inputs[1:]
else:
value = op.value
indexes = all_inputs

offsets_on_axis = [np.cumsum([0] + list(split)) for split in nsplits]
for reducer_index in itertools.product(
*(map(range, [len(s) for s in nsplits]))):
chunk_filters = []
indexes_iter = iter(indexes)
for axis, _ in enumerate(reducer_index):
start = offsets_on_axis[axis][reducer_index[axis]]
end = offsets_on_axis[axis][reducer_index[axis] + 1]
if axis in shuffle_axes:
index_on_axis = next(indexes_iter)
filtered = (index_on_axis >= start) & (index_on_axis < end)
chunk_filters.append(filtered)
combined_filter = functools.reduce(operator.and_, chunk_filters)
if hasattr(op.value, 'key'):
ctx[op.outputs[0].key, reducer_index] = tuple(inp[combined_filter]
for inp in all_inputs)
else:
ctx[op.outputs[0].key, reducer_index] = tuple([value] + [inp[combined_filter]
for inp in all_inputs])

@classmethod
def _execute_reduce(cls, ctx, op):
input_data = ctx[op.inputs[0].key].copy()
for index_value in op.iter_mapper_data(ctx, input_id=1):
value = index_value[0]
indexes_with_offset = index_value[1:]
indexes = []
index_iter = iter(indexes_with_offset)
for axis in range(input_data.ndim):
if axis in op.shuffle_axes:
indexes.append(next(index_iter) - op.chunk_offsets[axis])
input_data[indexes] = value

ctx[op.outputs[0].key] = input_data


def _check_support(indexes):
if all((isinstance(ix, (TENSOR_TYPE, np.ndarray)) and ix.dtype != np.bool_
or isinstance(ix, slice) and ix == slice(None)) for ix in indexes):
Expand Down

0 comments on commit 65320eb

Please sign in to comment.