Skip to content

Commit

Permalink
add scalar assignment to row_sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin committed Aug 9, 2017
1 parent 8da42c2 commit 8aef7a5
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 9 deletions.
11 changes: 8 additions & 3 deletions python/mxnet/ndarray/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def __setitem__(self, key, value):
----------
key : slice
The indexing key.
value : NDArray or numpy.ndarray
value : scalar, NDArray or numpy.ndarray
The value to set.
Examples
Expand All @@ -568,6 +568,12 @@ def __setitem__(self, key, value):
array([[ 1., 1., 1.],
[ 1., 1., 1.],
[ 1., 1., 1.]], dtype=float32)
>>> # assign scalar to RowSparseNDArray
>>> x[:] = 7
>>> x.asnumpy()
array([[ 7., 7., 7.],
[ 7., 7., 7.],
[ 7., 7., 7.]], dtype=float32)
"""
if not self.writable:
raise ValueError('Failed to assign to a readonly RowSparseNDArray')
Expand All @@ -580,8 +586,7 @@ def __setitem__(self, key, value):
if value.handle is not self.handle:
value.copyto(self)
elif isinstance(value, numeric_types):
raise ValueError("Assigning numeric types to RowSparseNDArray " \
"is not implemented yet.")
_internal._set_value(float(value), out=self)
elif isinstance(value, (np.ndarray, np.generic)):
warnings.warn('Assigning non-NDArray object to RowSparseNDArray is not efficient',
RuntimeWarning)
Expand Down
35 changes: 30 additions & 5 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,18 +295,43 @@ void SetValueOp(const real_t &rhs, NDArray *out) {
switch (ret.ctx().dev_mask()) {
case cpu::kDevMask: {
Engine::Get()->PushSync([rhs, ret](RunContext ctx) {
CHECK(ret.storage_type() == kDefaultStorage);
TBlob tmp = ret.data();
ndarray::Eval<cpu>(rhs, &tmp, ctx);
auto ret_stype = ret.storage_type();
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
if (ret_stype == kRowSparseStorage) {
NDArray out = ret;
// indices
nnvm::dim_t nnr = ret.shape()[0];
out.CheckAndAlloc({mshadow::Shape1(nnr)});
op::PopulateFullIdxRspImpl(s, &out);
// data
TBlob tmp = out.data();
ndarray::Eval<cpu>(rhs, &tmp, ctx);
} else {
TBlob tmp = ret.data();
ndarray::Eval<cpu>(rhs, &tmp, ctx);
}
}, ret.ctx(), {}, {ret.var()},
FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME);
break;
}
#if MXNET_USE_CUDA
case gpu::kDevMask: {
Engine::Get()->PushSync([rhs, ret](RunContext ctx) {
TBlob tmp = ret.data();
ndarray::Eval<gpu>(rhs, &tmp, ctx);
auto ret_stype = ret.storage_type();
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
if (ret_stype == kRowSparseStorage) {
NDArray out = ret;
// indices
nnvm::dim_t nnr = ret.shape()[0];
out.CheckAndAlloc({mshadow::Shape1(nnr)});
op::PopulateFullIdxRspImpl(s, &out);
// data
TBlob tmp = out.data();
ndarray::Eval<gpu>(rhs, &tmp, ctx);
} else {
TBlob tmp = ret.data();
ndarray::Eval<gpu>(rhs, &tmp, ctx);
}
// Wait GPU kernel to complete
ctx.get_stream<gpu>()->Wait();
}, ret.ctx(), {}, {ret.var()},
Expand Down
7 changes: 6 additions & 1 deletion tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,8 @@ def test_sparse_nd_setitem():
def check_sparse_nd_setitem(stype, shape, dst):
x = mx.nd.zeros(shape=shape, stype=stype)
x[:] = dst
dst_nd = mx.nd.array(dst) if isinstance(dst, (np.ndarray, np.generic)) else dst
dst_nd = mx.nd.zeros(shape=shape)
dst_nd[:] = dst
assert same(x.asnumpy(), dst_nd.asnumpy())

shape = rand_shape_2d()
Expand All @@ -112,6 +113,10 @@ def check_sparse_nd_setitem(stype, shape, dst):
check_sparse_nd_setitem(stype, shape, rand_ndarray(shape, stype))
# numpy assignment
check_sparse_nd_setitem(stype, shape, np.ones(shape))
if stype == 'row_sparse':
# scalar assignment
check_sparse_nd_setitem(stype, shape, 0)
check_sparse_nd_setitem(stype, shape, 1)


def test_sparse_nd_slice():
Expand Down

0 comments on commit 8aef7a5

Please sign in to comment.