Skip to content

Commit

Permalink
Save stype in frontend to avoid c-api call for stype
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Jul 5, 2017
1 parent 69a75b6 commit 5b6f5cc
Show file tree
Hide file tree
Showing 10 changed files with 57 additions and 35 deletions.
3 changes: 2 additions & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,8 @@ MXNET_DLL int MXImperativeInvoke(AtomicSymbolCreator creator,
NDArrayHandle **outputs,
int num_params,
const char **param_keys,
const char **param_vals);
const char **param_vals,
const int** out_stypes);
/*!
* \brief set whether to record operator for autograd
* \param is_train 1 when training, 0 when testing
Expand Down
27 changes: 22 additions & 5 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,20 @@
from .common import CachedOp


_STORAGE_TYPE_ID_TO_STR = {
-1 : 'undefined',
0 : 'default',
1 : 'row_sparse',
2 : 'csr',
}


class NDArrayBase(object):
"""Base data structure for ndarray"""
__slots__ = ["handle", "writable"]
__slots__ = ["handle", "writable", "_stype"]
# pylint: disable= no-member
def __init__(self, handle, writable=True):

def __init__(self, handle, writable=True, stype=None):
"""initialize a new NDArray
Parameters
Expand All @@ -32,6 +41,7 @@ def __init__(self, handle, writable=True):
assert isinstance(handle, NDArrayHandle)
self.handle = handle
self.writable = writable
self._stype = stype

def __del__(self):
check_call(_LIB.MXNDArrayFree(self.handle))
Expand Down Expand Up @@ -62,6 +72,10 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
output_vars = ctypes.POINTER(NDArrayHandle)()
num_output = ctypes.c_int(0)

# return output stypes to avoid the c_api call for checking
# a handle's stype in _ndarray_cls
out_stypes = ctypes.POINTER(ctypes.c_int)()

check_call(_LIB.MXImperativeInvoke(
ctypes.c_void_p(handle),
ctypes.c_int(len(ndargs)),
Expand All @@ -70,14 +84,17 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
ctypes.byref(output_vars),
ctypes.c_int(len(keys)),
c_array(ctypes.c_char_p, [c_str(key) for key in keys]),
c_array(ctypes.c_char_p, [c_str(str(val)) for val in vals])))
c_array(ctypes.c_char_p, [c_str(str(val)) for val in vals]),
ctypes.byref(out_stypes)))

if original_output is not None:
return original_output
if num_output.value == 1:
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle))
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]])
else:
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle))
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle,
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]]))
for i in range(num_output.value)]


Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
else:
assert self._arg_params is None and self._aux_params is None
param_arrays = [
mx.nd.zeros(shape=x[0].shape, dtype=x[0].dtype, storage_type=x[0].storage_type)
mx.nd.zeros(shape=x[0].shape, dtype=x[0].dtype, storage_type=x[0].stype)
for x in self._exec_group.param_arrays
]
self._arg_params = {name:arr for name, arr in zip(self._param_names, param_arrays)}
Expand Down
20 changes: 8 additions & 12 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,18 @@
# pylint: disable=unused-import
try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class
from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _STORAGE_TYPE_ID_TO_STR
from ._ctypes.ndarray import invoke, CachedOp, _imperative_invoke
elif _sys.version_info >= (3, 0):
from ._cy3.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke
from ._cy3.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR
from ._cy3.ndarray import invoke, CachedOp, _imperative_invoke
else:
from ._cy2.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke
from ._cy2.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR
from ._cy2.ndarray import invoke, CachedOp, _imperative_invoke
except ImportError:
if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0:
raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke
from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke, _STORAGE_TYPE_ID_TO_STR
from ._ctypes.ndarray import invoke, CachedOp, _imperative_invoke
# pylint: enable=unused-import

Expand All @@ -64,12 +64,6 @@
4 : np.int32,
6 : np.int64
}
_STORAGE_TYPE_ID_TO_STR = {
-1 : 'undefined',
0 : 'default',
1 : 'row_sparse',
2 : 'csr',
}
_STORAGE_TYPE_STR_TO_ID = {
'undefined' : -1,
'default' : 0,
Expand Down Expand Up @@ -748,8 +742,10 @@ def dtype(self):
return _DTYPE_MX_TO_NP[mx_dtype.value]

@property
def storage_type(self):
return _storage_type(self.handle)
def stype(self):
if self._stype is None:
self._stype = _storage_type(self.handle)
return self._stype

@property
# pylint: disable= invalid-name, undefined-variable
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def create_state(self, index, weight):
return None
else:
return mx.nd.zeros(shape=weight.shape, ctx=weight.context,
dtype=weight.dtype, storage_type=weight.storage_type)
dtype=weight.dtype, storage_type=weight.stype)

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
Expand Down
19 changes: 10 additions & 9 deletions python/mxnet/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __getitem__(self, key):
>>> x[1:2].asnumpy()
array([[ 3., 4., 5.]], dtype=float32)
"""
stype = self.storage_type
stype = self.stype
if stype != 'csr':
raise Exception("__getitem__ for " + str(stype) + " not implemented yet")
if isinstance(key, int):
Expand Down Expand Up @@ -233,7 +233,7 @@ def values(self):
def _num_aux(self):
''' The number of aux data used to help store the sparse ndarray.
'''
return len(_STORAGE_AUX_TYPES[self.storage_type])
return len(_STORAGE_AUX_TYPES[self.stype])

@property
# pylint: disable= invalid-name, undefined-variable
Expand Down Expand Up @@ -270,7 +270,7 @@ def astype(self, dtype):
<type 'numpy.int32'>
"""
res = mx.nd.zeros(shape=self.shape, ctx=self.context,
dtype=dtype, storage_type=self.storage_type)
dtype=dtype, storage_type=self.stype)
self.copyto(res)
return res

Expand Down Expand Up @@ -301,7 +301,7 @@ def copyto(self, other):
return
return _internal._copyto(self, out=other)
elif isinstance(other, Context):
hret = _ndarray_cls(_new_alloc_handle(self.storage_type, self.shape, other,
hret = _ndarray_cls(_new_alloc_handle(self.stype, self.shape, other,
True, self.dtype, self.aux_types))
return _internal._copyto(self, out=hret)
else:
Expand Down Expand Up @@ -525,14 +525,15 @@ def todense(source):
return ndarray.cast_storage(source, storage_type='default')


def _ndarray_cls(handle, writable=True):
stype = _storage_type(handle)
def _ndarray_cls(handle, writable=True, stype=None):
if stype is None:
stype = _storage_type(handle)
if stype == 'default':
return NDArray(handle, writable=writable)
return NDArray(handle, writable=writable, stype=stype)
elif stype == 'csr':
return CSRNDArray(handle, writable=writable)
return CSRNDArray(handle, writable=writable, stype=stype)
elif stype == 'row_sparse':
return RowSparseNDArray(handle, writable=writable)
return RowSparseNDArray(handle, writable=writable, stype=stype)
else:
raise Exception("unknown storage type")

Expand Down
11 changes: 9 additions & 2 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,13 +489,20 @@ int MXImperativeInvoke(AtomicSymbolCreator creator,
NDArrayHandle **outputs,
int num_params,
const char **param_keys,
const char **param_vals) {
const char **param_vals,
const int** out_stypes) { // outputs storage types
const nnvm::Op* op = static_cast<nnvm::Op*>(creator);

MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get();
API_BEGIN();
nnvm::NodeAttrs attrs;
SetOpAttrs(op, &attrs, num_inputs, num_params, param_keys, param_vals);
ImperativeInvokeImpl(attrs, num_inputs, inputs, num_outputs, outputs);
NDArray** output_nds = reinterpret_cast<NDArray**>(*outputs);
ret->out_types.resize(*num_outputs);
for (int i = 0; i < *num_outputs; ++i) {
ret->out_types[i] = output_nds[i]->storage_type();
}
*out_stypes = dmlc::BeginPtr(ret->out_types);
API_END();
}

Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def regression_model(m):
mod.bind(data_shapes=iterator.provide_data, label_shapes=iterator.provide_label)
mod.init_params()
v = mod._arg_params['v']
assert(v.storage_type == 'row_sparse')
assert(v.stype == 'row_sparse')
assert(np.sum(v.asnumpy()) != 0)

if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def check_sparse_nd_prop_rsp():
nd, (v, idx) = rand_sparse_ndarray(shape, storage_type)
assert(nd._num_aux == 1)
assert(nd.indices.dtype == np.int64)
assert(nd.storage_type == 'row_sparse')
assert(nd.stype == 'row_sparse')
assert_almost_equal(nd.indices.asnumpy(), idx)


Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1):
rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense()
out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs)
if trans_lhs:
assert out.storage_type == 'row_sparse'
assert out.stype == 'row_sparse'
else:
assert out.storage_type == 'default'
assert out.stype == 'default'
out_expected = mx.nd.dot(lhs_dns, rhs_dns, transpose_a=trans_lhs)
out_np = out_expected.asnumpy()
backward_trans = not trans_lhs
Expand Down

0 comments on commit 5b6f5cc

Please sign in to comment.