diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index c3fb00e562d1..0633e92d6b90 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -604,7 +604,8 @@ MXNET_DLL int MXImperativeInvoke(AtomicSymbolCreator creator, NDArrayHandle **outputs, int num_params, const char **param_keys, - const char **param_vals); + const char **param_vals, + const int** out_stypes); /*! * \brief set whether to record operator for autograd * \param is_train 1 when training, 0 when testing diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index a678e1726f02..494db9356271 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -16,11 +16,20 @@ from .common import CachedOp +_STORAGE_TYPE_ID_TO_STR = { + -1 : 'undefined', + 0 : 'default', + 1 : 'row_sparse', + 2 : 'csr', +} + + class NDArrayBase(object): """Base data structure for ndarray""" - __slots__ = ["handle", "writable"] + __slots__ = ["handle", "writable", "_stype"] # pylint: disable= no-member - def __init__(self, handle, writable=True): + + def __init__(self, handle, writable=True, stype=None): """initialize a new NDArray Parameters @@ -32,6 +41,7 @@ def __init__(self, handle, writable=True): assert isinstance(handle, NDArrayHandle) self.handle = handle self.writable = writable + self._stype = stype def __del__(self): check_call(_LIB.MXNDArrayFree(self.handle)) @@ -62,6 +72,10 @@ def _imperative_invoke(handle, ndargs, keys, vals, out): output_vars = ctypes.POINTER(NDArrayHandle)() num_output = ctypes.c_int(0) + # return output stypes to avoid the c_api call for checking + # a handle's stype in _ndarray_cls + out_stypes = ctypes.POINTER(ctypes.c_int)() + check_call(_LIB.MXImperativeInvoke( ctypes.c_void_p(handle), ctypes.c_int(len(ndargs)), @@ -70,14 +84,17 @@ def _imperative_invoke(handle, ndargs, keys, vals, out): ctypes.byref(output_vars), ctypes.c_int(len(keys)), c_array(ctypes.c_char_p, [c_str(key) for key in keys]), - c_array(ctypes.c_char_p, [c_str(str(val)) for val in vals]))) + c_array(ctypes.c_char_p, [c_str(str(val)) for val in vals]), + ctypes.byref(out_stypes))) if original_output is not None: return original_output if num_output.value == 1: - return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle)) + return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle), + stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]]) else: - return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle)) + return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle, + stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]])) for i in range(num_output.value)] diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index 8af84a307a82..680b937d3fd3 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -399,7 +399,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, else: assert self._arg_params is None and self._aux_params is None param_arrays = [ - mx.nd.zeros(shape=x[0].shape, dtype=x[0].dtype, storage_type=x[0].storage_type) + mx.nd.zeros(shape=x[0].shape, dtype=x[0].dtype, storage_type=x[0].stype) for x in self._exec_group.param_arrays ] self._arg_params = {name:arr for name, arr in zip(self._param_names, param_arrays)} diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 133e30ec6397..948d8dd587e7 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -32,18 +32,18 @@ # pylint: disable=unused-import try: if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: - from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class + from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _STORAGE_TYPE_ID_TO_STR from ._ctypes.ndarray import invoke, CachedOp, _imperative_invoke elif _sys.version_info >= (3, 0): - from ._cy3.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke + from ._cy3.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR from ._cy3.ndarray import invoke, CachedOp, _imperative_invoke else: - from ._cy2.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke + from ._cy2.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR from ._cy2.ndarray import invoke, CachedOp, _imperative_invoke except ImportError: if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") - from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke + from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR from ._ctypes.ndarray import invoke, CachedOp, _imperative_invoke # pylint: enable=unused-import @@ -64,12 +64,6 @@ 4 : np.int32, 6 : np.int64 } -_STORAGE_TYPE_ID_TO_STR = { - -1 : 'undefined', - 0 : 'default', - 1 : 'row_sparse', - 2 : 'csr', -} _STORAGE_TYPE_STR_TO_ID = { 'undefined' : -1, 'default' : 0, @@ -748,8 +742,10 @@ def dtype(self): return _DTYPE_MX_TO_NP[mx_dtype.value] @property - def storage_type(self): - return _storage_type(self.handle) + def stype(self): + if self._stype is None: + self._stype = _storage_type(self.handle) + return self._stype @property # pylint: disable= invalid-name, undefined-variable diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 10f9f06c11b3..7e46c30c7c79 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -335,7 +335,7 @@ def create_state(self, index, weight): return None else: return mx.nd.zeros(shape=weight.shape, ctx=weight.context, - dtype=weight.dtype, storage_type=weight.storage_type) + dtype=weight.dtype, storage_type=weight.stype) def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) diff --git a/python/mxnet/sparse_ndarray.py b/python/mxnet/sparse_ndarray.py index 2923a4c25292..a438b4d6ec7d 100644 --- a/python/mxnet/sparse_ndarray.py +++ b/python/mxnet/sparse_ndarray.py @@ -175,7 +175,7 @@ def __getitem__(self, key): >>> x[1:2].asnumpy() array([[ 3., 4., 5.]], dtype=float32) """ - stype = self.storage_type + stype = self.stype if stype != 'csr': raise Exception("__getitem__ for " + str(stype) + " not implemented yet") if isinstance(key, int): @@ -233,7 +233,7 @@ def values(self): def _num_aux(self): ''' The number of aux data used to help store the sparse ndarray. ''' - return len(_STORAGE_AUX_TYPES[self.storage_type]) + return len(_STORAGE_AUX_TYPES[self.stype]) @property # pylint: disable= invalid-name, undefined-variable @@ -270,7 +270,7 @@ def astype(self, dtype): """ res = mx.nd.zeros(shape=self.shape, ctx=self.context, - dtype=dtype, storage_type=self.storage_type) + dtype=dtype, storage_type=self.stype) self.copyto(res) return res @@ -301,7 +301,7 @@ def copyto(self, other): return return _internal._copyto(self, out=other) elif isinstance(other, Context): - hret = _ndarray_cls(_new_alloc_handle(self.storage_type, self.shape, other, + hret = _ndarray_cls(_new_alloc_handle(self.stype, self.shape, other, True, self.dtype, self.aux_types)) return _internal._copyto(self, out=hret) else: @@ -525,14 +525,15 @@ def todense(source): return ndarray.cast_storage(source, storage_type='default') -def _ndarray_cls(handle, writable=True): - stype = _storage_type(handle) +def _ndarray_cls(handle, writable=True, stype=None): + if stype is None: + stype = _storage_type(handle) if stype == 'default': - return NDArray(handle, writable=writable) + return NDArray(handle, writable=writable, stype=stype) elif stype == 'csr': - return CSRNDArray(handle, writable=writable) + return CSRNDArray(handle, writable=writable, stype=stype) elif stype == 'row_sparse': - return RowSparseNDArray(handle, writable=writable) + return RowSparseNDArray(handle, writable=writable, stype=stype) else: raise Exception("unknown storage type") diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 3dd491ea2c30..615b0231d750 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -489,13 +489,20 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, NDArrayHandle **outputs, int num_params, const char **param_keys, - const char **param_vals) { + const char **param_vals, + const int** out_stypes) { // outputs storage types const nnvm::Op* op = static_cast(creator); - + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); API_BEGIN(); nnvm::NodeAttrs attrs; SetOpAttrs(op, &attrs, num_inputs, num_params, param_keys, param_vals); ImperativeInvokeImpl(attrs, num_inputs, inputs, num_outputs, outputs); + NDArray** output_nds = reinterpret_cast(*outputs); + ret->out_types.resize(*num_outputs); + for (int i = 0; i < *num_outputs; ++i) { + ret->out_types[i] = output_nds[i]->storage_type(); + } + *out_stypes = dmlc::BeginPtr(ret->out_types); API_END(); } diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 5801bb1829d3..96fd77334d8d 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -463,7 +463,7 @@ def regression_model(m): mod.bind(data_shapes=iterator.provide_data, label_shapes=iterator.provide_label) mod.init_params() v = mod._arg_params['v'] - assert(v.storage_type == 'row_sparse') + assert(v.stype == 'row_sparse') assert(np.sum(v.asnumpy()) != 0) if __name__ == '__main__': diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index a09857b95efe..66e13801cc30 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -96,7 +96,7 @@ def check_sparse_nd_prop_rsp(): nd, (v, idx) = rand_sparse_ndarray(shape, storage_type) assert(nd._num_aux == 1) assert(nd.indices.dtype == np.int64) - assert(nd.storage_type == 'row_sparse') + assert(nd.stype == 'row_sparse') assert_almost_equal(nd.indices.asnumpy(), idx) diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 1fc64a7149ea..f712da00051d 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -109,9 +109,9 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1): rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense() out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs) if trans_lhs: - assert out.storage_type == 'row_sparse' + assert out.stype == 'row_sparse' else: - assert out.storage_type == 'default' + assert out.stype == 'default' out_expected = mx.nd.dot(lhs_dns, rhs_dns, transpose_a=trans_lhs) out_np = out_expected.asnumpy() backward_trans = not trans_lhs