Skip to content

Commit

Permalink
Change storage_type to stype
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Jul 5, 2017
1 parent 5b6f5cc commit 90db7d1
Show file tree
Hide file tree
Showing 14 changed files with 81 additions and 79 deletions.
4 changes: 2 additions & 2 deletions python/mxnet/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def _imperative_invoke(handle, ndargs, keys, vals, out):
return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle),
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]])
else:
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle,
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]]))
return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle),
stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]])
for i in range(num_output.value)]


Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/module/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def bind(self, data_shapes, label_shapes=None, for_training=True,
else:
assert self._arg_params is None and self._aux_params is None
param_arrays = [
mx.nd.zeros(shape=x[0].shape, dtype=x[0].dtype, storage_type=x[0].stype)
mx.nd.zeros(shape=x[0].shape, dtype=x[0].dtype, stype=x[0].stype)
for x in self._exec_group.param_arrays
]
self._arg_params = {name:arr for name, arr in zip(self._param_names, param_arrays)}
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,11 +973,11 @@ def backward(self, out_grad=None, retain_graph=False):

def _to_csr(self):
# pylint: disable=undefined-variable
return cast_storage(self, storage_type='csr')
return cast_storage(self, stype='csr')

def _to_rsp(self):
# pylint: disable=undefined-variable
return cast_storage(self, storage_type='row_sparse')
return cast_storage(self, stype='row_sparse')

def onehot_encode(indices, out):
"""One-hot encoding indices into matrix out.
Expand Down
20 changes: 10 additions & 10 deletions python/mxnet/ndarray_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def _zeros_ndarray(shape, ctx=None, dtype=None, **kwargs):
# pylint: enable= no-member, protected-access


def _zeros_sparse_ndarray(storage_type, shape, ctx=None, dtype=None, aux_types=None, **kwargs):
def _zeros_sparse_ndarray(stype, shape, ctx=None, dtype=None, aux_types=None, **kwargs):
"""Return a new array of given shape and type, filled with zeros.
Parameters
----------
shape : int or tuple of int
The shape of the empty array
storage_type: string
stype: string
The storage type of the empty array, such as 'row_sparse', 'csr', etc
ctx : Context, optional
An optional device context (default is the current default context)
Expand All @@ -76,26 +76,26 @@ def _zeros_sparse_ndarray(storage_type, shape, ctx=None, dtype=None, aux_types=N
>>> mx.sparse_nd.zeros('row_sparse', (1,2), mx.gpu(0), 'float16').asnumpy()
array([[ 0., 0.]], dtype=float16)
"""
if storage_type == 'default':
if stype == 'default':
return _zeros_ndarray(shape, ctx=ctx, dtype=dtype, **kwargs)
if ctx is None:
ctx = Context.default_ctx
dtype = mx_real_t if dtype is None else dtype
if aux_types is None:
if storage_type == 'row_sparse' or storage_type == 'csr':
aux_types = _STORAGE_AUX_TYPES[storage_type]
if stype == 'row_sparse' or stype == 'csr':
aux_types = _STORAGE_AUX_TYPES[stype]
else:
raise Exception("unknown storage type")
assert(len(aux_types) == len(_STORAGE_AUX_TYPES[storage_type]))
out = _ndarray_cls(_new_alloc_handle(storage_type, shape, ctx, True, dtype, aux_types))
assert(len(aux_types) == len(_STORAGE_AUX_TYPES[stype]))
out = _ndarray_cls(_new_alloc_handle(stype, shape, ctx, True, dtype, aux_types))
return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out, **kwargs)


def zeros(shape, ctx=None, dtype=None, storage_type=None, aux_types=None, **kwargs):
if storage_type is None:
def zeros(shape, ctx=None, dtype=None, stype=None, aux_types=None, **kwargs):
if stype is None:
return _zeros_ndarray(shape, ctx, dtype, **kwargs)
else:
return _zeros_sparse_ndarray(storage_type, shape, ctx, dtype, aux_types, **kwargs)
return _zeros_sparse_ndarray(stype, shape, ctx, dtype, aux_types, **kwargs)


def load(fname):
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def create_state(self, index, weight):
return None
else:
return mx.nd.zeros(shape=weight.shape, ctx=weight.context,
dtype=weight.dtype, storage_type=weight.stype)
dtype=weight.dtype, stype=weight.stype)

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
Expand Down
8 changes: 4 additions & 4 deletions python/mxnet/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
}


def _new_alloc_handle(storage_type, shape, ctx, delay_alloc, dtype, aux_types, aux_shapes=None):
def _new_alloc_handle(stype, shape, ctx, delay_alloc, dtype, aux_types, aux_shapes=None):
"""Return a new handle with specified storage type, shape, dtype and context.
Empty handle is only used to hold results
Expand All @@ -65,7 +65,7 @@ def _new_alloc_handle(storage_type, shape, ctx, delay_alloc, dtype, aux_types, a
aux_shapes = sum(aux_shapes, ())
num_aux = mx_uint(len(aux_types))
check_call(_LIB.MXNDArrayCreateSparseEx(
ctypes.c_int(int(_STORAGE_TYPE_STR_TO_ID[storage_type])),
ctypes.c_int(int(_STORAGE_TYPE_STR_TO_ID[stype])),
c_array(mx_uint, shape),
mx_uint(len(shape)),
ctypes.c_int(ctx.device_typeid),
Expand Down Expand Up @@ -270,7 +270,7 @@ def astype(self, dtype):
<type 'numpy.int32'>
"""
res = mx.nd.zeros(shape=self.shape, ctx=self.context,
dtype=dtype, storage_type=self.stype)
dtype=dtype, stype=self.stype)
self.copyto(res)
return res

Expand Down Expand Up @@ -522,7 +522,7 @@ def todense(source):
NDArray
The dense array with default storage
"""
return ndarray.cast_storage(source, storage_type='default')
return ndarray.cast_storage(source, stype='default')


def _ndarray_cls(handle, writable=True, stype=None):
Expand Down
18 changes: 10 additions & 8 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1162,7 +1162,7 @@ def _get_ndarray_inputs(arg_key, args, arg_names, allow_missing):
raise TypeError('Only accept list of NDArrays or dict of str to NDArray')
return c_array(NDArrayHandle, arg_handles), arg_arrays

def simple_bind(self, ctx, grad_req='write', type_dict=None, storage_type_dict=None,
def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,
group2ctx=None, shared_arg_names=None, shared_exec=None,
shared_buffer=None, **kwargs):
"""Bind current symbol to get an executor, allocate all the arguments needed.
Expand Down Expand Up @@ -1206,7 +1206,7 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, storage_type_dict=N
type_dict : Dict of str->numpy.dtype
Input type dictionary, name->dtype
storage_type_dict : Dict of str->str
stype_dict : Dict of str->str
Input storage type dictionary, name->storage_type
group2ctx : Dict of string to mx.Context
Expand Down Expand Up @@ -1255,10 +1255,10 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, storage_type_dict=N
# provided storage type argument names
provided_arg_stype_names = ctypes.POINTER(ctypes.c_char_p)()
provided_arg_stype_data = ctypes.POINTER(mx_uint)() # provided storage types
if storage_type_dict is not None:
if stype_dict is not None:
provided_arg_stype_names = []
provided_arg_stype_data = []
for k, v in storage_type_dict.items():
for k, v in stype_dict.items():
if v in _STORAGE_TYPE_STR_TO_ID:
provided_arg_stype_names.append(c_str(k))
provided_arg_stype_data.append(ctypes.c_int(_STORAGE_TYPE_STR_TO_ID[v]))
Expand Down Expand Up @@ -1339,7 +1339,7 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, storage_type_dict=N
shared_buffer_names = []
shared_buffer_handles = []
for k, v in shared_buffer.items():
assert(v.storage_type == 'default'), \
assert(v.stype == 'default'), \
"shared_buffer is expected to only contain NDArrays with default storage"
shared_buffer_names.append(c_str(k))
shared_buffer_handles.append(v.handle)
Expand Down Expand Up @@ -1669,7 +1669,7 @@ def reshape(self, shape):
return reshape(self, shape=shape)

def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,
init=None, storage_type=None, **kwargs):
init=None, stype=None, **kwargs):
"""Creates a symbolic variable with specified name.
Example usage:
Expand All @@ -1696,6 +1696,8 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,
The dtype for input variable. If not specified, this value will be inferred.
init : initializer (mxnet.init.*)
Initializer for this variable to (optionally) override the default initializer.
stype : str
The storage type of the variable.
kwargs : Additional attribute variables
Additional attributes must start and end with double underscores.
Expand Down Expand Up @@ -1723,8 +1725,8 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None,
if not isinstance(init, string_types):
init = init.dumps()
attr['__init__'] = init
if storage_type is not None:
attr['__storage_type__'] = str(_STORAGE_TYPE_STR_TO_ID[storage_type])
if stype is not None:
attr['__storage_type__'] = str(_STORAGE_TYPE_STR_TO_ID[stype])
for k, v in kwargs.items():
if k.startswith('__') and k.endswith('__'):
attr[k] = str(v)
Expand Down
18 changes: 9 additions & 9 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def random_sample(population, k):
return population_copy[0:k]


def rand_sparse_ndarray(shape, storage_type, density=None):
def rand_sparse_ndarray(shape, stype, density=None):
"""Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np) """
density = rnd.rand() if density is None else density
if storage_type == 'row_sparse':
if stype == 'row_sparse':
# TODO(haibin) support high dim sparse ndarray
assert(len(shape) < 3)
prod = np.prod(shape)
Expand All @@ -88,13 +88,13 @@ def rand_sparse_ndarray(shape, storage_type, density=None):
idx_sample = rnd.rand(shape[0])
indices = np.argwhere(idx_sample < density).flatten()
if indices.shape[0] == 0:
result = mx.nd.zeros(shape, storage_type='row_sparse')
result = mx.nd.zeros(shape, stype='row_sparse')
return result, (np.array([], dtype='int64'), np.array([], dtype='int64'))
# generate random values
val = rnd.rand(indices.shape[0], num_cols)
arr = mx.sparse_nd.row_sparse(val, indices, shape, indices_type=np.int64)
return arr, (val, indices)
elif storage_type == 'csr':
elif stype == 'csr':
assert(len(shape) == 2)
csr = sp.rand(shape[0], shape[1], density=density, format='csr')
result = mx.sparse_nd.csr(csr.data, csr.indptr, csr.indices, shape)
Expand All @@ -103,11 +103,11 @@ def rand_sparse_ndarray(shape, storage_type, density=None):
assert(False), "unknown storage type"


def rand_ndarray(shape, storage_type, density=None):
if storage_type == 'default':
def rand_ndarray(shape, stype, density=None):
if stype == 'default':
arr = mx.nd.array(random_arrays(shape))
else:
arr, _ = rand_sparse_ndarray(shape, storage_type, density=density)
arr, _ = rand_sparse_ndarray(shape, stype, density=density)
return arr


Expand Down Expand Up @@ -554,7 +554,7 @@ def random_projection(shape):
assert isinstance(grad_stype_dict, dict), "grad_stype_dict must be a dict"
for k, v in grad_stype_dict.items():
if k in args_grad and v in _STORAGE_TYPE_STR_TO_ID and v != 'default':
args_grad[k] = mx.nd.cast_storage(args_grad[k], storage_type=v)
args_grad[k] = mx.nd.cast_storage(args_grad[k], stype=v)

executor = out.bind(ctx, grad_req=grad_req,
args=location, args_grad=args_grad, aux_states=aux_states)
Expand Down Expand Up @@ -724,7 +724,7 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=
grad_stype = attr.get('grad_stype_hint', None)
nd = mx.nd.array(v, ctx=ctx)
if grad_stype is not None:
out = mx.nd.cast_storage(nd, storage_type=grad_stype)
out = mx.nd.cast_storage(nd, stype=grad_stype)
args_grad_data[k] = out
else:
args_grad_data[k] = nd
Expand Down
8 changes: 4 additions & 4 deletions src/operator/tensor/cast_storage-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,9 @@ void CastStorageComputeImpl(mshadow::Stream<xpu>* s,
}

struct CastStorageParam : public dmlc::Parameter<CastStorageParam> {
int storage_type;
int stype;
DMLC_DECLARE_PARAMETER(CastStorageParam) {
DMLC_DECLARE_FIELD(storage_type)
DMLC_DECLARE_FIELD(stype)
.add_enum("default", kDefaultStorage)
.add_enum("row_sparse", kRowSparseStorage)
.add_enum("csr", kCSRStorage)
Expand All @@ -310,9 +310,9 @@ inline bool CastStorageInferStorageType(const nnvm::NodeAttrs& attrs,
CHECK_NE(in_attrs->at(0), kUndefinedStorage)
<< "src ndarray's storage type must be specified";
const CastStorageParam& param = nnvm::get<CastStorageParam>(attrs.parsed);
CHECK_NE(param.storage_type, kUndefinedStorage)
CHECK_NE(param.stype, kUndefinedStorage)
<< "dst ndarray's storage type must be specified";
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.storage_type);
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.stype);
return true;
}

Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def init_kv(stype='default'):
"""init kv """
kv = mx.kv.create()
# single
kv.init(3, mx.nd.zeros(shape=shape, storage_type=stype))
kv.init(3, mx.nd.zeros(shape=shape, stype=stype))
# list
kv.init(keys, [mx.nd.zeros(shape=shape, storage_type=stype)] * len(keys))
kv.init(keys, [mx.nd.zeros(shape=shape, stype=stype)] * len(keys))
return kv

def init_kv_with_str():
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,10 @@ def test_module_fm():
rnd.seed(11)
def fm_model(k, feature_dim):
norm = mx.initializer.Normal(sigma=0.01)
x = mx.symbol.Variable("data", storage_type='csr')
v = mx.symbol.Variable("v", shape=(feature_dim, k), init=norm, storage_type='row_sparse')
x = mx.symbol.Variable("data", stype='csr')
v = mx.symbol.Variable("v", shape=(feature_dim, k), init=norm, stype='row_sparse')

w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=norm, storage_type='row_sparse')
w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=norm, stype='row_sparse')
w1 = mx.symbol.dot(x, w1_weight)

v_s = mx.symbol.sum(data=mx.symbol.square(data=v), axis=1)
Expand Down Expand Up @@ -443,9 +443,9 @@ def fm_model(k, feature_dim):

def test_module_initializer():
def regression_model(m):
x = mx.symbol.var("data", storage_type='csr')
x = mx.symbol.var("data", stype='csr')
v = mx.symbol.var("v", shape=(m, 1), init=mx.init.Uniform(scale=.1),
storage_type='row_sparse')
stype='row_sparse')
model = mx.symbol.dot(lhs=x, rhs=v)
y = mx.symbol.Variable("label")
model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out")
Expand All @@ -454,7 +454,7 @@ def regression_model(m):
n, m = 128, 100
model = regression_model(m)

data = mx.nd.zeros(shape=(n, m), storage_type='csr')
data = mx.nd.zeros(shape=(n, m), stype='csr')
label = mx.nd.zeros((n, 1))
iterator = mx.io.NDArrayIter(data=data, label={'label':label}, batch_size=n)

Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_multi_device_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def test_ctx_group():

def check_ctx_group_sparse(lhs_stype, rhs_stype):
with mx.AttrScope(ctx_group='stage1'):
lhs = mx.symbol.Variable('lhs', storage_type=lhs_stype)
rhs = mx.symbol.Variable('rhs', storage_type=rhs_stype)
lhs = mx.symbol.Variable('lhs', stype=lhs_stype)
rhs = mx.symbol.Variable('rhs', stype=rhs_stype)
plus = mx.symbol.elemwise_add(lhs, rhs, name='plus')

set_stage1 = set(plus.list_arguments())
Expand Down
Loading

0 comments on commit 90db7d1

Please sign in to comment.