Skip to content

Commit

Permalink
Fix Index and IndexSetValue operand when new chunks without indexes p…
Browse files Browse the repository at this point in the history
…arameter (#70)

* fix setitem & getitem

* use contextmanager to set inputs

* refine code for chunks to chunk_size

* refine code
  • Loading branch information
hekaisheng authored and wjsi committed Dec 28, 2018
1 parent 8870a95 commit 2699114
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 68 deletions.
30 changes: 30 additions & 0 deletions mars/deploy/local/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,36 @@ def testMultipleOutputTensorExecute(self):
np.testing.assert_allclose(U_result, U_expected + 1)
np.testing.assert_allclose(s_result, s_expectd + 1)

def testIndexTensorExecute(self):
with new_cluster(scheduler_n_process=2, worker_n_process=2) as cluster:
session = cluster.session

a = mt.random.rand(10, 5)
idx = slice(0, 5), slice(0, 5)
a[idx] = 2
a_splits = mt.split(a, 2)
r1, r2 = session.run(a_splits[0], a[idx])

np.testing.assert_array_equal(r1, r2)
np.testing.assert_array_equal(r1, np.ones((5, 5)) * 2)

with new_session(cluster.endpoint) as session2:
a = mt.random.rand(10, 5)
idx = slice(0, 5), slice(0, 5)

a[idx] = mt.ones((5, 5)) * 2
r = session2.run(a[idx])

np.testing.assert_array_equal(r, np.ones((5, 5)) * 2)

def testExecutableTuple(self):
with new_cluster(scheduler_n_process=2, worker_n_process=2, web=True) as cluster:
with new_session('http://' + cluster._web_endpoint).as_default() as _:
a = mt.ones((20, 10), chunk_size=10)
u, s, v = (mt.linalg.svd(a)).execute()
np.testing.assert_allclose(u.dot(np.diag(s).dot(v)), np.ones((20, 10)))


def testGraphFail(self):
op = SerializeMustFailOperand(f=3)
tensor = op.new_tensor(None, (3, 3))
Expand Down
2 changes: 1 addition & 1 deletion mars/tensor/execution/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def _index(ctx, chunk):
def _index_set_value(ctx, chunk):
indexes = [ctx[index.key] if hasattr(index, 'key') else index
for index in chunk.op.indexes]
input = ctx[chunk.inputs[0].key]
input = ctx[chunk.inputs[0].key].copy()
value = ctx[chunk.op.value.key] if hasattr(chunk.op.value, 'key') else chunk.op.value
if hasattr(input, 'flags') and not input.flags.writeable:
input.setflags(write=True)
Expand Down
54 changes: 30 additions & 24 deletions mars/tensor/expressions/indexing/getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from numbers import Integral
import operator
import itertools
import contextlib

import numpy as np

Expand All @@ -34,34 +35,39 @@ class TensorIndex(Index, TensorOperandMixin):
def __init__(self, dtype=None, sparse=False, **kw):
super(TensorIndex, self).__init__(_dtype=dtype, _sparse=sparse, **kw)

@classmethod
def _handle_inputs(cls, inputs):
tensor, indexes = inputs
indexes_inputs = [ind for ind in indexes if isinstance(ind, (BaseWithKey, Entity))]
return [tensor] + indexes_inputs

def _set_inputs(self, inputs):
super(TensorIndex, self)._set_inputs(inputs)
self._input = self._inputs[0]
indexes_iter = iter(self._inputs[1:])
new_indexes = [next(indexes_iter) if isinstance(index, (BaseWithKey, Entity)) else index
@contextlib.contextmanager
def _handle_params(self, inputs, indexes):
"""
Index operator is special, it has additional parameter `indexes` which may also be tensor type,
normally, this indexes is provided when called by `tile` or `TensorIndex.__call__`, however, calls
in `GraphActor.get_executable_operand_dag` only provide inputs, in such situation, we need get `indexes`
from operand itself and replace tensor-liked indexes by new one in `inputs`.
"""
if indexes is not None:
indexes_inputs = [ind for ind in indexes if isinstance(ind, (BaseWithKey, Entity))]
inputs = inputs + indexes_inputs
yield inputs

if indexes is not None:
self._indexes = indexes

inputs_iter = iter(self._inputs[1:])
new_indexes = [next(inputs_iter) if isinstance(index, (BaseWithKey, Entity)) else index
for index in self._indexes]
self._indexes = new_indexes

def new_tensors(self, inputs, shape, **kw):
tensor, indexes = inputs
self._indexes = indexes
inputs = self._handle_inputs(inputs)
return super(TensorIndex, self).new_tensors(inputs, shape, **kw)
indexes = kw.pop('indexes', None)
with self._handle_params(inputs, indexes) as mix_inputs:
return super(TensorIndex, self).new_tensors(mix_inputs, shape, **kw)

def new_chunks(self, inputs, shape, **kw):
chunk, indexes = inputs
self._indexes = indexes
inputs = self._handle_inputs(inputs)
return super(TensorIndex, self).new_chunks(inputs, shape, **kw)
indexes = kw.pop('indexes', None)
with self._handle_params(inputs, indexes) as mix_inputs:
return super(TensorIndex, self).new_chunks(mix_inputs, shape, **kw)

def __call__(self, a, index, shape):
return self.new_tensor([a, index], shape)
return self.new_tensor([a], shape, indexes=index)

@classmethod
def tile(cls, op):
Expand Down Expand Up @@ -172,14 +178,14 @@ def tile(cls, op):

chunk_input = in_tensor.cix[tuple(chunk_idx)]
chunk_op = op.copy().reset_key()
chunk = chunk_op.new_chunk([chunk_input, chunk_index], tuple(chunk_shape), index=output_idx)
chunk = chunk_op.new_chunk([chunk_input], tuple(chunk_shape), indexes=chunk_index, index=output_idx)
out_chunks.append(chunk)

nsplits = [tuple(c.shape[i] for c in out_chunks
if all(idx == 0 for j, idx in enumerate(c.index) if j != i))
for i in range(len(out_chunks[0].shape))]
new_op = op.copy().reset_key()
tensor = new_op.new_tensor([op.input, op.indexes], tensor.shape, chunks=out_chunks, nsplits=nsplits)
tensor = new_op.new_tensor([op.input], tensor.shape, indexes=op.indexes, chunks=out_chunks, nsplits=nsplits)

if len(to_concat_axis_index) > 1:
raise NotImplementedError
Expand All @@ -204,11 +210,11 @@ def tile(cls, op):
axis=axis, dtype=chunks[0].dtype, sparse=chunks[0].op.sparse)
concat_chunk = concat_chunk_op.new_chunk(chunks, tuple(s), index=new_idx)
out_chunk_op = TensorIndex(dtype=concat_chunk.dtype, sparse=concat_chunk.op.sparse)
out_chunk = out_chunk_op.new_chunk([concat_chunk, indexobj], tuple(s), index=new_idx)
out_chunk = out_chunk_op.new_chunk([concat_chunk], tuple(s), indexes=indexobj, index=new_idx)
output_chunks.append(out_chunk)

new_op = tensor.op.copy()
tensor = new_op.new_tensor([op.input, op.indexes], tuple(output_shape),
tensor = new_op.new_tensor([op.input], tuple(output_shape), indexes=op.indexes,
chunks=output_chunks, nsplits=output_nsplits)

return [tensor]
Expand Down
67 changes: 36 additions & 31 deletions mars/tensor/expressions/indexing/setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

from numbers import Integral
import contextlib

import numpy as np

Expand All @@ -30,41 +31,45 @@ class TensorIndexSetValue(IndexSetValue, TensorOperandMixin):
def __init__(self, dtype=None, sparse=False, **kw):
super(TensorIndexSetValue, self).__init__(_dtype=dtype, _sparse=sparse, **kw)

@classmethod
def _handle_inputs(cls, inputs):
tensor, indexes, value = inputs
indexes_inputs = [ind for ind in indexes if isinstance(ind, TENSOR_TYPE + CHUNK_TYPE)]
inputs = [tensor] + indexes_inputs
if isinstance(value, TENSOR_TYPE + CHUNK_TYPE):
inputs += [value]
return inputs

def _set_inputs(self, inputs):
super(TensorIndexSetValue, self)._set_inputs(inputs)
self._input = self._inputs[0]
indexes_iter = iter(self._inputs[1:])
new_indexes = [next(indexes_iter) if isinstance(index, (BaseWithKey, Entity)) else index
@contextlib.contextmanager
def _handle_params(self, inputs, indexes, value):
"""
TensorIndexSetValue operator is like Index operand, it has additional parameter `indexes` and `value`, all of
them may be tensor type. As explained in TensorIndex, when indexes and value are not provided, we should get
from operand itself and replace tensor-liked objects by iterating over inputs.
"""
if indexes is not None and value is not None:
indexes_inputs = [ind for ind in indexes if isinstance(ind, TENSOR_TYPE + CHUNK_TYPE)]
inputs += indexes_inputs
if isinstance(value, TENSOR_TYPE + CHUNK_TYPE):
inputs += [value]
yield inputs

if indexes is not None:
self._indexes = indexes
if value is not None:
self._value = value
inputs_iter = iter(self._inputs[1:])
new_indexes = [next(inputs_iter) if isinstance(index, (BaseWithKey, Entity)) else index
for index in self._indexes]
self._indexes = new_indexes
if isinstance(self._value, Entity):
self._value = self._value.data
if isinstance(self._value, (BaseWithKey, Entity)):
self._value = next(inputs_iter)

def new_tensors(self, inputs, shape, **kw):
tensor, indexes, value = inputs
self._indexes = indexes
self._value = value
inputs = self._handle_inputs(inputs)
return super(TensorIndexSetValue, self).new_tensors(inputs, shape, **kw)
indexes = kw.pop('indexes', None)
value = kw.pop('value', None)
with self._handle_params(inputs, indexes, value) as mix_inputs:
return super(TensorIndexSetValue, self).new_tensors(mix_inputs, shape, **kw)

def new_chunks(self, inputs, shape, **kw):
chunk, indexes, value = inputs
self._indexes = indexes
self._value = value
inputs = self._handle_inputs(inputs)
return super(TensorIndexSetValue, self).new_chunks(inputs, shape, **kw)
indexes = kw.pop('indexes', None)
value = kw.pop('value', None)
with self._handle_params(inputs, indexes, value) as mix_inputs:
return super(TensorIndexSetValue, self).new_chunks(mix_inputs, shape, **kw)

def __call__(self, a, index, value):
return self.new_tensor([a, index, value], a.shape)
return self.new_tensor([a], a.shape, indexes=index, value=value)

@classmethod
def tile(cls, op):
Expand All @@ -73,7 +78,7 @@ def tile(cls, op):
is_value_tensor = isinstance(value, TENSOR_TYPE)

index_tensor_op = TensorIndex(dtype=tensor.dtype, sparse=op.sparse)
index_tensor = index_tensor_op.new_tensor([op.input, op.indexes], tensor.shape).single_tiles()
index_tensor = index_tensor_op.new_tensor([op.input], tensor.shape, indexes=op.indexes).single_tiles()

nsplits = index_tensor.nsplits
if any(any(np.isnan(ns) for ns in nsplit) for nsplit in nsplits):
Expand All @@ -92,12 +97,12 @@ def tile(cls, op):

value_chunk = value.cix[index_chunk.index] if is_value_tensor else value
chunk_op = op.copy().reset_key()
out_chunk = chunk_op.new_chunk([chunk, index_chunk.op.indexes, value_chunk],
chunk.shape, index=chunk.index)
out_chunk = chunk_op.new_chunk([chunk], chunk.shape, indexes=index_chunk.op.indexes,
value=value_chunk, index=chunk.index)
out_chunks.append(out_chunk)

new_op = op.copy()
return new_op.new_tensors([op.input, op.indexes, op.value], tensor.shape,
return new_op.new_tensors([op.input], tensor.shape, indexes=op.indexes, value=op.value,
chunks=out_chunks, nsplits=op.input.nsplits)


Expand Down
2 changes: 1 addition & 1 deletion mars/tensor/expressions/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def testIsIn(self):
self.assertEqual(len(mask.op.test_elements.chunks), 1)
self.assertIs(mask.chunks[0].inputs[0], element.chunks[0].data)

element = 2 * arange(4, chunks=1).reshape(2, 2)
element = 2 * arange(4, chunk_size=1).reshape(2, 2)
test_elements = tensor([1, 2, 4, 8], chunk_size=2)

mask = isin(element, test_elements, invert=True)
Expand Down
11 changes: 0 additions & 11 deletions mars/web/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,6 @@ def testApi(self):
value = sess.run(c, timeout=120)
assert_array_equal(value, va.dot(vb))

# test test multiple outputs
a = mt.random.rand(10, 10)
U, s, V, raw = sess.run(list(mt.linalg.svd(a)) + [a])
np.testing.assert_allclose(U.dot(np.diag(s).dot(V)), raw)

# check web UI requests
res = requests.get(service_ep)
self.assertEqual(res.status_code, 200)
Expand All @@ -185,12 +180,6 @@ def testApi(self):
res = requests.get(service_ep + '/worker')
self.assertEqual(res.status_code, 200)

# test default session run with multiple inputs
with new_session(service_ep).as_default() as sess:
a = mt.ones((20, 10), chunk_size=10)
u, s, v = (mt.linalg.svd(a)).execute()
np.testing.assert_allclose(u.dot(np.diag(s).dot(v)), np.ones((20, 10)))


class MockResponse:
def __init__(self, status_code, json_text=None, data=None):
Expand Down

0 comments on commit 2699114

Please sign in to comment.