Skip to content

Commit

Permalink
change idx types from int32 to int64
Browse files Browse the repository at this point in the history
Conflicts:
	python/mxnet/test_utils.py
	tests/python/unittest/test_sparse_operator.py

update mshadow submodule

fix extra quotes in test script

change indptr type to int64

better err message for rsp"
  • Loading branch information
eric-haibin-lin committed Jun 19, 2017
1 parent 22fb07f commit 1b94692
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 82 deletions.
8 changes: 2 additions & 6 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ class AutogradRuntime;
} // namespace autograd

// enum for storage types
#define CSR_IND_PTR_TYPE mshadow::kInt32
#define CSR_IDX_DTYPE mshadow::kInt32
#define ROW_SPARSE_IDX_TYPE mshadow::kInt32
// FIXME int64_t is not available mshadow
namespace csr {
enum CSRAuxType {kIndPtr, kIdx};
}
Expand Down Expand Up @@ -114,9 +110,9 @@ class NDArray {
// Assign default aux types if not given
if (aux_types.size() == 0) {
if (stype == kRowSparseStorage) {
aux_types = {ROW_SPARSE_IDX_TYPE};
aux_types = {mshadow::kInt64};
} else if (stype == kCSRStorage) {
aux_types = {CSR_IND_PTR_TYPE, CSR_IDX_DTYPE};
aux_types = {mshadow::kInt64, mshadow::kInt64};
} else {
LOG(FATAL) << "Unknown storage type " << stype;
}
Expand Down
2 changes: 1 addition & 1 deletion mshadow
6 changes: 4 additions & 2 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@
np.float64 : 1,
np.float16 : 2,
np.uint8 : 3,
np.int32 : 4
np.int32 : 4,
np.int64 : 6
}
_DTYPE_MX_TO_NP = {
0 : np.float32,
1 : np.float64,
2 : np.float16,
3 : np.uint8,
4 : np.int32
4 : np.int32,
6 : np.int64
}
_STORAGE_TYPE_ID_TO_STR = {
-1 : 'undefined',
Expand Down
10 changes: 5 additions & 5 deletions python/mxnet/sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@

# pylint: enable=unused-import
_STORAGE_AUX_TYPES = {
'row_sparse': [np.int32],
'csr': [np.int32, np.int32]
'row_sparse': [np.int64],
'csr': [np.int64, np.int64]
}


Expand Down Expand Up @@ -483,8 +483,8 @@ def csr(values, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None,
indices, indices_type = _prepare_src_array(indices, indices_type,
_STORAGE_AUX_TYPES[storage_type][1])
# verify types
assert('int' in str(indptr_type) or 'long' in str(indptr_type))
assert('int' in str(indices_type) or 'long' in str(indices_type))
assert('int64' in str(indptr_type)), "expected int64 for indptr"
assert('int64' in str(indices_type)), "expected int64 for indices"
# verify shapes
aux_shapes = [indptr.shape, indices.shape]
assert(values.ndim == 1)
Expand Down Expand Up @@ -536,7 +536,7 @@ def row_sparse(values, indices, shape, ctx=None, dtype=None, indices_type=None):
indices, indices_type = _prepare_src_array(indices, indices_type,
_STORAGE_AUX_TYPES[storage_type][0])
# verify types
assert('int' in str(indices_type) or 'long' in str(indices_type))
assert('int64' in str(indices_type)), "expected int64 for indices"
# verify shapes
assert(values.ndim == len(shape))
assert(indices.ndim == 1)
Expand Down
5 changes: 2 additions & 3 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def random_sample(population, k):
return population_copy[0:k]


# TODO(haibin) also include types in arguments
def rand_sparse_ndarray(shape, storage_type, density=None):
"""Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np) """
density = rnd.rand() if density is None else density
Expand All @@ -90,10 +89,10 @@ def rand_sparse_ndarray(shape, storage_type, density=None):
indices = np.argwhere(idx_sample < density).flatten()
if indices.shape[0] == 0:
result = mx.nd.zeros(shape, storage_type='row_sparse')
return result, (np.array([]), np.array([], dtype='int32'))
return result, (np.array([], dtype='int64'), np.array([], dtype='int64'))
# generate random values
val = rnd.rand(indices.shape[0], num_cols)
arr = mx.sparse_nd.row_sparse(val, indices, shape, indices_type=np.int32)
arr = mx.sparse_nd.row_sparse(val, indices, shape, indices_type=np.int64)
return arr, (val, indices)
elif storage_type == 'csr':
assert(len(shape) == 2)
Expand Down
24 changes: 12 additions & 12 deletions src/io/iter_libsvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ class LibSVMIter: public SparseIIterator<DataInst> {
// intialize iterator loads data in
virtual void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) {
param_.InitAllowUnknown(kwargs);
data_parser_.reset(dmlc::Parser<uint32_t>::Create(param_.data_libsvm.c_str(),
0, 1, "libsvm"));
CHECK_EQ(param_.data_shape.ndim(), 1) << "dimension of data_shape is expected to be 1";
data_parser_.reset(dmlc::Parser<uint64_t>::Create(param_.data_libsvm.c_str(),
0, 1, "libsvm"));
if (param_.label_libsvm != "NULL") {
label_parser_.reset(dmlc::Parser<uint32_t>::Create(param_.label_libsvm.c_str(),
label_parser_.reset(dmlc::Parser<uint64_t>::Create(param_.label_libsvm.c_str(),
0, 1, "libsvm"));
CHECK_GT(param_.label_shape.Size(), 1)
<< "label_shape is not expected to be (1,) when param_.label_libsvm is set.";
Expand Down Expand Up @@ -129,23 +129,23 @@ class LibSVMIter: public SparseIIterator<DataInst> {
}

private:
inline TBlob AsDataBlob(const dmlc::Row<uint32_t>& row) {
inline TBlob AsDataBlob(const dmlc::Row<uint64_t>& row) {
const real_t* ptr = row.value;
TShape shape(mshadow::Shape1(row.length));
return TBlob((real_t*) ptr, shape, cpu::kDevMask); // NOLINT(*)
}

inline TBlob AsIdxBlob(const dmlc::Row<uint32_t>& row) {
const uint32_t* ptr = row.index;
inline TBlob AsIdxBlob(const dmlc::Row<uint64_t>& row) {
const uint64_t* ptr = row.index;
TShape shape(mshadow::Shape1(row.length));
return TBlob((int32_t*) ptr, shape, cpu::kDevMask, CSR_IDX_DTYPE); // NOLINT(*)
return TBlob((int64_t*) ptr, shape, cpu::kDevMask, mshadow::kInt64); // NOLINT(*)
}

inline TBlob AsIndPtrPlaceholder(const dmlc::Row<uint32_t>& row) {
return TBlob(nullptr, mshadow::Shape1(0), cpu::kDevMask, CSR_IND_PTR_TYPE);
inline TBlob AsIndPtrPlaceholder(const dmlc::Row<uint64_t>& row) {
return TBlob(nullptr, mshadow::Shape1(0), cpu::kDevMask, mshadow::kInt64);
}

inline TBlob AsScalarLabelBlob(const dmlc::Row<uint32_t>& row) {
inline TBlob AsScalarLabelBlob(const dmlc::Row<uint64_t>& row) {
const real_t* ptr = row.label;
return TBlob((real_t*) ptr, mshadow::Shape1(1), cpu::kDevMask); // NOLINT(*)
}
Expand All @@ -160,8 +160,8 @@ class LibSVMIter: public SparseIIterator<DataInst> {
// label parser
size_t label_ptr_{0}, label_size_{0};
size_t data_ptr_{0}, data_size_{0};
std::unique_ptr<dmlc::Parser<uint32_t> > label_parser_;
std::unique_ptr<dmlc::Parser<uint32_t> > data_parser_;
std::unique_ptr<dmlc::Parser<uint64_t> > label_parser_;
std::unique_ptr<dmlc::Parser<uint64_t> > data_parser_;
};


Expand Down
47 changes: 23 additions & 24 deletions src/io/iter_sparse_batchloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,32 +71,31 @@ class SparseBatchLoader : public BatchLoader, public SparseIIterator<TBlobBatch>
out_.num_batch_padd = num_overflow_;
CHECK_EQ(inst_cache_.size(), param_.batch_size);
this->InitDataFromBatch();
MSHADOW_INT_TYPE_SWITCH(CSR_IND_PTR_TYPE, IType, {
for (size_t j = 0; j < inst_cache_.size(); j++) {
const auto& d = inst_cache_[j];
out_.inst_index[top] = d.index;
size_t unit_size = 0;
for (size_t i = 0; i < d.data.size(); ++i) {
// indptr tensor
if (IsIndPtr(i)) {
auto indptr = data_[i].get<cpu, 1, IType>();
if (j == 0) indptr[0] = 0;
indptr[j + 1] = indptr[j] + (IType) unit_size;
offsets_[i] = j;
} else {
// indices and values tensor
unit_size = d.data[i].shape_.Size();
MSHADOW_TYPE_SWITCH(data_[i].type_flag_, DType, {
const auto begin = offsets_[i];
const auto end = offsets_[i] + unit_size;
mshadow::Copy(data_[i].get<cpu, 1, DType>().Slice(begin, end),
d.data[i].get_with_shape<cpu, 1, DType>(mshadow::Shape1(unit_size)));
});
offsets_[i] += unit_size;
}
for (size_t j = 0; j < inst_cache_.size(); j++) {
const auto& d = inst_cache_[j];
out_.inst_index[top] = d.index;
// TODO(haibin) double check the type?
int64_t unit_size = 0;
for (size_t i = 0; i < d.data.size(); ++i) {
// indptr tensor
if (IsIndPtr(i)) {
auto indptr = data_[i].get<cpu, 1, int64_t>();
if (j == 0) indptr[0] = 0;
indptr[j + 1] = indptr[j] + unit_size;
offsets_[i] = j;
} else {
// indices and values tensor
unit_size = d.data[i].shape_.Size();
MSHADOW_TYPE_SWITCH(data_[i].type_flag_, DType, {
const auto begin = offsets_[i];
const auto end = offsets_[i] + unit_size;
mshadow::Copy(data_[i].get<cpu, 1, DType>().Slice(begin, end),
d.data[i].get_with_shape<cpu, 1, DType>(mshadow::Shape1(unit_size)));
});
offsets_[i] += unit_size;
}
}
});
}
return true;
}

Expand Down
42 changes: 21 additions & 21 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1182,28 +1182,28 @@ void SliceCsrImpl(const SliceParam &param, const OpContext& ctx,
out.set_aux_shape(kIndPtr, Shape1(0));
return;
}
CHECK_EQ(in.aux_type(kIndPtr), in.aux_type(kIdx))
<< "The type for indptr and indices are different. This is not implemented yet.";
// assume idx indptr share the same type
MSHADOW_INT_TYPE_SWITCH(in.aux_type(kIndPtr), IType, {
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
auto in_indptr = in.aux_data(kIndPtr).dptr<IType>();
auto out_indptr = out.aux_data(kIndPtr).dptr<IType>();
SliceCsrIndPtrImpl<cpu, IType>(begin, end, ctx.run_ctx, in_indptr, out_indptr);

// retrieve nnz (CPU implementation)
int nnz = out_indptr[indptr_len - 1];
// copy indices and values
out.CheckAndAllocAuxData(kIdx, Shape1(nnz));
out.CheckAndAllocData(Shape1(nnz));
auto in_idx = in.aux_data(kIdx).dptr<IType>();
auto out_idx = out.aux_data(kIdx).dptr<IType>();
auto in_data = in.data().dptr<DType>();
auto out_data = out.data().dptr<DType>();
int offset = in_indptr[begin];
// this is also a CPU-only implementation
memcpy(out_idx, in_idx + offset, nnz * sizeof(IType));
memcpy(out_data, in_data + offset, nnz * sizeof(DType));
MSHADOW_INT_TYPE_SWITCH(in.aux_type(kIndPtr), RType, {
MSHADOW_INT_TYPE_SWITCH(in.aux_type(kIdx), IType, {
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
auto in_indptr = in.aux_data(kIndPtr).dptr<RType>();
auto out_indptr = out.aux_data(kIndPtr).dptr<RType>();
SliceCsrIndPtrImpl<cpu, RType>(begin, end, ctx.run_ctx, in_indptr, out_indptr);

// retrieve nnz (CPU implementation)
int nnz = out_indptr[indptr_len - 1];
// copy indices and values
out.CheckAndAllocAuxData(kIdx, Shape1(nnz));
out.CheckAndAllocData(Shape1(nnz));
auto in_idx = in.aux_data(kIdx).dptr<IType>();
auto out_idx = out.aux_data(kIdx).dptr<IType>();
auto in_data = in.data().dptr<DType>();
auto out_data = out.data().dptr<DType>();
int offset = in_indptr[begin];
// this is also a CPU-only implementation
memcpy(out_idx, in_idx + offset, nnz * sizeof(IType));
memcpy(out_data, in_data + offset, nnz * sizeof(DType));
});
});
});
}
Expand Down
15 changes: 9 additions & 6 deletions tests/python/unittest/test_sparse_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def check_sparse_nd_prop_rsp():
shape = rand_shape_2d()
nd, (v, idx) = rand_sparse_ndarray(shape, storage_type)
assert(nd._num_aux == 1)
assert(nd.indices.dtype == np.int32)
assert(nd.indices.dtype == np.int64)
assert(nd.storage_type == 'row_sparse')
assert_almost_equal(nd.indices.asnumpy(), idx)

Expand All @@ -106,9 +106,12 @@ def check_rsp_creation(values, indices, shape):
dns = mx.nd.zeros(shape)
dns[1] = mx.nd.array(values[0])
dns[3] = mx.nd.array(values[1])
assert_almost_equal(rsp.asnumpy(), dns.asnumpy())
indices = mx.nd.array(indices).asnumpy()
assert_almost_equal(rsp.indices.asnumpy(), indices)
#assert_almost_equal(rsp.asnumpy(), dns.asnumpy())
print('before', indices)
print('mx', mx.nd.array(indices, dtype='int64')[1].asnumpy())
indices_np = mx.nd.array(indices, dtype='int64').asnumpy()
print('after', indices_np)
assert_almost_equal(rsp.indices.asnumpy(), indices_np)

def check_csr_creation(shape):
csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr')
Expand All @@ -118,11 +121,11 @@ def check_csr_creation(shape):

shape = (4,2)
values = np.random.rand(2,2)
indices = np.array([1,3])
indices = np.array([1,3], dtype='int64')
check_rsp_creation(values, indices, shape)

values = mx.nd.array(np.random.rand(2,2))
indices = mx.nd.array([1,3], dtype='int32')
indices = mx.nd.array([1,3], dtype='int64')
check_rsp_creation(values, indices, shape)

values = [[0.1, 0.2], [0.3, 0.4]]
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_sparse_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test_elemwise_add_ex_multiple_stages():

val1 = mx.nd.array([[5, 10]]);
val2 = mx.nd.array([[5, 10]]);
idx1 = mx.nd.array([0], dtype=np.int32);
idx2 = mx.nd.array([1], dtype=np.int32);
idx1 = mx.nd.array([0], dtype=np.int64);
idx2 = mx.nd.array([1], dtype=np.int64);
sp_nd1 = mx.sparse_nd.row_sparse(val1, idx1, shape)
sp_nd2 = mx.sparse_nd.row_sparse(val2, idx2, shape)
ds_nd = mx.nd.array(ds_np)
Expand Down

0 comments on commit 1b94692

Please sign in to comment.