diff --git a/src/operator/random/sample_op.cc b/src/operator/random/sample_op.cc index 8d87d2b99d14..363163cbc697 100644 --- a/src/operator/random/sample_op.cc +++ b/src/operator/random/sample_op.cc @@ -61,7 +61,8 @@ Example:: [ 0.54488319, 0.84725171]] )code" ADD_FILELINE) -.set_attr("FCompute", SampleUniform_); +.set_attr("FCompute", SampleUniform_) +.set_attr("FComputeEx", SampleUniformEx_); // Add "normal" alias for backward compatibility MXNET_OPERATOR_REGISTER_SAMPLE(random_normal, SampleNormalParam) @@ -78,7 +79,8 @@ Example:: random_normal(loc=0, scale=1, shape=(2,2)) = [[ 1.89171135, -1.16881478], [-1.23474145, 1.55807114]] )code" ADD_FILELINE) -.set_attr("FCompute", SampleNormal_); +.set_attr("FCompute", SampleNormal_) +.set_attr("FComputeEx", SampleNormalEx_); MXNET_OPERATOR_REGISTER_SAMPLE(random_gamma, SampleGammaParam) .add_alias("_sample_gamma") @@ -91,7 +93,8 @@ Example:: random_gamma(alpha=9, beta=0.5, shape=(2,2)) = [[ 7.10486984, 3.37695289], [ 3.91697288, 3.65933681]] )code" ADD_FILELINE) -.set_attr("FCompute", SampleGamma_); +.set_attr("FCompute", SampleGamma_) +.set_attr("FComputeEx", SampleGammaEx_); MXNET_OPERATOR_REGISTER_SAMPLE(random_exponential, SampleExponentialParam) .add_alias("_sample_exponential") diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h index a1a6a2345b1b..0cd3f6bc2efb 100644 --- a/src/operator/random/sample_op.h +++ b/src/operator/random/sample_op.h @@ -232,29 +232,75 @@ struct SampleGenNegBinomialParam : public dmlc::Parameter; + template -void SampleUniform_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +void SampleComputeEx_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + FSampleCompute fcomp) { + NDArray output = outputs[0]; + mshadow::Stream *s = ctx.get_stream(); + if (output.storage_type() == kRowSparseStorage) { + // indices + nnvm::dim_t nnr = output.shape()[0]; + output.CheckAndAlloc({mshadow::Shape1(nnr)}); + PopulateFullIdxRspImpl(s, &output); + // data + TBlob out_blob = output.data(); + fcomp(attrs, ctx, req[0], &out_blob); + } else { + LOG(FATAL) << "Unexpected storage type for SampleComputeEx_: " + << output.storage_type(); + } +} + +template +void SampleUniformDnsImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const OpReqType& req, + TBlob* output) { using namespace mxnet::op; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); const SampleUniformParam& param = nnvm::get(attrs.parsed); - MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MSHADOW_REAL_TYPE_SWITCH(output->type_flag_, DType, { mshadow::Random *prnd = ctx.requested[0].get_random(s); - mshadow::Tensor out = outputs[0].FlatTo2D(s); + mshadow::Tensor out = output->FlatTo2D(s); prnd->SampleUniform(&out, param.low, param.high); }); } template -void SampleNormal_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +void SampleUniform_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TBlob out = outputs[0]; + SampleUniformDnsImpl(attrs, ctx, req[0], &out); +} + + +template +void SampleUniformEx_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + SampleComputeEx_(attrs, ctx, inputs, req, outputs, SampleUniformDnsImpl); +} + +template +void SampleNormalDnsImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const OpReqType& req, + TBlob* outputs) { using namespace mxnet::op; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); @@ -268,11 +314,29 @@ void SampleNormal_(const nnvm::NodeAttrs& attrs, } template -void SampleGamma_(const nnvm::NodeAttrs& attrs, +void SampleNormal_(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { + TBlob out = outputs[0]; + SampleNormalDnsImpl(attrs, ctx, req[0], &out); +} + +template +void SampleNormalEx_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + SampleComputeEx_(attrs, ctx, inputs, req, outputs, SampleNormalDnsImpl); +} + +template +void SampleGammaDnsImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const OpReqType& req, + TBlob* outputs) { using namespace mxnet::op; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); @@ -286,6 +350,25 @@ void SampleGamma_(const nnvm::NodeAttrs& attrs, }); } +template +void SampleGamma_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TBlob out = outputs[0]; + SampleGammaDnsImpl(attrs, ctx, req[0], &out); +} + +template +void SampleGammaEx_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + SampleComputeEx_(attrs, ctx, inputs, req, outputs, SampleGammaDnsImpl); +} + template void SampleExponential_(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 1b244251fca1..1ac933ddaef5 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -167,6 +167,26 @@ inline void FillDnsZerosRspImpl(mshadow::Stream *s, NDArray *dst) { }); } +struct PopulateFullIdxRspKernel { + template + MSHADOW_XINLINE static void Map(int i, IType* out) { + KERNEL_ASSIGN(out[i], kWriteTo, i); + } +}; + +// Fill full indices NDArray with zeros by updating the aux shape. +template +void PopulateFullIdxRspImpl(mshadow::Stream *s, NDArray *dst) { + using namespace rowsparse; + CHECK_EQ(dst->storage_type(), kRowSparseStorage); + nnvm::dim_t nnr = dst->shape()[0]; + dst->CheckAndAllocAuxData(kIdx, mshadow::Shape1(nnr)); + MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, { + IType* idx = dst->aux_data(kIdx).dptr(); + mxnet_op::Kernel::Launch(s, nnr, idx); + }); +} + // Fill a rsp NDArray with zeros by updating the aux shape. template void FillZerosRspImpl(mshadow::Stream *s, NDArray *dst) { diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 63113fd03f01..cbce588549a1 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -351,6 +351,20 @@ def test_sparse_nd_output_fallback(): mx.nd.random_normal(shape=shape, out=out) assert(np.sum(out.asnumpy()) != 0) +def test_sparse_nd_random(): + shape = (100, 100) + fns = [mx.nd.random_uniform, mx.nd.random_normal, mx.nd.random_gamma] + for fn in fns: + rsp_out = mx.nd.zeros(shape=shape, stype='row_sparse') + dns_out = mx.nd.zeros(shape=shape, stype='default') + mx.random.seed(0) + np.random.seed(0) + fn(shape=shape, out=dns_out) + mx.random.seed(0) + np.random.seed(0) + fn(shape=shape, out=rsp_out) + assert_almost_equal(dns_out.asnumpy(), rsp_out.asnumpy()) + def test_sparse_nd_astype(): stypes = ['row_sparse', 'csr']