Skip to content

Commit

Permalink
fix elemwise_sum test script (apache#8008)
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-haibin-lin authored and mbaijal committed Sep 25, 2017
1 parent 52b79fc commit 99e4b8c
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
4 changes: 1 addition & 3 deletions src/operator/tensor/elemwise_sum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,8 @@ void ElementWiseSumComputeExGPU(const nnvm::NodeAttrs& attrs,
CHECK_EQ(req[0], kWriteTo) << "ElementWiseSumComputeExGPU only supports req = kWriteTo";
if (inputs[0].storage_type() == kRowSparseStorage) {
mshadow::Stream<gpu>* s = op_ctx.get_stream<gpu>();
Resource rsc = ResourceManager::Get()->Request(op_ctx.run_ctx.get_ctx(),
ResourceRequest(ResourceRequest::kTempSpace));
NDArray out_nd = outputs[0];
mxnet::ndarray::ElementwiseSum<gpu>(s, rsc, inputs, &out_nd);
mxnet::ndarray::ElementwiseSum<gpu>(s, op_ctx.requested[0], inputs, &out_nd);
} else {
FCompExFallback<gpu>(attrs, op_ctx, inputs, req, outputs,
ElementWiseSumComputeWithHalf2<gpu>,
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,8 +1498,8 @@ def check_sparse_elementwise_sum_with_shape(stype, shape, n):
inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)]
out = mx.symbol.sparse.add_n(*inputs, name='esum')
arr = []
arr_grad = [mx.nd.empty(shape) for _ in range(n)]
densities = [0, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 1.0]
arr_grad = [mx.nd.empty(shape, stype=stype) for _ in range(n)]
densities = [0, 0.01, 0.5, 1.0]
for i in range(n):
arr.append(rand_ndarray(shape, stype, densities[np.random.randint(0, len(densities))]))

Expand Down

0 comments on commit 99e4b8c

Please sign in to comment.