From 68c837d7b2269b372b04bcf713fa30296a943011 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Mon, 26 Jun 2017 10:01:16 -0700 Subject: [PATCH] [WIP] Sparse Tensor (#5800) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * squash merge with 38f7c5584016e92ba1e0ee1b00ea6632740f67ce compiles on GPU update check alloc: Checkpoint. Pass elem-sum gpu test bug fix for copyfromto. sparse sgd test pass on gpu inefficient implementation for csr copy update submodule fix lint Simple bind with infer storage type (#32) * Symbol binding for sparse tensor development. (#31) * Initial checkin * Add init functions for simple bind in graph_executor * Add simple_bind c_api * Add simple bind c-api * Assign zeros to in_args, arg_grads, and aux_states * Add simple_bind2 python interface * Fix python interface bugs * Interface changes * Fix * Fix core dump * Add bind_ith_exec c_api * Change simple_bind2 * Fix seg fault * Finish simple_bind * Change _bind_ith_exec * Refactor simple_bind initialization flow for bind * Consolidate bind and simple_bind graph init flow * Fix bug * Clean up * Add comments * Clean up * Clean up * Minor correction * Rename APIs in graph executor * Refactor * Rebase * Delete deprecated functions * Move more front-end work to backend * Bug fix * Fix failed tests * Minor fix * Fix lint * Fix lint * Revert unnecessary changes * Revert * Revert * Clean up * Fix lint Conflicts: python/mxnet/symbol.py src/executor/graph_executor.cc * Add inferstorage to graph executor * re-enable tests for sparse embedding with simple_bind * type switch fix in sparse embedding" ; change `default` to `default_storage` for cast storage op (#33) * change default to default_storage * disable cpp test build temporarily attempt to fix windows build error, and fix lint (#34) update nnvm submodule (#37) Scipy build (#38) * update nnvm submodule * add scipy pip install for dockerfile Python3 unit tests (#39) * change xrange to range for python3 compatiblity" * remove more xrange from tests replace long with int for python3 (#40) fix the rest of TShape constructor errors (#41) fix lint (#42) fix wrong usage of mshadow::Shape1" (#43) implementation for Csr slice on cpu (#36) * CPU implementation for CSR remove seg_len from csr slice add some docs for slice csr change indptr, values, etc to be private member bug fix in sparse embedding update nnvm submoduel fix lint update unit test for sparse nd" * add const for SliceCsrIndPtr kernel Fix sparse dot according to the new RSP definition (#35) * Fix csr dot dns * Fix sparse dot * Add fallback and test cases for dot(csr, dns)=dns * Add int type switch * Fix * Fix * Fix update mshadow submodule (#44) Fix dns to rsp (#46) fix lint (#47) add runtime storage fallback detection" (#48) * add runtime storage fallback detection" * replace cast storage ex with cast storage impl Fm example (#45) * update csr slice logic to avoid confusion. add more exmaples. * add hint to module.update * more testcases(fallback) for sparse_nd * add to_csr() and to_rsp() method. More unit test (fallback now) * add fm test. fix lint * register sparse sgd under Optim.SGD * update dmlc-core submoduel * change indptr to _indptr temporarily. add const ref to fname fix lint fix lint; (#51) Guard gpu cast storage (#50) * Clean up * Fix typo Rearrange unit test files (#52) fix lint. add scipy for python_test. fix scipy.sparse import error. fix truediv for python3 fix travis test (#54) * remove pyc files * add verbose for travis nosetests cleanup some testing code and enums (#57) * update Makefile * refactor test_sparse_operator * change `default_storage` back to `default` * remove unused cpp tests port libsvm parser to mxnet as libsvm iter (#55) * copied csv iter to libsvm iter test libsvm iter draft handle round batch == false for csr batch loader code refactoring add get stype, shape interface to iiter separate class for sparse iter add missing file fix mem corruption' rename variables add comments also read label from libsvm add test. update docs. update submodule Conflicts: python/mxnet/sparse_ndarray.py * update submodule * fix lint * update test * revert naming change add benchmark scritp for dot (#59) * add benchmark scritp for dot add gpu option for bench add get_data funciton for benchmark print t_sparse, too; add comment change nnz to dnesity add backward * add comment update fm test (#62) introduce CSRNDarray and rowsparseNDarray to python frontend api (#58) * introduce CSRNDarray and rowsparseNDarray to python frontend api * temporarily disable fm_module test fix lint (#64) fix typo. disable libsvm io test (#65) Improve dot (#61) * Init checkin * Fix * Adjust dot parallelization methods * Set num_omp_threads for benchmark from command line * Fix omp thread number * Clean up * Add scipy as dot baseline * Fix format sparse_retain op (#66) * Initial checkin * Fix bugs * Add unit test for sparse_retain * Add example and modify test add storage cast for outputs that have non-default storage (#67) fix gpu build (#69) Fix test_sparse_retain python3 issue (#68) revert nnvm version * draft for sgd rsp rsp (#75) support sgd(rsp, rsp) support dot(csr, rsp) when rsp is full add ref to const ndarray params support sparse embedding with rsp weight' fix lint modify embedding backward to produce dense grad remove invalid_rid for rsp->dns remove previous embedding op changes pass sparse embedding test add STORAGE_TYPE_ASSIGN_CHECK remove backward storage infer * fix lint (#78) * fix lint (#79) * serial elemwise sum impl (#80) update module kvstore interface add other missing params and functions revert some interface changes revert some more changes reomve explicit casting for gradients on kvstore update Comm interface update fm example Conflicts: python/mxnet/model.py python/mxnet/ndarray.py * bug fix for initializing module with row_sparse weight (#81) * bug fix for initializing module with row_sparse weight * update log message * Sparse ndarray serialization and deserialization (#77) * Initial checkin * Add unit tests * Fix lint * Fix lint (#84) * Sgd with row_sparse weight, dns gradient (#83) * sgd rsp dns draft * support sgd_mom(rsp, dns, rsp) * update doc * remove cast storage for kv updater * code refactoring * update mshadow version (#88) * csr slice bug fix (#90) * benchmark dot code refactor (#87) * q^x6x add some code in benchmark * refactor * minor fixes * fix * lint fix * Add unit test (#91) * add unittest * minor fix * remove commented lines * change test func name * add test rsp * kvstore push row sparse (#93) * Add multi-thread cpu elemwise sum for rsps * Minor fix * Add flag to switch between serial and multi-thread kvstore push * Fix lint in sparse_ndarray.py * Revert "Fix lint in sparse_ndarray.py" This reverts commit d7225ec267a1e8c0c3c8074d25af5844ed39a14d. * Fix ndarray init in copy(ctx) * Add env var to control the flow of serial/parallel reduce * Refactor * Fix copy ndarray bug * Fix lint * Refactor * Fix windows openmp build failure (#94) * update mshadow submoduel (#95) * Revert "update mshadow submoduel (#95)" (#96) This reverts commit 1a129e4cc39514a6c7b3aa1189949969b818aec3. * Refactor sparse tensor code (#99) * Initial checkin test_sparse_ndarray passes * Fix test failure * Clean up * Clean up * Move init backend op to ndarray_utils * Fix lint * Eliminate circular dependency on headers * More refactor * Fix gpu build and consolidate Slice for dense and sparse * Clean up * More refactor * Clean up * Fix gpu build * Fix comment * fix pylint (#100) * Fix refactor sparse gpu test (#104) * Fix gpu build * Fix * Fix gpu test failure * change idx types from int32 to int64 (#101) Conflicts: python/mxnet/test_utils.py tests/python/unittest/test_sparse_operator.py update mshadow submodule fix extra quotes in test script change indptr type to int64 better err message for rsp" * revert LOG(DEBUG) change (#105) * fix undefined zeros in optimizer.py (#106) * move init dns zeros to init_op.h for kvstore to use (#107) * Refactor cast storage (#109) * Refactor cast_storage * Add cast_storage cc and cu files * Remove redundant comments * Replace std::accumulate with ParallelAccumulate * Clean up * Fix windows build * Rowsparse kv (#111) * update kvstore unit test Conflicts: tests/python/unittest/test_kvstore.py update model/module.py Conflicts: python/mxnet/model.py python/mxnet/module/module.py fix lint resolve conflict remove int keys in kvstore update cast to str function * fix failed dist_sync_kv test * bug fix in comm to ensure merged gradient is of the right type bug fix in comm * row sparse dist kvstore draft (push only) row_sparse pull * add ndarray row sparse shared mem constructor * code refactoring * add test for row_sparse weight bug fix for kv server slicing add async support rsolve race condition in kvstore * resolve error after reb ase * fix lint (#113) * rename some python funciton (#114) * _to_rsp * _to_csr. raise NotImplementedError * todense * fix lint (#115) enable libsvm uniit test (#6839) remove shared mem slice for csr add csr ndarray iter test make osx nose test verbose disable libsvm iter test Move InferAttr to mxnet from nnvm (#6830) * Move InferAttr to mxnet from nnvm Replace nnvm infer attr functions in c_api Initial checkin Clean up Remove nnvm namespace for FInferShape, FInferType, and FInferStorageType Add new interface for InferStorageType Revert "Remove nnvm namespace for FInferShape, FInferType, and FInferStorageType" This reverts commit 8aedf054bfe29b076c6fcb6f54d996fd2752e4de. Fix and clean up Fix lint Add nnvm changes Change infer function interface to accept only rvalue reference of graph Clean up Flush commits to show up in PR Add error handling for storage type inference failure Update nnvm * Fix pylint Change idx type switch for aux data (#6860) * Change idx type switch for aux data * Add mshadow commit Sparse dot enhancement (#6842) * Initial checkin Initial checkin Fix sparse dot test Fix unitest and add fallback for sparse dot * Add benchmark code * Revert "Add benchmark code" This reverts commit be009fe4c5a2a321aa92e99ac6e9cc511198c742. * Fix bug * Fix storage shape * Remove unnecessary test code * Use idx type switch Implement dot(csr, rsp)=dns and dot(csr.T, rsp)=rsp and refactor (#6902) * Initial checkin Add dot(csr.T, rsp)=rsp2 Add infer storage for dot(csr, rsp)=dns and dot(csr.T, rsp)=rsp2 * Fix comments * Replace std::lower_bound with own impl for gpu use too * Add time profiling * Revert "Add time profiling" This reverts commit 8f5bb982867731df0305148b1b150b05661f8529. * Move dot and batch_dot to a single file * Move dot gpu impl to a .cuh file * More refactor * Fix include error LibsvmIter fix (#6898) * fix bug in libsvm iter which causes mem corruption * add test for news dataset * fix wrong path in test * fix import error for urllib * update url * replace bz command with bz module Optimized gpu dot kernels (#6937) * pulled update to mshadow * mshadow update * added optimized gpu kernels for dot(csr,dns)=dns and dot(csr.T,dns)=dns, and unit test * added __syncwarp to vector kernel and reduced number of writes to shared memory Refactor sparse tensor code (#6955) * Save stype in frontend to avoid c-api call for stype * Change storage_type to stype * Revert "Change storage_type to stype" This reverts commit 90db7d18b624f3ee4ffd37bf5680205e77ca2763. * Revert "Revert "Change storage_type to stype"" This reverts commit 09328382e926b92a42ba5b3df6f169f825975d88. Move ndarray.py, sparse_ndarray.py, ndarray_utils.py, and _ndarray_internal to ndarrary folder More refactor Move elementwise sum for rsp to ndarray_function.cc Remove unnecessary import in ndarray module Fix pylint Remove redundant code Remove _stype from slots Fix cpp-package build error caused by the change to imperative invoke interface Use relative import Remove print line Rename _ndarray_internal.py to _internal.py * Relaunch test... minor bug fix in warp synchronous code (#7029) --- benchmark/python/sparse_op.py | 228 +++++ benchmark/python/util.py | 33 + include/mxnet/c_api.h | 177 +++- include/mxnet/executor.h | 1 + include/mxnet/ndarray.h | 454 ++++++++- include/mxnet/op_attr_types.h | 18 +- include/mxnet/storage.h | 4 +- mshadow | 2 +- nnvm | 2 +- python/mxnet/__init__.py | 3 +- python/mxnet/_ctypes/ndarray.py | 24 +- python/mxnet/contrib/autograd.py | 1 + python/mxnet/executor.py | 5 +- python/mxnet/image.py | 6 +- python/mxnet/io.py | 5 +- python/mxnet/kvstore.py | 9 +- python/mxnet/model.py | 9 +- python/mxnet/module/module.py | 12 +- python/mxnet/ndarray/__init__.py | 12 + .../_internal.py} | 0 python/mxnet/{ => ndarray}/ndarray.py | 385 ++------ python/mxnet/ndarray/ndarray_utils.py | 99 ++ python/mxnet/ndarray/op.py | 189 ++++ python/mxnet/ndarray/sparse_ndarray.py | 628 ++++++++++++ python/mxnet/optimizer.py | 8 +- python/mxnet/random.py | 14 +- python/mxnet/symbol.py | 51 +- python/mxnet/test_utils.py | 94 +- python/setup.py | 2 +- src/c_api/c_api.cc | 71 ++ src/c_api/c_api_common.h | 2 + src/c_api/c_api_executor.cc | 32 +- src/c_api/c_api_ndarray.cc | 132 ++- src/c_api/c_api_symbolic.cc | 5 +- src/c_api/c_predict_api.cc | 3 +- src/common/utils.cc | 23 + src/common/utils.cu | 21 + src/common/utils.h | 165 +++- src/executor/attach_op_execs_pass.cc | 112 ++- src/executor/exec_pass.h | 51 +- src/executor/graph_executor.cc | 321 ++++-- src/executor/graph_executor.h | 8 +- src/executor/infer_graph_attr_pass.cc | 337 +++++++ src/executor/inplace_addto_detect_pass.cc | 2 + src/io/inst_vector.h | 1 + src/io/iter_batchloader.h | 17 +- src/io/iter_libsvm.cc | 258 +++++ src/io/iter_prefetcher.h | 32 +- src/io/iter_sparse.h | 27 + src/io/iter_sparse_batchloader.h | 185 ++++ src/io/iter_sparse_prefetcher.h | 135 +++ src/kvstore/comm.h | 179 +++- src/kvstore/kvstore_dist.h | 234 ++++- src/kvstore/kvstore_dist_server.h | 208 +++- src/kvstore/kvstore_local.h | 8 +- src/ndarray/ndarray.cc | 315 +++++- src/ndarray/ndarray_function-inl.h | 61 +- src/ndarray/ndarray_function.cc | 130 +++ src/ndarray/ndarray_function.h | 9 + src/operator/elemwise_op_common.h | 63 ++ src/operator/mxnet_op.h | 14 +- src/operator/operator_common.h | 61 ++ src/operator/optimizer_op-inl.h | 368 +++++++ src/operator/optimizer_op.cc | 8 + src/operator/optimizer_op.cu | 6 +- src/operator/tensor/cast_storage-inl.cuh | 26 + src/operator/tensor/cast_storage-inl.h | 336 +++++++ src/operator/tensor/cast_storage.cc | 31 + src/operator/tensor/cast_storage.cu | 17 + src/operator/tensor/dot-inl.cuh | 382 ++++++++ src/operator/tensor/dot-inl.h | 925 ++++++++++++++++++ src/operator/tensor/dot.cc | 114 +++ src/operator/tensor/dot.cu | 27 + .../elemwise_binary_broadcast_op_basic.cc | 1 + src/operator/tensor/elemwise_binary_op.h | 157 ++- .../tensor/elemwise_binary_op_basic.cc | 7 +- .../tensor/elemwise_binary_op_basic.cu | 7 +- src/operator/tensor/elemwise_unary_op.cc | 3 + src/operator/tensor/elemwise_unary_op.cu | 4 +- src/operator/tensor/elemwise_unary_op.h | 58 +- src/operator/tensor/indexing_op.cc | 83 ++ src/operator/tensor/indexing_op.cu | 6 + src/operator/tensor/indexing_op.h | 295 ++++++ src/operator/tensor/init_op.cc | 1 + src/operator/tensor/init_op.cu | 3 +- src/operator/tensor/init_op.h | 69 +- src/operator/tensor/matrix_op-inl.h | 449 ++------- src/operator/tensor/matrix_op.cc | 94 +- src/operator/tensor/matrix_op.cu | 12 - .../ci_build/install/ubuntu_install_python.sh | 4 +- tests/cpp/include/test_ndarray_utils.h | 115 +++ tests/cpp/operator/batchnorm_test.cc | 6 +- tests/cpp/operator/ndarray_test.cc | 6 + tests/cpp/unittest.mk | 2 +- tests/nightly/dist_sync_kvstore.py | 81 +- tests/python/gpu/test_operator_gpu.py | 1 + tests/python/unittest/test_infer_shape.py | 18 + tests/python/unittest/test_io.py | 98 ++ tests/python/unittest/test_kvstore.py | 49 +- tests/python/unittest/test_module.py | 95 +- .../python/unittest/test_multi_device_exec.py | 31 + tests/python/unittest/test_ndarray.py | 1 + tests/python/unittest/test_optimizer.py | 132 ++- tests/python/unittest/test_sparse_ndarray.py | 395 ++++++++ tests/python/unittest/test_sparse_operator.py | 214 ++++ tests/travis/run_test.sh | 20 +- tests/travis/setup.sh | 4 +- 107 files changed, 9160 insertions(+), 1298 deletions(-) create mode 100644 benchmark/python/sparse_op.py create mode 100644 benchmark/python/util.py create mode 100644 python/mxnet/ndarray/__init__.py rename python/mxnet/{_ndarray_internal.py => ndarray/_internal.py} (100%) rename python/mxnet/{ => ndarray}/ndarray.py (87%) create mode 100644 python/mxnet/ndarray/ndarray_utils.py create mode 100644 python/mxnet/ndarray/op.py create mode 100644 python/mxnet/ndarray/sparse_ndarray.py create mode 100644 src/common/utils.cc create mode 100644 src/common/utils.cu create mode 100644 src/executor/infer_graph_attr_pass.cc create mode 100644 src/io/iter_libsvm.cc create mode 100644 src/io/iter_sparse.h create mode 100644 src/io/iter_sparse_batchloader.h create mode 100644 src/io/iter_sparse_prefetcher.h create mode 100644 src/operator/tensor/cast_storage-inl.cuh create mode 100644 src/operator/tensor/cast_storage-inl.h create mode 100644 src/operator/tensor/cast_storage.cc create mode 100644 src/operator/tensor/cast_storage.cu create mode 100644 src/operator/tensor/dot-inl.cuh create mode 100644 src/operator/tensor/dot-inl.h create mode 100644 src/operator/tensor/dot.cc create mode 100644 src/operator/tensor/dot.cu create mode 100644 tests/cpp/include/test_ndarray_utils.h create mode 100644 tests/cpp/operator/ndarray_test.cc create mode 100644 tests/python/unittest/test_sparse_ndarray.py create mode 100644 tests/python/unittest/test_sparse_operator.py diff --git a/benchmark/python/sparse_op.py b/benchmark/python/sparse_op.py new file mode 100644 index 000000000000..15ca4df1be73 --- /dev/null +++ b/benchmark/python/sparse_op.py @@ -0,0 +1,228 @@ +import ctypes + +from mxnet.test_utils import * +import scipy.sparse as sp +import os +import time +import argparse + +from mxnet.base import check_call, _LIB +from util import get_data, estimate_density + +parser = argparse.ArgumentParser(description="Benchmark sparse operators", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--num-omp-threads', type=int, default=1, help='number of omp threads to set in MXNet') +args = parser.parse_args() + +# some data information +kdda = { + 'data_mini': 'kdda.t.mini', + 'data_name': 'kdda.t', + 'data_origin_name': 'kdda.t.bz2', + 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/kdda.t.bz2", + 'feature_dim': 20216830, + 'm': 200, + 'batch_size': [64] +} + +avazu = { + 'data_mini': 'avazu-app.t.mini', + 'data_name': 'avazu-app.t', + 'data_origin_name': 'avazu-app.t.bz2', + 'url': "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2", + 'feature_dim': 1000000, + 'm': 500, + 'batch_size': [64, 128] +} + + +def measure_cost(repeat, f, *args, **kwargs): + # start bench + start = time.time() + results = [] + for i in range(repeat): + results.append(f(*args, **kwargs)) + for result in results: + result.wait_to_read() + end = time.time() + diff = end - start + return diff / repeat + + +def test_dot_real(data_dict): + def get_iter(path, data_shape, batch_size): + data_train = mx.io.LibSVMIter(data_libsvm=path, + data_shape=data_shape, + batch_size=batch_size) + data_iter = iter(data_train) + return data_iter + + data_dir = os.path.join(os.getcwd(), 'data') + + path = os.path.join(data_dir, data_dict['data_name']) + if not os.path.exists(path): + get_data( + data_dir, + data_dict['data_name'], + data_dict['url'], + data_dict['data_origin_name'] + ) + assert os.path.exists(path) + + k = data_dict['feature_dim'] + m = data_dict['m'] + density = estimate_density(path, data_dict['feature_dim']) + + mini_path = os.path.join(data_dir, data_dict['data_mini']) + if not os.path.exists(mini_path): + os.system("head -n 2000 %r > %r" % (path, mini_path)) + assert os.path.exists(mini_path) + + print "Running Benchmarking on %r data" % data_dict['data_mini'] + for batch_size in data_dict['batch_size']: # iterator through different batch size of choice + print "batch_size is %d" % batch_size + # model + data_shape = (k, ) + train_iter = get_iter(mini_path, data_shape, batch_size) + weight = mx.nd.random_uniform(low=0, high=1, shape=(k, m)) + + csr_data = [] + dns_data = [] + num_batch = 0 + for batch in train_iter: + data = train_iter.getdata() + csr_data.append(data) + dns_data.append(data.todense()) + num_batch += 1 + bag_of_data = [csr_data, dns_data] + num_repeat = 5 + costs = [] + for d in bag_of_data: + weight.wait_to_read() + cost = 0. + count = 0 + for d_batch in d: + d_batch.wait_to_read() + cost += measure_cost(num_repeat, mx.nd.dot, d_batch, weight) + count += 1 + costs.append(cost/count) + t_sparse = costs[0] + t_dense = costs[1] + ratio = t_dense / t_sparse + print('density(%)\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse') + fmt = "%0.4f\t\t%d\t%d\t%d\t%0.2f\t\t\t%0.4f\t%0.6f" + print(fmt % (density * 100, batch_size, m, k, ratio, t_dense, t_sparse)) + + +def test_dot_synthetic(): + """benchmark mx.nd.dot(sparse_ndarray, dense_ndarray) with given density. + `t_sparse` is the time cost of dot(csr, dns), while `t_dense` is the time cost + of dot(dns, dns), with the same matrix except that it is in default storage type. + """ + def measure_cost_forward_baseline(repeat, dot, lhs, rhs): + start = time.time() + for i in range(repeat): + dot(lhs, rhs) + end = time.time() + diff = end - start + return diff / repeat + + def measure_cost_backward_baseline(repeat, dot, transpose, lhs, rhs): + start = time.time() + for i in range(repeat): + dot(transpose(lhs), rhs) + end = time.time() + diff = end - start + return diff / repeat + + def bench_dot_forward(m, k, n, density, ctx, repeat): + set_default_context(ctx) + dns = mx.nd.random_uniform(shape=(k, n)).copyto(ctx) + data_shape = (m, k) + csr_data = rand_ndarray(data_shape, 'csr', density) + dns_data = csr_data.todense() + rhs_dns_np = dns.asnumpy() + lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) # csr in scipy + lhs_dns_np = lhs_csr_sp.todense() + + data = [dns_data, csr_data] + costs = [] + for d in data: + dns.wait_to_read() + d.wait_to_read() + cost = measure_cost(repeat, mx.nd.dot, d, dns) + costs.append(cost) + ratio = costs[0] / costs[1] + + costs_baseline = [] + cost = measure_cost_forward_baseline(repeat, np.dot, lhs_dns_np, rhs_dns_np) + costs_baseline.append(cost) + cost = measure_cost_forward_baseline(repeat, sp.spmatrix.dot, lhs_csr_sp, rhs_dns_np) + costs_baseline.append(cost) + ratio_baseline = costs_baseline[0] / costs_baseline[1] + fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.2f\t\t\t%0.2f\t%0.5f\t\t%0.2f\t\t\t\t%0.6f\t%0.5f" + print(fmt % (density * 100, str(ctx), n, m, k, ratio, costs[0], costs[1], + ratio_baseline, costs_baseline[0], costs_baseline[1])) + + def bench_dot_backward(m, k, n, density, ctx, repeat): + set_default_context(ctx) + dns = mx.nd.random_uniform(shape=(m, n)).copyto(ctx) + data_shape = (m, k) + csr_data = rand_ndarray(data_shape, 'csr', density) + dns_data = csr_data.todense() + rhs_dns_np = dns.asnumpy() + lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) + lhs_dns_np = lhs_csr_sp.todense() + + data = [dns_data, csr_data] + costs = [] + for d in data: + dns.wait_to_read() + d.wait_to_read() + cost = measure_cost(repeat, mx.nd.dot, d, dns, transpose_a=True) + costs.append(cost) + ratio = costs[0] / costs[1] + + costs_baseline = [] + cost = measure_cost_backward_baseline(repeat, np.dot, np.transpose, lhs_dns_np, rhs_dns_np) + costs_baseline.append(cost) + cost = measure_cost_backward_baseline(repeat, sp.spmatrix.dot, sp.spmatrix.transpose, lhs_csr_sp, rhs_dns_np) + costs_baseline.append(cost) + ratio_baseline = costs_baseline[0] / costs_baseline[1] + fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.2f\t\t\t%0.2f\t%0.5f\t\t%0.2f\t\t\t\t%0.6f\t%0.5f" + print(fmt % (density * 100, str(ctx), n, m, k, ratio, costs[0], costs[1], + ratio_baseline, costs_baseline[0], costs_baseline[1])) + + print("A = sparse NDArray of shape(m, k)") + print("B = dense NDArray of shape(k, n)") + print("dot_forward\tdot(csr, dns)") + print('density(%)\tcontext\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse' + '\tt_scipy_dense/t_scipy_sparse\tt_scipy_dense\tt_scipy_sparse') + + check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads))) + # TODO(haibin) make these runtime options + m = 512 + k = [50000, 100000] + n = [64, 128] + density = [1.00, 0.90, 0.70, 0.50, 0.30, 0.20, 0.10, 0.07, 0.05, 0.02, 0.01, 0.005, 0.001] + num_repeat = 10 + # contexts = [mx.cpu(), mx.gpu(0)] + contexts = [mx.cpu()] + for i in range(2): + for ctx in contexts: + for den in density: + bench_dot_forward(m, k[i], n[i], den, ctx, num_repeat) + + print("dot_backward\tdot(csr.T, dns)") + print('density(%)\tcontext\tn\tm\tk\tt_dense/t_sparse\tt_dense\tt_sparse' + '\tt_scipy_dense/t_scipy_sparse\tt_scipy_dense\tt_scipy_sparse') + for i in range(2): + for ctx in contexts: + for den in density: + bench_dot_backward(m, k[i], n[i], den, ctx, num_repeat) + + +if __name__ == "__main__": + test_dot_real(avazu) + test_dot_real(kdda) + test_dot_synthetic() diff --git a/benchmark/python/util.py b/benchmark/python/util.py new file mode 100644 index 000000000000..86e67d0f8a20 --- /dev/null +++ b/benchmark/python/util.py @@ -0,0 +1,33 @@ +import os +import random + + +def get_data(data_dir, data_name, url, data_origin_name): + if not os.path.isdir(data_dir): + os.system("mkdir " + data_dir) + os.chdir(data_dir) + if (not os.path.exists(data_name)): + import urllib + zippath = os.path.join(data_dir, data_origin_name) + urllib.urlretrieve(url, zippath) + os.system("bzip2 -d %r" % data_origin_name) + os.chdir("..") + + +def estimate_density(DATA_PATH, feature_size): + """sample 10 times of a size of 1000 for estimating the density of the sparse dataset""" + if not os.path.exists(DATA_PATH): + raise Exception("Data is not there!") + density = [] + P = 0.01 + for _ in xrange(10): + num_non_zero = 0 + num_sample = 0 + with open(DATA_PATH) as f: + for line in f: + if (random.random() < P): + num_non_zero += len(line.split(" ")) - 1 + num_sample += 1 + density.append(num_non_zero * 1.0 / (feature_size * num_sample)) + return sum(density) / len(density) + diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 47447fb37196..d7811d8a1b60 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -246,6 +246,38 @@ MXNET_DLL int MXNDArrayCreateEx(const mx_uint *shape, int delay_alloc, int dtype, NDArrayHandle *out); + + +/*! + * \brief create an empty sparse NDArray with specified shape and data type + * \param storage_type the storage type of the ndarray + * \param shape the pointer to the shape + * \param ndim the dimension of the shape + * \param dev_type device type, specify device we want to take + * \param dev_id the device id of the specific device + * \param delay_alloc whether to delay allocation until + * the narray is first mutated + * \param dtype data type of created array + * \param num_aux the number of aux data to support this ndarray + * \param aux_type data type of the aux data for the created array + * \param aux_ndims the dimension of the shapes of aux data + * \param aux_shape the shapes of aux data + * \param out the returning handle + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayCreateSparseEx(int storage_type, + const mx_uint *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + mx_uint num_aux, + int *aux_type, + mx_uint *aux_ndims, + const mx_uint *aux_shape, + NDArrayHandle *out); + /*! * \brief create a NDArray handle that is loaded from raw bytes. * \param buf the head of the raw bytes @@ -358,6 +390,7 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle, mx_uint slice_begin, mx_uint slice_end, NDArrayHandle *out); + /*! * \brief Index the NDArray along axis 0. * \param handle the handle to the NDArray @@ -368,6 +401,13 @@ MXNET_DLL int MXNDArraySlice(NDArrayHandle handle, MXNET_DLL int MXNDArrayAt(NDArrayHandle handle, mx_uint idx, NDArrayHandle *out); + +/*! + * \brief get the storage type of the array + */ +MXNET_DLL int MXNDArrayGetStorageType(NDArrayHandle handle, + int *out_storage_type); + /*! * \brief Reshape the NDArray. * \param handle the handle to the narray @@ -406,6 +446,26 @@ MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle, */ MXNET_DLL int MXNDArrayGetDType(NDArrayHandle handle, int *out_dtype); + +/*! + * \brief get the type of the ith aux data in NDArray + * \param handle the handle to the narray + * \param i the index of the aux data + * \param out_type pointer holder to get type of aux data + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXNDArrayGetAuxType(NDArrayHandle handle, + mx_uint i, + int *out_type); + +// Get the ith aux data blob wrapped in an NDArray +MXNET_DLL int MXNDArrayGetAuxNDArray(NDArrayHandle handle, + mx_uint i, + NDArrayHandle *out); + +// Get the data blob wrapped in an NDArray +MXNET_DLL int MXNDArrayGetDataNDArray(NDArrayHandle handle, + NDArrayHandle *out); /*! * \brief get the context of the NDArray * \param handle the handle to the narray @@ -551,6 +611,28 @@ MXNET_DLL int MXImperativeInvoke(AtomicSymbolCreator creator, int num_params, const char **param_keys, const char **param_vals); +/*! + * \brief invoke a nnvm op and imperative function + * \param creator the op + * \param num_inputs number of input NDArrays + * \param inputs input NDArrays + * \param num_outputs number of output NDArrays + * \param outputs output NDArrays + * \param num_params number of keyword parameters + * \param param_keys keys for keyword parameters + * \param param_vals values for keyword parameters + * \param out_stypes output ndarrays' stypes + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXImperativeInvokeEx(AtomicSymbolCreator creator, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs, + int num_params, + const char **param_keys, + const char **param_vals, + const int **out_stypes); /*! * \brief set whether to record operator for autograd * \param is_train 1 when training, 0 when testing @@ -948,20 +1030,20 @@ MXNET_DLL int MXSymbolInferShape(SymbolHandle sym, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolInferShapePartial(SymbolHandle sym, - mx_uint num_args, - const char** keys, - const mx_uint *arg_ind_ptr, - const mx_uint *arg_shape_data, - mx_uint *in_shape_size, - const mx_uint **in_shape_ndim, - const mx_uint ***in_shape_data, - mx_uint *out_shape_size, - const mx_uint **out_shape_ndim, - const mx_uint ***out_shape_data, - mx_uint *aux_shape_size, - const mx_uint **aux_shape_ndim, - const mx_uint ***aux_shape_data, - int *complete); + mx_uint num_args, + const char** keys, + const mx_uint *arg_ind_ptr, + const mx_uint *arg_shape_data, + mx_uint *in_shape_size, + const mx_uint **in_shape_ndim, + const mx_uint ***in_shape_data, + mx_uint *out_shape_size, + const mx_uint **out_shape_ndim, + const mx_uint ***out_shape_data, + mx_uint *aux_shape_size, + const mx_uint **aux_shape_ndim, + const mx_uint ***aux_shape_data, + int *complete); /*! * \brief infer type of unknown input types given the known one. @@ -992,6 +1074,10 @@ MXNET_DLL int MXSymbolInferType(SymbolHandle sym, mx_uint *aux_type_size, const int **aux_type_data, int *complete); + + + + //-------------------------------------------- // Part 4: Executor interface //-------------------------------------------- @@ -1140,36 +1226,39 @@ MXNET_DLL int MXExecutorBindEX(SymbolHandle symbol_handle, ExecutorHandle *out); MXNET_DLL int MXExecutorSimpleBind(SymbolHandle symbol_handle, - int dev_type, - int dev_id, - const mx_uint num_g2c_keys, - const char** g2c_keys, - const int* g2c_dev_types, - const int* g2c_dev_ids, - const mx_uint provided_grad_req_list_len, - const char** provided_grad_req_names, - const char** provided_grad_req_types, - const mx_uint num_provided_arg_shapes, - const char** provided_arg_shape_names, - const mx_uint* provided_arg_shape_data, - const mx_uint* provided_arg_shape_idx, - const mx_uint num_provided_arg_dtypes, - const char** provided_arg_dtype_names, - const int* provided_arg_dtypes, - const mx_uint num_shared_arg_names, - const char** shared_arg_name_list, - int* shared_buffer_len, - const char** shared_buffer_name_list, - NDArrayHandle* shared_buffer_handle_list, - const char*** updated_shared_buffer_name_list, - NDArrayHandle** updated_shared_buffer_handle_list, - mx_uint* num_in_args, - NDArrayHandle** in_args, - NDArrayHandle** arg_grads, - mx_uint* num_aux_states, - NDArrayHandle** aux_states, - ExecutorHandle shared_exec_handle, - ExecutorHandle* out); + int dev_type, + int dev_id, + const mx_uint num_g2c_keys, + const char** g2c_keys, + const int* g2c_dev_types, + const int* g2c_dev_ids, + const mx_uint provided_grad_req_list_len, + const char** provided_grad_req_names, + const char** provided_grad_req_types, + const mx_uint num_provided_arg_shapes, + const char** provided_arg_shape_names, + const mx_uint* provided_arg_shape_data, + const mx_uint* provided_arg_shape_idx, + const mx_uint num_provided_arg_dtypes, + const char** provided_arg_dtype_names, + const int* provided_arg_dtypes, + const mx_uint num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, + const mx_uint num_shared_arg_names, + const char** shared_arg_name_list, + int* shared_buffer_len, + const char** shared_buffer_name_list, + NDArrayHandle* shared_buffer_handle_list, + const char*** updated_shared_buffer_name_list, + NDArrayHandle** updated_shared_buffer_handle_list, + mx_uint* num_in_args, + NDArrayHandle** in_args, + NDArrayHandle** arg_grads, + mx_uint* num_aux_states, + NDArrayHandle** aux_states, + ExecutorHandle shared_exec_handle, + ExecutorHandle* out); /*! * \brief set a call back to notify the completion of operation */ diff --git a/include/mxnet/executor.h b/include/mxnet/executor.h index 40bd60f5f405..5856b87cf859 100644 --- a/include/mxnet/executor.h +++ b/include/mxnet/executor.h @@ -115,6 +115,7 @@ class Executor { const std::vector& aux_state_ctxes, const std::unordered_map& arg_shape_map, const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, const std::vector& grad_req_types, const std::unordered_set& param_names, std::vector* in_args, diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index e349b3091c56..3a02972cab06 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -29,7 +29,6 @@ namespace mxnet { -// forward declaration namespace autograd { class AGNode; @@ -53,6 +52,23 @@ class AGNodeEntry { class AutogradRuntime; } // namespace autograd +// enum for storage types +namespace csr { +enum CSRAuxType {kIndPtr, kIdx}; +} + +namespace rowsparse { +enum RowSparseAuxType {kIdx}; +} + +enum NDArrayStorageType { + kUndefinedStorage = -1, // undefined storage + kDefaultStorage, // dense + kRowSparseStorage, // row sparse + kCSRStorage, // csr +}; + + /*! * \brief ndarray interface */ @@ -73,10 +89,55 @@ class NDArray { */ NDArray(const TShape &shape, Context ctx, bool delay_alloc = false, int dtype = mshadow::default_type_flag) - : ptr_(std::make_shared(shape.Size(), ctx, delay_alloc, dtype)), + : ptr_(std::make_shared(shape, ctx, delay_alloc, dtype)), shape_(shape), dtype_(dtype), entry_({nullptr, 0, 0}) { #if MKL_EXPERIMENTAL == 1 Mkl_mem_ = std::make_shared(); +#endif + } + /*! \brief constructor for NDArray with storage type + */ + NDArray(const NDArrayStorageType stype, const TShape &shape, Context ctx, + bool delay_alloc = true, int dtype = mshadow::default_type_flag, + std::vector aux_types = {}, std::vector aux_shapes = {}, + TShape storage_shape = TShape(mshadow::Shape1(0))) + : shape_(shape), dtype_(dtype), entry_({nullptr, 0, 0}) { + // Assign default aux types if not given + if (aux_types.size() == 0) { + if (stype == kRowSparseStorage) { + aux_types = {mshadow::kInt64}; + } else if (stype == kCSRStorage) { + aux_types = {mshadow::kInt64, mshadow::kInt64}; + } else { + LOG(FATAL) << "Unknown storage type " << stype; + } + } + // Assign default shapes if not given + // unknown shapes are intialized as {0} such that Size() would return 0 + if (aux_shapes.size() == 0) { + if (stype == kRowSparseStorage) { + aux_shapes = {TShape(mshadow::Shape1(0))}; + } else if (stype == kCSRStorage) { + // aux shapes for indptr and indices + aux_shapes = {TShape(mshadow::Shape1(0)), TShape(mshadow::Shape1(0))}; + } else { + LOG(FATAL) << "Unknown storage type " << stype; + } + } + if (storage_shape.Size() == 0) { + if (stype == kRowSparseStorage) { + storage_shape = shape; + storage_shape[0] = aux_shapes[rowsparse::kIdx][0]; + } else if (stype == kCSRStorage) { + storage_shape = aux_shapes[csr::kIdx]; + } else { + LOG(FATAL) << "Unknown storage type " << stype; + } + } + ptr_ = std::make_shared(stype, storage_shape, ctx, delay_alloc, + dtype, aux_types, aux_shapes); +#if MKL_EXPERIMENTAL == 1 + Mkl_mem_ = std::make_shared(); #endif } /*! @@ -85,25 +146,94 @@ class NDArray { * make sure the memory region is available through out the life of NDArray * \param data the memory content of static data * \param dev_id the device id this tensor sits at + * \param shared_var the same var handle shared with others. + It will not be deleted during destruction. + */ + NDArray(const TBlob &data, int dev_id, Engine::VarHandle shared_var = nullptr) + : ptr_(std::make_shared(data, dev_id, shared_var)), shape_(data.shape_), + dtype_(data.type_flag_), entry_({nullptr, 0, 0}) { +#if MKL_EXPERIMENTAL == 1 + Mkl_mem_ = std::make_shared(); +#endif + } + + /*! + * \brief constructing a static NDArray of non-default storage that shares data with TBlob + * Use with caution: allocate ONLY ONE NDArray for each TBlob, + * make sure the memory region is available through out the life of NDArray + * \param stype the storage type of NDArray + * \param shape the shape of NDArray + * \param data the memory content of static data + * \param aux_data the memory content of static aux data + * \param dev_id the device id this tensor sits at + * \param shared_var the same var handle shared with others. + It will not be deleted during destruction. */ - NDArray(const TBlob &data, int dev_id) - : ptr_(std::make_shared(data, dev_id)), shape_(data.shape_), + NDArray(const NDArrayStorageType stype, const TShape &shape, + const TBlob &data, const std::vector &aux_data, int dev_id) + : ptr_(std::make_shared(stype, data, aux_data, dev_id)), shape_(shape), dtype_(data.type_flag_), entry_({nullptr, 0, 0}) { #if MKL_EXPERIMENTAL == 1 Mkl_mem_ = std::make_shared(); #endif } + + /*! - * \return the shape of current NDArray + * \return the shape of current NDArray. */ inline const TShape& shape() const { return shape_; } + /*! + * \return the shape of underlying chunk which stores the NDArray values. + * For default storage, it is the same as shape(). For row-sparse storage, it is the shape of + * the tensor which stores the non-zero values. + */ + inline const TShape &storage_shape() const { + CHECK(ptr_ != nullptr); + return ptr_->storage_shape; + } + + /*! + * \brief For sparse operations, the storage shape is an estimated value + * in the beginning for allocating enough capacity for the final result. + * After the operation is done, the exact size of the shape is known + * and need to be reset using this function. For example, adding + * two CSRs with nnz1 and nnz2 as their numbers of non-zero values, respectively, + * would allocate the array of size nnz1+nnz2 first and get the final + * nnz that is smaller than nnz1+nnz2. Therefore, the storage shape's size + * needs to be shrunk from nnz1+nnz2 to nnz. + */ + inline void set_storage_shape(const TShape& sshape) { + CHECK(storage_type() != kDefaultStorage); + ptr_->storage_shape = sshape; + } + + /*! + * \return the shape of aux data at ith index. If it doesn't exist, return an empty one. + */ + inline const TShape aux_shape(size_t i) const { + CHECK(storage_type() != kDefaultStorage); + return ptr_->aux_shapes[i]; + } + + /*! + * \brief For a sparse operation on a csr matrix for example, + * the size of the column index array + * is an estimated value in the beginning for allocating enough capacity + * for the final result. After the operation is done, the exact size of + * the shape is known and need to be reset using this function. + */ + inline void set_aux_shape(size_t i, const TShape& shape) const { + ptr_->aux_shapes[i] = shape; + } + /*! * \return the data TBlob */ inline const TBlob& data() const { - CheckAndAlloc(); + if (storage_type() == kDefaultStorage) CheckAndAlloc(); SetTBlob(); return tblob_; } @@ -111,6 +241,26 @@ class NDArray { * \return the gradient ndarray. */ NDArray grad() const; + + /*! + * \return the aux TBlob + */ + inline TBlob aux_data(size_t i) const { + auto stype = storage_type(); + TBlob res; + auto shape = aux_shape(i); + auto type = aux_type(i); + MSHADOW_TYPE_SWITCH(type, DType, { + auto dptr = static_cast(ptr_->aux_handles[i].dptr); + CHECK(stype == kRowSparseStorage || stype == kCSRStorage) + << "Unexpected storage type: " << stype; + res = TBlob(dptr, shape, ptr_->aux_handles[i].ctx.dev_mask(), type); + }); +#if MKL_EXPERIMENTAL == 1 + res.Mkl_mem_ = Mkl_mem_; +#endif + return res; + } /*! * \return the context of NDArray, this function is only valid when the NDArray is not empty */ @@ -123,6 +273,15 @@ class NDArray { inline int dtype() const { return dtype_; } + inline int aux_type(size_t i) const { + CHECK(!is_none()); + return ptr_->aux_types[i]; + } + + inline NDArrayStorageType storage_type() const { + if (is_none()) return kUndefinedStorage; + return ptr_->storage_type; + } /*! \return whether this ndarray is not initialized */ inline bool is_none() const { return ptr_.get() == nullptr; @@ -131,6 +290,18 @@ class NDArray { bool fresh_out_grad() const; /*! \return updated grad state in entry_ */ void set_fresh_out_grad(bool state) const; + // returns true if a sparse ndarray's aux_data and storage are initialized + inline bool storage_initialized() const { + if (is_none()) return false; + auto stype = storage_type(); + CHECK_NE(stype, kDefaultStorage); + if (stype == kRowSparseStorage || stype == kCSRStorage) { + return aux_shape(0).Size() != 0; + } else { + LOG(FATAL) << "Unknown storage type"; + } + return true; + } /*! * \brief Block until all the pending write operations with respect * to current NDArray are finished, and read can be performed. @@ -161,6 +332,12 @@ class NDArray { * \param strm the output stream */ void Save(dmlc::Stream *strm) const; + /*! + * \brief load ndarrays before supporting sparse ndarrays + * \param strm the output stream + * \param magic the magic number used for version control + */ + bool LegacyLoad(dmlc::Stream *strm, const uint32_t magic); /*! * \brief load the content from binary stream * \param strm the output stream @@ -264,17 +441,31 @@ class NDArray { void SyncCopyToCPU(void *data, size_t size) const; /*! * \brief Slice a NDArray - * \param begin begin index in first dim - * \param end end index in first dim + * \param begin begin index in first dim (inclusive) + * \param end end index in first dim (exclusive) * \return sliced NDArray */ NDArray Slice(index_t begin, index_t end) const; + /*! * \brief Index a NDArray * \param idx the index * \return idx-th sub array NDArray */ NDArray At(index_t idx) const; + // Wrap the tblob of aux data into an NDArray which shares the same variable with the + // current one. + inline const NDArray aux_ndarray(size_t i) const { + CHECK_NE(storage_type(), kDefaultStorage); + CHECK(i < ptr_->aux_shapes.size()); + return NDArray(aux_data(i), ctx().dev_id, var()); + } + // Wrap the tblob of data into an NDArray which shares the same variable with the + // current one. + inline const NDArray data_ndarray() const { + CHECK_NE(storage_type(), kDefaultStorage); + return NDArray(data(), ctx().dev_id, var()); + } /*! * \brief Create a NDArray that shares memory with current one * The new array must have smaller memory size than the current array. @@ -283,6 +474,7 @@ class NDArray { * \return NDArray in new shape and type. */ inline NDArray AsArray(const TShape &shape, int dtype) const { + CHECK_EQ(storage_type(), kDefaultStorage) << "Not implemented yet"; CHECK_GE(shape_.Size() * mshadow::mshadow_sizeof(dtype_), shape.Size() * mshadow::mshadow_sizeof(dtype)) << "NDArray.AsArray: target memory size is bigger"; @@ -316,8 +508,25 @@ class NDArray { * This is an internal function used by system that normal user should not use */ inline void CheckAndAlloc() const { + CHECK_EQ(storage_type(), kDefaultStorage); ptr_->CheckAndAlloc(); } + /* ! + * \brief Alloc memory for non-default storage + * aux_shape is only known at run time + */ + inline void CheckAndAlloc(const std::vector &aux_shapes) const { + CHECK_NE(storage_type(), kDefaultStorage); + ptr_->CheckAndAlloc(shape_, aux_shapes, dtype_); + } + inline void CheckAndAllocData(const TShape &storage_shape) const { + CHECK_NE(storage_type(), kDefaultStorage); + ptr_->CheckAndAllocData(storage_shape, dtype_); + } + inline void CheckAndAllocAuxData(size_t i, const TShape &aux_shape) const { + CHECK_NE(storage_type(), kDefaultStorage); + ptr_->CheckAndAllocAuxData(i, aux_shape); + } /*! * \brief Save list of ndarray into the Stream.x * \param fo The stream of output. @@ -340,44 +549,132 @@ class NDArray { private: friend class autograd::AutogradRuntime; /*! \brief the real data chunk that backs NDArray */ + // shandle is used to store the actual values in the NDArray + // aux_handles store the aux data(such as indices) if it's needed by non-default storage. struct Chunk { - /*! \brief storage handlefrom storage engine */ + /*! \brief storage handle from storage engine. + for non-default storage, shandle stores the data(value) array. + */ Storage::Handle shandle; + /*! \brief storage handles for aux data (e.g index) + for row_sparse, aux_handles[0] = indices + for csr, aux_handles[0] = indptr, aux_handles[1] = indices + */ + std::vector aux_handles; /*! \brief variable from engine */ Engine::VarHandle var; /*! * \brief if this is true, this means the data do not come * from Storage, and do not need to be freed */ + /*! \brief construct from static data */ bool static_data; - /*! \brief whether allocation is delayed */ + /*! \brief whether data allocation is delayed. This doesn't indicate whether aux data + allocation is delayed. */ bool delay_alloc; + // the type of the storage. The storage_type is never kUndefinedStorage once the chunk + // is constructed. + NDArrayStorageType storage_type = kDefaultStorage; + /*! \brief type of aux */ + std::vector aux_types; + // context of data + Context ctx; + // The shape of the chunk data. + // This might not be the same shape as the NDArray, since the storage may be sparse. + // The default value for storage_shape is {0} when an empty non-default NDArray is created. + TShape storage_shape; + // The shape of aux data. The default value for the shape depends on the type of storage. + // If aux_shapes[i].Size() is zero, aux data i is empty. + std::vector aux_shapes; + // \brief skip the deletion of var handle. Usually set when shared_var is present. + bool skip_delete_var = false; + /*! \brief default cosntructor */ - Chunk() : static_data(true), delay_alloc(false) { - var = Engine::Get()->NewVariable(); - } - /*! \brief construct from static data */ - Chunk(const TBlob &data, int dev_id) - : static_data(true), - delay_alloc(false) { + Chunk() : static_data(true), delay_alloc(false) {} + + /*! \brief construct a new chunk */ + Chunk(TShape shape, Context ctx_, bool delay_alloc_, int dtype) + : static_data(false), delay_alloc(true), ctx(ctx_) { + auto size = shape.Size(); + storage_shape = shape; var = Engine::Get()->NewVariable(); + shandle.size = size * mshadow::mshadow_sizeof(dtype); + shandle.ctx = ctx_; + if (!delay_alloc_) this->CheckAndAlloc(); + } + + Chunk(const TBlob &data, int dev_id, Engine::VarHandle shared_var) + : static_data(true), delay_alloc(false) { + CHECK(storage_type == kDefaultStorage); + // init var + if (shared_var == nullptr) { + var = Engine::Get()->NewVariable(); + } else { + skip_delete_var = true; + var = shared_var; + } + // init ctx if (data.dev_mask() == cpu::kDevMask) { - shandle.ctx = Context::CPU(); + ctx = Context::CPU(); } else { CHECK_EQ(data.dev_mask(), gpu::kDevMask); - shandle.ctx = Context::GPU(dev_id); + ctx = Context::GPU(dev_id); } + // init shandle + shandle.ctx = ctx; shandle.dptr = data.dptr_; shandle.size = data.shape_.Size() * mshadow::mshadow_sizeof(data.type_flag_); + storage_shape = data.shape_; } - /*! \brief construct a new chunk */ - Chunk(uint64_t size, Context ctx, bool delay_alloc_, int dtype) - : static_data(false), delay_alloc(true) { + // Constructor for a non-default storage chunk + Chunk(NDArrayStorageType storage_type_, const TShape &storage_shape_, Context ctx_, + bool delay_alloc_, int dtype, const std::vector &aux_types_, + const std::vector &aux_shapes_) + : static_data(false), delay_alloc(delay_alloc_), storage_type(storage_type_), + aux_types(aux_types_), ctx(ctx_), storage_shape(storage_shape_), + aux_shapes(aux_shapes_) { + shandle.ctx = ctx; var = Engine::Get()->NewVariable(); - shandle.size = size * mshadow::mshadow_sizeof(dtype); + // aux_handles always reflect the correct number of aux data + for (size_t i = 0; i < aux_shapes.size(); i++) { + CheckAndAllocAuxData(i, aux_shapes[i]); + } + if (!delay_alloc) { + CheckAndAllocData(storage_shape, dtype); + } + } + + Chunk(const NDArrayStorageType storage_type_, const TBlob &data, + const std::vector &aux_data, int dev_id) + : static_data(true), delay_alloc(false), storage_type(storage_type_) { + using namespace mshadow; + CHECK_NE(storage_type, kDefaultStorage); + // init var + var = Engine::Get()->NewVariable(); + // init ctx + if (data.dev_mask() == cpu::kDevMask) { + ctx = Context::CPU(); + } else { + CHECK_EQ(data.dev_mask(), gpu::kDevMask); + ctx = Context::GPU(dev_id); + } + // init shandle shandle.ctx = ctx; - if (!delay_alloc_) this->CheckAndAlloc(); + shandle.dptr = data.dptr_; + shandle.size = data.shape_.Size() * mshadow_sizeof(data.type_flag_); + storage_shape = data.shape_; + // init aux handles + for (const auto &aux : aux_data) { + Storage::Handle aux_handle; + aux_handle.ctx = ctx; + aux_handle.dptr = aux.dptr_; + aux_handle.size = aux.shape_.Size() * mshadow_sizeof(aux.type_flag_); + aux_handles.push_back(aux_handle); + aux_types.emplace_back(aux.type_flag_); + aux_shapes.emplace_back(aux.shape_); + } } + /*! \brief check if delay alloc is on, do alloc if not yet done */ inline void CheckAndAlloc(void) { if (delay_alloc) { @@ -385,22 +682,98 @@ class NDArray { delay_alloc = false; } } - /*! \brief destructor */ - ~Chunk() { - if (static_data || delay_alloc) { - Engine::Get()->DeleteVariable([](RunContext s) {}, shandle.ctx, var); + inline void CheckAndAlloc(const TShape &shape, const std::vector &aux_shapes, + int dtype) { + // calculate size, perform allocation + if (kRowSparseStorage == storage_type) { + // For row sparse, aux_shape indicates the number of rows to allocate + auto aux_shape = aux_shapes[rowsparse::kIdx]; + CHECK_EQ(shape.ndim(), 2) << "High dim RowSparse not yet implemented"; + CheckAndAllocAuxData(rowsparse::kIdx, aux_shape); + TShape storage_shape(shape); + storage_shape[0] = aux_shape[0]; + CheckAndAllocData(storage_shape, dtype); + } else if (kCSRStorage == storage_type) { + CheckAndAllocAuxData(csr::kIndPtr, aux_shapes[csr::kIndPtr]); + CheckAndAllocAuxData(csr::kIdx, aux_shapes[csr::kIdx]); + CheckAndAllocData(aux_shapes[csr::kIdx], dtype); } else { - Storage::Handle h = this->shandle; - Engine::Get()->DeleteVariable([h](RunContext s) { - Storage::Get()->Free(h); - }, shandle.ctx, var); + LOG(FATAL) << "Storage type " << storage_type << " not implemented for CheckAndAlloc"; } } - }; + // create storage handle for data based on shape and dtype, assuming ctx is set + // storage shape is also updated + // if data is already allocated, try reuse the storage. Otherwise, free the current one + // and allocate new storage + inline void CheckAndAllocData(const TShape &shape, int dtype) { + CHECK_NE(aux_shapes.size(), 0) << "data is expected to be allocated after aux_data"; + auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype); + if (shandle.size < dbytes) { + // free storage if necessary and alloc again + if (shandle.size > 0) Storage::Get()->Free(shandle); + // init storage + shandle = Storage::Get()->Alloc(dbytes, ctx); + } + // init shape + storage_shape = shape; + // delay_alloc is only set when data storage handle is present + delay_alloc = false; + } + // create storage handle for aux data based on shape + // this function assumes ctx, aux shapes and aux types are set + // aux shape is also updated + // if aux data is already allocated, try reuse the storage. Otherwise, free the current one + // and allocate new storage + inline void CheckAndAllocAuxData(size_t i, const TShape &shape) { + CHECK_EQ(shape.ndim(), 1) << "shape must be 1D in CheckAndAllocAuxData"; + CHECK_NE(storage_type, kUndefinedStorage) + << "storage type cannot be kUndefinedStorage in CheckAndAllocAuxData"; + CHECK_NE(storage_type, kDefaultStorage) + << "storage type cannot be kDefaultStorage in CheckAndAllocAuxData"; + if (aux_handles.size() <= i) { + aux_handles.resize(i + 1); + } + size_t aux_bytes = shape.Size() * mshadow::mshadow_sizeof(aux_types[i]); + if (aux_handles[i].size < aux_bytes) { + // free storage if necessary and alloc again + if (aux_handles[i].size > 0) Storage::Get()->Free(aux_handles[i]); + // init aux storage + aux_handles[i] = Storage::Get()->Alloc(aux_bytes, ctx); + } + // init shape + aux_shapes[i] = shape; + } + /*! \brief destructor */ + ~Chunk() { + if (skip_delete_var) return; + bool skip_free = static_data || delay_alloc; + Storage::Handle h = this->shandle; + std::vector aux_h = this->aux_handles; + Engine::Get()->DeleteVariable([h, aux_h, skip_free](RunContext s) { + if (skip_free == false) { + Storage::Get()->Free(h); + for (size_t i = 0; i < aux_h.size(); i++) { + if (aux_h[i].size > 0) Storage::Get()->Free(aux_h[i]); + } + } + }, shandle.ctx, var); + } + }; // struct Chunk void SetTBlob() const { - tblob_.dptr_ = static_cast(ptr_->shandle.dptr) + byte_offset_; - tblob_.shape_ = shape_; + CHECK(ptr_ != nullptr); + TShape shape = shape_; + char *dptr = static_cast(ptr_->shandle.dptr); + auto stype = storage_type(); + if (stype == kDefaultStorage) { + dptr += byte_offset_; + } else if (stype == kCSRStorage || stype == kRowSparseStorage) { + shape = storage_shape(); + } else { + LOG(FATAL) << "unknown storage type " << stype; + } + tblob_.dptr_ = dptr; + tblob_.shape_ = shape; tblob_.type_flag_ = dtype_; tblob_.SetDLTensor(ptr_->shandle.ctx.dev_mask(), ptr_->shandle.ctx.dev_id); #if MKL_EXPERIMENTAL == 1 @@ -412,7 +785,7 @@ class NDArray { std::shared_ptr Mkl_mem_; #endif /*! \brief internal data of NDArray */ - std::shared_ptr ptr_; + std::shared_ptr ptr_{nullptr}; /*! \brief shape of current NDArray */ TShape shape_; /*! \brief byte offset in chunk */ @@ -429,7 +802,12 @@ class NDArray { * this situation. */ mutable TBlob tblob_; -}; +}; // class NDArray + +/*! + * \return the number of aux data used for given storage type + */ +size_t num_aux_data(NDArrayStorageType stype); /*! * \brief issue an copy operation from one NDArray to another @@ -439,12 +817,12 @@ class NDArray { * \param from the ndarray we want to copy data from * \param to the target ndarray * \param priority Priority of the action. + * \param alloc_output whether to allocate memory for the output ndarray * \note The function name explicitly marks the order of from and to * due to different possible convention carried by copy function. */ void CopyFromTo(const NDArray &from, NDArray *to, int priority = 0); - /*! * \brief Perform elementwise sum over each data from source, store result into out. * \param source the ndarray we want to sum diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index dbf9a07e0bcb..8cbf035888d4 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -7,7 +7,6 @@ #ifndef MXNET_OP_ATTR_TYPES_H_ #define MXNET_OP_ATTR_TYPES_H_ - #include #include @@ -207,6 +206,23 @@ using FCompute = std::function& inputs, const std::vector& req, const std::vector& outputs)>; +/*! + * \brief Resiger an NDArray compute function for simple stateless forward only operator + * + * \note Register under "FComputeEx" and "FComputeEx" + * Dispatched only when operators process non-default storage inputs or outputs + */ +using FComputeEx = std::function& inputs, + const std::vector& req, + const std::vector& outputs)>; + +using FInferStorageType = std::function* in_attrs, + std::vector* out_attrs)>; + } // namespace mxnet #endif // MXNET_OP_ATTR_TYPES_H_ diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index 1b765233947d..e236a9cf313b 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -23,11 +23,11 @@ class Storage { /*! * \brief Pointer to the data. */ - void* dptr; + void* dptr{nullptr}; /*! * \brief Size of the storage. */ - size_t size; + size_t size{0}; /*! * \brief Context information about device and ID. */ diff --git a/mshadow b/mshadow index d32b5dacf2bb..5a11d7544841 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit d32b5dacf2bb5af4121df5fd60eb7775704f9131 +Subproject commit 5a11d7544841b55a8ac1a65081759dc2289c335d diff --git a/nnvm b/nnvm index c96dd0e126a7..d02104dca1ee 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit c96dd0e126a788089fe700cf6effe4e87bc40e05 +Subproject commit d02104dca1eeb174a063aa06b54b774875a9106f diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index d878f9bb0594..d796f4c0818f 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -8,6 +8,7 @@ from . import base from . import contrib from . import ndarray +from . import ndarray as nd from . import name # use mx.sym as short for symbol from . import symbol as sym @@ -16,8 +17,6 @@ from . import io from . import recordio from . import operator -# use mx.nd as short for mx.ndarray -from . import ndarray as nd # use mx.rnd as short for mx.random from . import random as rnd from . import random diff --git a/python/mxnet/_ctypes/ndarray.py b/python/mxnet/_ctypes/ndarray.py index 396c57a41dfb..2a82c0899ea2 100644 --- a/python/mxnet/_ctypes/ndarray.py +++ b/python/mxnet/_ctypes/ndarray.py @@ -15,10 +15,19 @@ from ..ndarray_doc import _build_doc +_STORAGE_TYPE_ID_TO_STR = { + -1 : 'undefined', + 0 : 'default', + 1 : 'row_sparse', + 2 : 'csr', +} + + class NDArrayBase(object): """Base data structure for ndarray""" __slots__ = ["handle", "writable"] # pylint: disable= no-member + def __init__(self, handle, writable=True): """initialize a new NDArray @@ -61,7 +70,11 @@ def _imperative_invoke(handle, ndargs, keys, vals, out): output_vars = ctypes.POINTER(NDArrayHandle)() num_output = ctypes.c_int(0) - check_call(_LIB.MXImperativeInvoke( + # return output stypes to avoid the c_api call for checking + # a handle's stype in _ndarray_cls + out_stypes = ctypes.POINTER(ctypes.c_int)() + + check_call(_LIB.MXImperativeInvokeEx( ctypes.c_void_p(handle), ctypes.c_int(len(ndargs)), c_array(NDArrayHandle, [arr.handle for arr in ndargs]), @@ -69,14 +82,17 @@ def _imperative_invoke(handle, ndargs, keys, vals, out): ctypes.byref(output_vars), ctypes.c_int(len(keys)), c_array(ctypes.c_char_p, [c_str(key) for key in keys]), - c_array(ctypes.c_char_p, [c_str(str(val)) for val in vals]))) + c_array(ctypes.c_char_p, [c_str(str(val)) for val in vals]), + ctypes.byref(out_stypes))) if original_output is not None: return original_output if num_output.value == 1: - return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle)) + return _ndarray_cls(ctypes.cast(output_vars[0], NDArrayHandle), + stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[0]]) else: - return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle)) + return [_ndarray_cls(ctypes.cast(output_vars[i], NDArrayHandle), + stype=_STORAGE_TYPE_ID_TO_STR[out_stypes[i]]) for i in range(num_output.value)] diff --git a/python/mxnet/contrib/autograd.py b/python/mxnet/contrib/autograd.py index e56361efdb1f..aa212c72fc9a 100644 --- a/python/mxnet/contrib/autograd.py +++ b/python/mxnet/contrib/autograd.py @@ -7,6 +7,7 @@ import functools from ..base import _LIB, check_call, string_types from ..base import mx_uint, NDArrayHandle, c_array +# pylint: disable= unused-import from ..ndarray import NDArray, zeros_like from ..symbol import _GRAD_REQ_MAP diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py index 6b9aab2de6f1..ead7709f14ad 100644 --- a/python/mxnet/executor.py +++ b/python/mxnet/executor.py @@ -11,6 +11,7 @@ from .base import mx_uint, NDArrayHandle, ExecutorHandle from .base import check_call, c_array, py_str from .ndarray import NDArray +from .ndarray import _ndarray_cls from . import ndarray as nd # those functions are not used here, we just import them to keep backward compatibility @@ -90,7 +91,9 @@ def _get_outputs(self): handles = ctypes.POINTER(NDArrayHandle)() check_call(_LIB.MXExecutorOutputs(self.handle, ctypes.byref(out_size), ctypes.byref(handles))) - return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)] + num_output = out_size.value + outputs = [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(num_output)] + return outputs def forward(self, is_train=False, **kwargs): """Calculate the outputs specified by the bound symbol. diff --git a/python/mxnet/image.py b/python/mxnet/image.py index 890de7d0ffb8..9a9aedec8254 100644 --- a/python/mxnet/image.py +++ b/python/mxnet/image.py @@ -16,9 +16,9 @@ from .base import numeric_types from . import ndarray as nd -from . import _ndarray_internal as _internal -from ._ndarray_internal import _cvimresize as imresize -from ._ndarray_internal import _cvcopyMakeBorder as copyMakeBorder +from .ndarray import _internal +from .ndarray._internal import _cvimresize as imresize +from .ndarray._internal import _cvcopyMakeBorder as copyMakeBorder from . import io from . import recordio diff --git a/python/mxnet/io.py b/python/mxnet/io.py index bb791cef035e..a7d359a328c4 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -17,6 +17,7 @@ from .base import mx_real_t from .base import check_call, build_param_doc as _build_param_doc from .ndarray import NDArray +from .ndarray import _ndarray_cls from .ndarray import array from .ndarray import concatenate @@ -784,12 +785,12 @@ def iter_next(self): def getdata(self): hdl = NDArrayHandle() check_call(_LIB.MXDataIterGetData(self.handle, ctypes.byref(hdl))) - return NDArray(hdl, False) + return _ndarray_cls(hdl, False) def getlabel(self): hdl = NDArrayHandle() check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl))) - return NDArray(hdl, False) + return _ndarray_cls(hdl, False) def getindex(self): index_size = ctypes.c_uint64(0) diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 10b83b04db97..d1907e9d0e66 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -16,7 +16,7 @@ def _ctype_key_value(keys, vals): c_keys = [] c_vals = [] for key, val in zip(keys, vals): - c_key_i, c_val_i = _ctype_key_value(key, val) + c_key_i, c_val_i = _ctype_str_key_value(key, val) c_keys += c_key_i c_vals += c_val_i return (c_array(ctypes.c_char_p, c_keys), c_array(NDArrayHandle, c_vals)) @@ -44,7 +44,7 @@ def updater_handle(key, lhs_handle, rhs_handle, _): class KVStore(object): """A key-value store for synchronization of values, over multiple devices.""" - def __init__(self, handle): + def __init__(self, handle, name2idx=None): """Initializes a new KVStore. Parameters @@ -54,6 +54,7 @@ def __init__(self, handle): """ assert isinstance(handle, KVStoreHandle) self.handle = handle + self.name2idx = name2idx if name2idx is not None else {} self._updater = None self._updater_func = None @@ -391,7 +392,7 @@ def _send_command_to_servers(self, head, body): check_call(_LIB.MXKVStoreSendCommmandToServers( self.handle, mx_uint(head), c_str(body))) -def create(name='local'): +def create(name='local', name2idx=None): """Creates a new KVStore. For single machine training, there are two commonly used types: @@ -431,4 +432,4 @@ def create(name='local'): handle = KVStoreHandle() check_call(_LIB.MXKVStoreCreate(c_str(name), ctypes.byref(handle))) - return KVStore(handle) + return KVStore(handle, name2idx=name2idx) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index a476d84efd92..c91ef5474601 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -37,7 +37,7 @@ 'eval_metric', 'locals']) -def _create_kvstore(kvstore, num_device, arg_params): +def _create_kvstore(kvstore, num_device, arg_params, name2idx=None): """Create kvstore This function select and create a proper kvstore if given the kvstore type. @@ -61,7 +61,7 @@ def _create_kvstore(kvstore, num_device, arg_params): # no need to use kv for single device and single machine kv = None else: - kv = kvs.create(kvstore) + kv = kvs.create(kvstore, name2idx=name2idx) if kvstore == 'local': # automatically select a proper local max_size = max(np.prod(param.shape) for param in @@ -101,10 +101,11 @@ def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names): def _update_params(param_arrays, grad_arrays, updater, num_device, kvstore=None, param_names=None): """Perform update of param_arrays from grad_arrays not on kvstore.""" - for index, pair in enumerate(zip(param_arrays, grad_arrays)): + for i, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: continue + index = i if kvstore: name = param_names[index] # push gradient, priority is negative index @@ -114,7 +115,7 @@ def _update_params(param_arrays, grad_arrays, updater, num_device, for k, p in enumerate(zip(arg_list, grad_list)): # faked an index here, to make optimizer create diff # state for the same index but on diff devs, TODO(mli) - # use a better solution latter + # use a better solution later w, g = p updater(index*num_device+k, g, w) diff --git a/python/mxnet/module/module.py b/python/mxnet/module/module.py index 2a36c6ad7e7e..519499ebcf7f 100644 --- a/python/mxnet/module/module.py +++ b/python/mxnet/module/module.py @@ -16,6 +16,7 @@ from ..model import load_checkpoint from ..initializer import Uniform, InitDesc from ..io import DataDesc +from ..ndarray import zeros from .base_module import BaseModule, _check_input_names, _parse_data_desc @@ -409,13 +410,13 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, else: assert self._arg_params is None and self._aux_params is None param_arrays = [ - nd.zeros(x[0].shape, dtype=x[0].dtype) + zeros(shape=x[0].shape, dtype=x[0].dtype, stype=x[0].stype) for x in self._exec_group.param_arrays ] self._arg_params = {name:arr for name, arr in zip(self._param_names, param_arrays)} aux_arrays = [ - nd.zeros(x[0].shape, dtype=x[0].dtype) + zeros(x[0].shape, dtype=x[0].dtype) for x in self._exec_group.aux_arrays ] self._aux_params = {name:arr for name, arr in zip(self._aux_names, aux_arrays)} @@ -423,7 +424,6 @@ def bind(self, data_shapes, label_shapes=None, for_training=True, if shared_module is not None and shared_module.optimizer_initialized: self.borrow_optimizer(shared_module) - def reshape(self, data_shapes, label_shapes=None): """Reshapes the module for new input shapes. @@ -465,8 +465,12 @@ def init_optimizer(self, kvstore='local', optimizer='sgd', if self._params_dirty: self._sync_params_from_devices() + name2idx = {} + for idx, name in enumerate(self._exec_group.param_names): + name2idx[name] = idx + (kvstore, update_on_kvstore) = \ - _create_kvstore(kvstore, len(self._context), self._arg_params) + _create_kvstore(kvstore, len(self._context), self._arg_params, name2idx=name2idx) batch_size = self._exec_group.batch_size if kvstore and 'dist' in kvstore.type and '_sync' in kvstore.type: diff --git a/python/mxnet/ndarray/__init__.py b/python/mxnet/ndarray/__init__.py new file mode 100644 index 000000000000..d7accf092ced --- /dev/null +++ b/python/mxnet/ndarray/__init__.py @@ -0,0 +1,12 @@ +"""ndarray module""" + +from . import _internal +from . import op +from .op import CachedOp, invoke +from .ndarray import NDArray, array, concatenate, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP +from .ndarray import empty, ones, add, arange, divide, equal, full, greater, greater_equal, imdecode +from .ndarray import lesser, lesser_equal, maximum, minimum, moveaxis, multiply, negative, not_equal +from .ndarray import onehot_encode, power, subtract, true_divide, waitall, _new_empty_handle +from .ndarray_utils import load, save, zeros +from .sparse_ndarray import _ndarray_cls +from .sparse_ndarray import csr, row_sparse, SparseNDArray, todense, RowSparseNDArray, CSRNDArray diff --git a/python/mxnet/_ndarray_internal.py b/python/mxnet/ndarray/_internal.py similarity index 100% rename from python/mxnet/_ndarray_internal.py rename to python/mxnet/ndarray/_internal.py diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray/ndarray.py similarity index 87% rename from python/mxnet/ndarray.py rename to python/mxnet/ndarray/ndarray.py index 4939b6c221a5..c9000251dccb 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -4,6 +4,7 @@ """NDArray API of MXNet.""" from __future__ import absolute_import from __future__ import division + try: from __builtin__ import slice as py_slice except ImportError: @@ -11,40 +12,16 @@ import ctypes import warnings - -import os as _os -import sys as _sys - import operator import numpy as np -from .base import _LIB, string_types, numeric_types, integer_types -from .base import c_array, py_str, c_str, mx_real_t, _Null # pylint: disable=unused-import -from .base import mx_uint, NDArrayHandle, check_call, OpHandle -from .base import ctypes2buffer -from .context import Context -from . import _ndarray_internal as _internal -from .ndarray_doc import _build_doc - - -# Use different version of SymbolBase -# When possible, use cython to speedup part of computation. -# pylint: disable=unused-import -try: - if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: - from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class - from ._ctypes.ndarray import CachedOp, _imperative_invoke - elif _sys.version_info >= (3, 0): - from ._cy3.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke - from ._cy3.ndarray import CachedOp, _imperative_invoke - else: - from ._cy2.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke - from ._cy2.ndarray import CachedOp, _imperative_invoke -except ImportError: - if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: - raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") - from ._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _imperative_invoke - from ._ctypes.ndarray import CachedOp, _imperative_invoke -# pylint: enable=unused-import +from ..base import _LIB, numeric_types +from ..base import c_array, mx_real_t +from ..base import mx_uint, NDArrayHandle, check_call +from ..base import ctypes2buffer +from ..context import Context +from . import _internal +from .op import NDArrayBase, _STORAGE_TYPE_ID_TO_STR +from . import * # pylint: disable= no-member _DTYPE_NP_TO_MX = { @@ -53,18 +30,24 @@ np.float64 : 1, np.float16 : 2, np.uint8 : 3, - np.int32 : 4 + np.int32 : 4, + np.int64 : 6 } - _DTYPE_MX_TO_NP = { -1 : None, 0 : np.float32, 1 : np.float64, 2 : np.float16, 3 : np.uint8, - 4 : np.int32 + 4 : np.int32, + 6 : np.int64 +} +_STORAGE_TYPE_STR_TO_ID = { + 'undefined' : -1, + 'default' : 0, + 'row_sparse' : 1, + 'csr' : 2, } - _GRAD_REQ_MAP = { 'null': 0, 'write': 1, @@ -114,6 +97,11 @@ def waitall(): """ check_call(_LIB.MXNDArrayWaitAll()) +def _storage_type(handle): + storage_type = ctypes.c_int(0) + check_call(_LIB.MXNDArrayGetStorageType(handle, ctypes.byref(storage_type))) + return _STORAGE_TYPE_ID_TO_STR[storage_type.value] + class NDArray(NDArrayBase): """An array object representing a multidimensional, homogeneous array of fixed-size items. @@ -121,6 +109,7 @@ class NDArray(NDArrayBase): """ __slots__ = [] # pylint: disable= no-member, undefined-variable + def __repr__(self): """Returns a string representation of the array.""" shape_info = 'x'.join(['%d' % x for x in self.shape]) @@ -128,6 +117,9 @@ def __repr__(self): self.__class__.__name__, shape_info, self.context) + def __reduce__(self): + return NDArray, (None,), self.__getstate__() + def __add__(self, other): """x.__add__(y) <=> x+y <=> mx.nd.add(x, y) """ return add(self, other) @@ -699,7 +691,6 @@ def wait_to_read(self): """ check_call(_LIB.MXNDArrayWaitToRead(self.handle)) - @property def ndim(self): """Returns the number of dimensions of this array @@ -734,6 +725,7 @@ def shape(self): self.handle, ctypes.byref(ndim), ctypes.byref(pdata))) return tuple(pdata[:ndim.value]) + @property def size(self): """Number of elements in the array. @@ -795,6 +787,10 @@ def dtype(self): self.handle, ctypes.byref(mx_dtype))) return _DTYPE_MX_TO_NP[mx_dtype.value] + @property + def stype(self): + return _storage_type(self.handle) + @property # pylint: disable= invalid-name, undefined-variable def T(self): @@ -1047,6 +1043,13 @@ def backward(self, out_grad=None, retain_graph=False): c_array(NDArrayHandle, ograd_handles), ctypes.c_int(retain_graph))) + def _to_csr(self): + # pylint: disable=undefined-variable + return cast_storage(self, stype='csr') + + def _to_rsp(self): + # pylint: disable=undefined-variable + return cast_storage(self, stype='row_sparse') def onehot_encode(indices, out): """One-hot encoding indices into matrix out. @@ -1091,42 +1094,8 @@ def empty(shape, ctx=None, dtype=mx_real_t): ctx = Context.default_ctx return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype)) -def zeros(shape, ctx=None, dtype=mx_real_t, **kwargs): - """Returns a new array filled with all zeros, with the given shape and type. - - Parameters - ---------- - shape : int or tuple of int - The shape of the empty array. - ctx : Context, optional - An optional device context (default is the current default context). - dtype : str or numpy.dtype, optional - An optional value type (default is `float32`). - out : NDArray, optional - The output NDArray (default is `None`). - - Returns - ------- - NDArray - A created array - Examples - -------- - >>> mx.nd.zeros(1).asnumpy() - array([ 0.], dtype=float32) - >>> mx.nd.zeros((1,2), mx.gpu(0)) - - >>> mx.nd.zeros((1,2), mx.gpu(0), 'float16').asnumpy() - array([[ 0., 0.]], dtype=float16) - """ - # pylint: disable= unused-argument - if ctx is None: - ctx = Context.default_ctx - # pylint: disable= no-member, protected-access - return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) - # pylint: enable= no-member, protected-access - -def ones(shape, ctx=None, dtype=mx_real_t, **kwargs): +def ones(shape, ctx=None, dtype=None, **kwargs): """Returns a new array filled with all ones, with the given shape and type. Parameters @@ -1158,6 +1127,7 @@ def ones(shape, ctx=None, dtype=mx_real_t, **kwargs): # pylint: disable= unused-argument if ctx is None: ctx = Context.default_ctx + dtype = mx_real_t if dtype is None else dtype # pylint: disable= no-member, protected-access return _internal._ones(shape=shape, ctx=ctx, dtype=dtype, **kwargs) # pylint: enable= no-member, protected-access @@ -2255,89 +2225,6 @@ def negative(arr): """ return multiply(arr, -1.0) -def load(fname): - """Loads an array from file. - - See more details in ``save``. - - Parameters - ---------- - fname : str - The filename. - - Returns - ------- - list of NDArray or dict of str to NDArray - Loaded data. - """ - if not isinstance(fname, string_types): - raise TypeError('fname required to be a string') - out_size = mx_uint() - out_name_size = mx_uint() - handles = ctypes.POINTER(NDArrayHandle)() - names = ctypes.POINTER(ctypes.c_char_p)() - check_call(_LIB.MXNDArrayLoad(c_str(fname), - ctypes.byref(out_size), - ctypes.byref(handles), - ctypes.byref(out_name_size), - ctypes.byref(names))) - if out_name_size.value == 0: - return [NDArray(NDArrayHandle(handles[i])) for i in range(out_size.value)] - else: - assert out_name_size.value == out_size.value - return dict( - (py_str(names[i]), NDArray(NDArrayHandle(handles[i]))) for i in range(out_size.value)) - - -def save(fname, data): - """Saves a list of arrays or a dict of str->array to file. - - Examples of filenames: - - - ``/path/to/file`` - - ``s3://my-bucket/path/to/file`` (if compiled with AWS S3 supports) - - ``hdfs://path/to/file`` (if compiled with HDFS supports) - - Parameters - ---------- - fname : str - The filename. - data : list of ``NDArray` or dict of str to ``NDArray`` - The data to save. - - Examples - -------- - >>> x = mx.nd.zeros((2,3)) - >>> y = mx.nd.ones((1,4)) - >>> mx.nd.save('my_list', [x,y]) - >>> mx.nd.save('my_dict', {'x':x, 'y':y}) - >>> mx.nd.load('my_list') - [, ] - >>> mx.nd.load('my_dict') - {'y': , 'x': } - """ - handles = [] - if isinstance(data, dict): - keys = [] - for key, val in data.items(): - if not isinstance(key, string_types): - raise TypeError('save only accept dict str->NDArray or list of NDArray') - if not isinstance(val, NDArray): - raise TypeError('save only accept dict str->NDArray or list of NDArray') - keys.append(c_str(key)) - handles.append(val.handle) - keys = c_array(ctypes.c_char_p, keys) - else: - for val in data: - if not isinstance(val, NDArray): - raise TypeError('save only accept dict str->NDArray or list of NDArray') - handles.append(val.handle) - keys = None - check_call(_LIB.MXNDArraySave(c_str(fname), - mx_uint(len(handles)), - c_array(NDArrayHandle, handles), - keys)) - def concatenate(arrays, axis=0, always_copy=True): """DEPRECATED, use ``concat`` instead @@ -2437,160 +2324,38 @@ def imdecode(str_img, clip_rect=(0, 0, 0, 0), out=None, index=0, channels=3, mea out=out) -# pylint: disable=too-many-locals, invalid-name -def _make_ndarray_function(handle, name): - """Create a NDArray function from the FunctionHandle.""" - real_name = ctypes.c_char_p() - desc = ctypes.c_char_p() - num_args = mx_uint() - arg_names = ctypes.POINTER(ctypes.c_char_p)() - arg_types = ctypes.POINTER(ctypes.c_char_p)() - arg_descs = ctypes.POINTER(ctypes.c_char_p)() - key_var_num_args = ctypes.c_char_p() - ret_type = ctypes.c_char_p() - - check_call(_LIB.MXSymbolGetAtomicSymbolInfo( - handle, ctypes.byref(real_name), ctypes.byref(desc), - ctypes.byref(num_args), - ctypes.byref(arg_names), - ctypes.byref(arg_types), - ctypes.byref(arg_descs), - ctypes.byref(key_var_num_args), - ctypes.byref(ret_type))) - narg = int(num_args.value) - arg_names = [py_str(arg_names[i]) for i in range(narg)] - arg_types = [py_str(arg_types[i]) for i in range(narg)] - func_name = name - key_var_num_args = py_str(key_var_num_args.value) - ret_type = py_str(ret_type.value) if ret_type.value is not None else '' - doc_str = _build_doc(func_name, - py_str(desc.value), - arg_names, - arg_types, - [py_str(arg_descs[i]) for i in range(narg)], - key_var_num_args, - ret_type) - - dtype_name = None - arr_name = None - ndsignature = [] - signature = [] - ndarg_names = [] - kwarg_names = [] - for i in range(narg): - name, atype = arg_names[i], arg_types[i] - if name == 'dtype': - dtype_name = name - signature.append('%s=_Null'%name) - elif atype.startswith('NDArray') or atype.startswith('Symbol'): - assert not arr_name, \ - "Op can only have one argument with variable " \ - "size and it must be the last argument." - if atype.endswith('[]'): - ndsignature.append('*%s'%name) - arr_name = name - else: - ndsignature.append('%s=None'%name) - ndarg_names.append(name) - else: - signature.append('%s=_Null'%name) - kwarg_names.append(name) - #signature.append('is_train=False') - signature.append('out=None') - signature.append('name=None') - signature.append('**kwargs') - signature = ndsignature + signature - - code = [] - if arr_name: - code.append(""" -def %s(*%s, **kwargs):"""%(func_name, arr_name)) - code.append(""" - ndargs = [] - for i in {}: - assert isinstance(i, NDArrayBase), \\ - "Positional arguments must have NDArray type, " \\ - "but got %s"%str(i) - ndargs.append(i)""".format(arr_name)) - if dtype_name is not None: - code.append(""" - if '%s' in kwargs: - kwargs['%s'] = np.dtype(kwargs['%s']).name"""%( - dtype_name, dtype_name, dtype_name)) - code.append(""" - _ = kwargs.pop('name', None) - out = kwargs.pop('out', None) - keys = list(kwargs.keys()) - vals = list(kwargs.values())""") - else: - code.append(""" -def %s(%s): - ndargs = [] - keys = list(kwargs.keys()) - vals = list(kwargs.values())"""%(func_name, ', '.join(signature))) - # NDArray args - for name in ndarg_names: # pylint: disable=redefined-argument-from-local - code.append(""" - if {name} is not None: - assert isinstance({name}, NDArrayBase), \\ - "Argument {name} must have NDArray type, but got %s"%str({name}) - ndargs.append({name})""".format(name=name)) - # kwargs - for name in kwarg_names: # pylint: disable=redefined-argument-from-local - code.append(""" - if %s is not _Null: - keys.append('%s') - vals.append(%s)"""%(name, name, name)) - # dtype - if dtype_name is not None: - code.append(""" - if %s is not _Null: - keys.append('%s') - vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name)) - - code.append(""" - return _imperative_invoke(%d, ndargs, keys, vals, out)"""%( - handle.value)) - - local = {} - exec(''.join(code), None, local) # pylint: disable=exec-used - ndarray_function = local[func_name] - ndarray_function.__name__ = func_name - ndarray_function.__doc__ = doc_str - ndarray_function.__module__ = 'mxnet.ndarray' - return ndarray_function - - -# pylint: enable=too-many-locals, invalid-name -def _init_ndarray_module(ndarray_class, root_namespace): - """List and add all the ndarray functions to current module.""" - _set_ndarray_class(ndarray_class) - plist = ctypes.POINTER(ctypes.c_char_p)() - size = ctypes.c_uint() - - check_call(_LIB.MXListAllOpNames(ctypes.byref(size), - ctypes.byref(plist))) - op_names = [] - for i in range(size.value): - op_names.append(py_str(plist[i])) - - module_obj = _sys.modules["%s.ndarray" % root_namespace] - module_internal = _sys.modules["%s._ndarray_internal" % root_namespace] - module_contrib = _sys.modules["%s.contrib.ndarray" % root_namespace] - for name in op_names: - hdl = OpHandle() - check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) - function = _make_ndarray_function(hdl, name) - if function.__name__.startswith('_contrib_'): - function.__name__ = function.__name__[9:] - function.__module__ = 'mxnet.contrib.ndarray' - setattr(module_contrib, function.__name__, function) - elif function.__name__.startswith('_'): - setattr(module_internal, function.__name__, function) - else: - setattr(module_obj, function.__name__, function) +def _zeros_ndarray(shape, ctx=None, dtype=None, **kwargs): + """Returns a new array filled with all zeros, with the given shape and type. -_init_ndarray_module(NDArray, "mxnet") + Parameters + ---------- + shape : int or tuple of int + The shape of the empty array. + ctx : Context, optional + An optional device context (default is the current default context). + dtype : str or numpy.dtype, optional + An optional value type (default is `float32`). + out : NDArray, optional + The output NDArray (default is `None`). -# from .base import add_fileline_to_docstring -# add_fileline_to_docstring(__name__) + Returns + ------- + NDArray + A created array + + Examples + -------- + >>> mx.nd.zeros(1).asnumpy() + array([ 0.], dtype=float32) + >>> mx.nd.zeros((1,2), mx.gpu(0)) + + >>> mx.nd.zeros((1,2), mx.gpu(0), 'float16').asnumpy() + array([[ 0., 0.]], dtype=float16) + """ + # pylint: disable= unused-argument + if ctx is None: + ctx = Context.default_ctx + dtype = mx_real_t if dtype is None else dtype + # pylint: disable= no-member, protected-access + return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, **kwargs) + # pylint: enable= no-member, protected-access diff --git a/python/mxnet/ndarray/ndarray_utils.py b/python/mxnet/ndarray/ndarray_utils.py new file mode 100644 index 000000000000..2516372d1b55 --- /dev/null +++ b/python/mxnet/ndarray/ndarray_utils.py @@ -0,0 +1,99 @@ +# coding: utf-8 +"""Utility functions for NDArray and SparseNDArray.""" +import ctypes + +from ..base import _LIB, check_call, py_str, c_str, string_types, mx_uint, NDArrayHandle, c_array +from .ndarray import NDArray, _zeros_ndarray +from .sparse_ndarray import _ndarray_cls, _zeros_sparse_ndarray + + +def zeros(shape, ctx=None, dtype=None, stype=None, aux_types=None, **kwargs): + if stype is None: + return _zeros_ndarray(shape, ctx, dtype, **kwargs) + else: + return _zeros_sparse_ndarray(stype, shape, ctx, dtype, aux_types, **kwargs) + + +def load(fname): + """Loads an array from file. + + See more details in ``save``. + + Parameters + ---------- + fname : str + The filename. + + Returns + ------- + list of NDArray or dict of str to NDArray + Loaded data. + """ + if not isinstance(fname, string_types): + raise TypeError('fname required to be a string') + out_size = mx_uint() + out_name_size = mx_uint() + handles = ctypes.POINTER(NDArrayHandle)() + names = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.MXNDArrayLoad(c_str(fname), + ctypes.byref(out_size), + ctypes.byref(handles), + ctypes.byref(out_name_size), + ctypes.byref(names))) + if out_name_size.value == 0: + return [_ndarray_cls(NDArrayHandle(handles[i])) for i in range(out_size.value)] + else: + assert out_name_size.value == out_size.value + return dict( + (py_str(names[i]), _ndarray_cls(NDArrayHandle(handles[i]))) + for i in range(out_size.value)) + + +def save(fname, data): + """Saves a list of arrays or a dict of str->array to file. + + Examples of filenames: + + - ``/path/to/file`` + - ``s3://my-bucket/path/to/file`` (if compiled with AWS S3 supports) + - ``hdfs://path/to/file`` (if compiled with HDFS supports) + + Parameters + ---------- + fname : str + The filename. + data : list of ``NDArray` or dict of str to ``NDArray`` + The data to save. + + Examples + -------- + >>> x = mx.nd.zeros((2,3)) + >>> y = mx.nd.ones((1,4)) + >>> mx.nd.save('my_list', [x,y]) + >>> mx.nd.save('my_dict', {'x':x, 'y':y}) + >>> mx.nd.load('my_list') + [, ] + >>> mx.nd.load('my_dict') + {'y': , 'x': } + """ + handles = [] + if isinstance(data, dict): + keys = [] + for key, val in data.items(): + if not isinstance(key, string_types): + raise TypeError('save only accept dict str->NDArray or list of NDArray') + if not isinstance(val, NDArray): + raise TypeError('save only accept dict str->NDArray or list of NDArray') + keys.append(c_str(key)) + handles.append(val.handle) + keys = c_array(ctypes.c_char_p, keys) + else: + for val in data: + if not isinstance(val, NDArray): + raise TypeError('save only accept dict str->NDArray or list of NDArray') + handles.append(val.handle) + keys = None + check_call(_LIB.MXNDArraySave(c_str(fname), + mx_uint(len(handles)), + c_array(NDArrayHandle, handles), + keys)) diff --git a/python/mxnet/ndarray/op.py b/python/mxnet/ndarray/op.py new file mode 100644 index 000000000000..ba64e68c9394 --- /dev/null +++ b/python/mxnet/ndarray/op.py @@ -0,0 +1,189 @@ +"""Register backend ops in mxnet.ndarray namespace""" + +import sys as _sys +import os as _os +import ctypes +import numpy as np # pylint: disable=unused-import + +from ..ndarray_doc import _build_doc + +# Use different verison of SymbolBase +# When possible, use cython to speedup part of computation. +# pylint: disable=unused-import +try: + if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: + from .._ctypes.ndarray import NDArrayBase, _set_ndarray_class, _STORAGE_TYPE_ID_TO_STR + from .._ctypes.ndarray import invoke, CachedOp, _imperative_invoke + elif _sys.version_info >= (3, 0): + from .._cy3.ndarray import NDArrayBase, _set_ndarray_class,\ + _imperative_invoke, _STORAGE_TYPE_ID_TO_STR + from .._cy3.ndarray import invoke, CachedOp, _imperative_invoke + else: + from .._cy2.ndarray import NDArrayBase, _set_ndarray_class,\ + _imperative_invoke, _STORAGE_TYPE_ID_TO_STR + from .._cy2.ndarray import invoke, CachedOp, _imperative_invoke +except ImportError: + if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: + raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") + from .._ctypes.ndarray import NDArrayBase, _set_ndarray_class,\ + _imperative_invoke, _STORAGE_TYPE_ID_TO_STR + from .._ctypes.ndarray import invoke, CachedOp, _imperative_invoke + +from ..base import mx_uint, check_call, _LIB, py_str, OpHandle, c_str, _Null +# pylint: enable=unused-import + + +# pylint: disable=too-many-locals, invalid-name +def _make_ndarray_function(handle, name): + """Create a NDArray function from the FunctionHandle.""" + real_name = ctypes.c_char_p() + desc = ctypes.c_char_p() + num_args = mx_uint() + arg_names = ctypes.POINTER(ctypes.c_char_p)() + arg_types = ctypes.POINTER(ctypes.c_char_p)() + arg_descs = ctypes.POINTER(ctypes.c_char_p)() + key_var_num_args = ctypes.c_char_p() + ret_type = ctypes.c_char_p() + + check_call(_LIB.MXSymbolGetAtomicSymbolInfo( + handle, ctypes.byref(real_name), ctypes.byref(desc), + ctypes.byref(num_args), + ctypes.byref(arg_names), + ctypes.byref(arg_types), + ctypes.byref(arg_descs), + ctypes.byref(key_var_num_args), + ctypes.byref(ret_type))) + narg = int(num_args.value) + arg_names = [py_str(arg_names[i]) for i in range(narg)] + arg_types = [py_str(arg_types[i]) for i in range(narg)] + func_name = name + key_var_num_args = py_str(key_var_num_args.value) + ret_type = py_str(ret_type.value) if ret_type.value is not None else '' + doc_str = _build_doc(func_name, + py_str(desc.value), + arg_names, + arg_types, + [py_str(arg_descs[i]) for i in range(narg)], + key_var_num_args, + ret_type) + + dtype_name = None + arr_name = None + ndsignature = [] + signature = [] + ndarg_names = [] + kwarg_names = [] + for i in range(narg): + name, atype = arg_names[i], arg_types[i] + if name == 'dtype': + dtype_name = name + signature.append('%s=_Null'%name) + elif atype.startswith('NDArray') or atype.startswith('Symbol'): + assert not arr_name, \ + "Op can only have one argument with variable " \ + "size and it must be the last argument." + if atype.endswith('[]'): + ndsignature.append('*%s'%name) + arr_name = name + else: + ndsignature.append('%s=None'%name) + ndarg_names.append(name) + else: + signature.append('%s=_Null'%name) + kwarg_names.append(name) + # signature.append('is_train=False') + signature.append('out=None') + signature.append('name=None') + signature.append('**kwargs') + signature = ndsignature + signature + + code = [] + if arr_name: + code.append(""" +def %s(*%s, **kwargs):"""%(func_name, arr_name)) + code.append(""" + ndargs = [] + for i in {}: + assert isinstance(i, NDArrayBase), \\ + "Positional arguments must have NDArray type, " \\ + "but got %s"%str(i) + ndargs.append(i)""".format(arr_name)) + if dtype_name is not None: + code.append(""" + if '%s' in kwargs: + kwargs['%s'] = np.dtype(kwargs['%s']).name"""%( + dtype_name, dtype_name, dtype_name)) + code.append(""" + _ = kwargs.pop('name', None) + out = kwargs.pop('out', None) + keys = list(kwargs.keys()) + vals = list(kwargs.values())""") + else: + code.append(""" +def %s(%s): + ndargs = [] + keys = list(kwargs.keys()) + vals = list(kwargs.values())"""%(func_name, ', '.join(signature))) + # NDArray args + for name in ndarg_names: # pylint: disable=redefined-argument-from-local + code.append(""" + if {name} is not None: + assert isinstance({name}, NDArrayBase), \\ + "Argument {name} must have NDArray type, but got %s"%str({name}) + ndargs.append({name})""".format(name=name)) + # kwargs + for name in kwarg_names: # pylint: disable=redefined-argument-from-local + code.append(""" + if %s is not _Null: + keys.append('%s') + vals.append(%s)"""%(name, name, name)) + # dtype + if dtype_name is not None: + code.append(""" + if %s is not _Null: + keys.append('%s') + vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name)) + + code.append(""" + return _imperative_invoke(%d, ndargs, keys, vals, out)"""%( + handle.value)) + + local = {} + exec(''.join(code), None, local) # pylint: disable=exec-used + ndarray_function = local[func_name] + ndarray_function.__name__ = func_name + ndarray_function.__doc__ = doc_str + ndarray_function.__module__ = 'mxnet.ndarray' + return ndarray_function + + +# pylint: enable=too-many-locals, invalid-name +def _init_ndarray_module(root_namespace): + """List and add all the ndarray functions to current module.""" + plist = ctypes.POINTER(ctypes.c_char_p)() + size = ctypes.c_uint() + + check_call(_LIB.MXListAllOpNames(ctypes.byref(size), + ctypes.byref(plist))) + op_names = [] + for i in range(size.value): + op_names.append(py_str(plist[i])) + + module_obj = _sys.modules["%s.ndarray" % root_namespace] + module_internal = _sys.modules["%s.ndarray._internal" % root_namespace] + module_contrib = _sys.modules["%s.contrib.ndarray" % root_namespace] + for name in op_names: + hdl = OpHandle() + check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl))) + function = _make_ndarray_function(hdl, name) + if function.__name__.startswith('_contrib_'): + function.__name__ = function.__name__[9:] + function.__module__ = 'mxnet.contrib.ndarray' + setattr(module_contrib, function.__name__, function) + elif function.__name__.startswith('_'): + setattr(module_internal, function.__name__, function) + else: + setattr(module_obj, function.__name__, function) + +# register backend operators in mx.nd +_init_ndarray_module("mxnet") diff --git a/python/mxnet/ndarray/sparse_ndarray.py b/python/mxnet/ndarray/sparse_ndarray.py new file mode 100644 index 000000000000..720d44586a74 --- /dev/null +++ b/python/mxnet/ndarray/sparse_ndarray.py @@ -0,0 +1,628 @@ +# coding: utf-8 +"""SparseNDArray API of mxnet.""" +from __future__ import absolute_import +from __future__ import division +try: + from __builtin__ import slice as py_slice +except ImportError: + from builtins import slice as py_slice + +import ctypes +import warnings + +import os as _os +import sys as _sys + +# import operator +import numpy as np +from ..base import _LIB, numeric_types +from ..base import c_array, mx_real_t +from ..base import mx_uint, NDArrayHandle, check_call +from ..context import Context +from . import _internal +from . import ndarray +from .ndarray import _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP +from .ndarray import _STORAGE_TYPE_STR_TO_ID +from .ndarray import NDArray, _storage_type, _zeros_ndarray +from . import cast_storage +from . import slice as nd_slice + +# Use different verison of SymbolBase +# When possible, use cython to speedup part of computation. +# pylint: disable=unused-import +try: + if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0: + from .._ctypes.ndarray import NDArrayBase, _set_ndarray_class + elif _sys.version_info >= (3, 0): + from .._cy3.ndarray import NDArrayBase, _set_ndarray_class + else: + from .._cy2.ndarray import NDArrayBase, _set_ndarray_class +except ImportError: + if int(_os.environ.get("MXNET_ENFORCE_CYTHON", False)) != 0: + raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1") + from .._ctypes.ndarray import NDArrayBase, _set_ndarray_class + +# pylint: enable=unused-import +_STORAGE_AUX_TYPES = { + 'row_sparse': [np.int64], + 'csr': [np.int64, np.int64] +} + + +def _new_alloc_handle(stype, shape, ctx, delay_alloc, dtype, aux_types, aux_shapes=None): + """Return a new handle with specified storage type, shape, dtype and context. + + Empty handle is only used to hold results + + Returns + ------- + handle + A new empty ndarray handle + """ + hdl = NDArrayHandle() + aux_type_ids = [int(_DTYPE_NP_TO_MX[np.dtype(aux_t).type]) for aux_t in aux_types] + aux_shapes = [(0,) for aux_t in aux_types] if aux_shapes is None else aux_shapes + aux_shape_lens = [len(aux_shape) for aux_shape in aux_shapes] + aux_shapes = sum(aux_shapes, ()) + num_aux = mx_uint(len(aux_types)) + check_call(_LIB.MXNDArrayCreateSparseEx( + ctypes.c_int(int(_STORAGE_TYPE_STR_TO_ID[stype])), + c_array(mx_uint, shape), + mx_uint(len(shape)), + ctypes.c_int(ctx.device_typeid), + ctypes.c_int(ctx.device_id), + ctypes.c_int(int(delay_alloc)), + ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])), + num_aux, + c_array(ctypes.c_int, aux_type_ids), + c_array(mx_uint, aux_shape_lens), + c_array(mx_uint, aux_shapes), + ctypes.byref(hdl))) + return hdl + + +class SparseNDArray(NDArray): + """An array object representing a multidimensional, homogeneous array of + fixed-size items, stored in sparse format. See CSRNDArray and RowSparseNDArray + for more details. + """ + def __iadd__(self, other): + raise NotImplementedError("SparseND doesn't support __iadd__") + + def __isub__(self, other): + raise NotImplementedError("SparseND doesn't support __isub__") + + def __imul__(self, other): + raise NotImplementedError("SparseND doesn't support __imul__") + + def __idiv__(self, other): + raise NotImplementedError("SparseND doesn't support __idiv__") + + def __itruediv__(self, other): + raise NotImplementedError("SparseND doesn't support __itruediv__") + + def __setitem__(self, key, value): + """x.__setitem__(i, y) <=> x[i]=y + + Set self[key] to value. Only slice [:] is supported. + + Parameters + ---------- + key : slice + The indexing key. + value : NDArray or numpy.ndarray + The value to set. + + Examples + -------- + >>> src = mx.nd.row_sparse([[1, 0, 2], [4, 5, 6]], [0, 2], (3,3)) + >>> src.asnumpy() + array([[ 1., 0., 2.], + [ 0., 0., 0.], + [ 4., 5., 6.]], dtype=float32) + >>> # assign SparseNDArray with same storage type + >>> x = mx.nd.zeros('row_sparse', (3,3)) + >>> x[:] = src + >>> x.asnumpy() + array([[ 1., 0., 2.], + [ 0., 0., 0.], + [ 4., 5., 6.]], dtype=float32) + >>> # assign NDArray to SparseNDArray + >>> x[:] = mx.nd.ones((3,3)) + >>> x.asnumpy() + array([[ 1., 1., 1.], + [ 1., 1., 1.], + [ 1., 1., 1.]], dtype=float32) + """ + if not self.writable: + raise ValueError('Failed to assign to a readonly NDArray') + if isinstance(key, py_slice): + if key.step is not None or key.start is not None or key.stop is not None: + raise ValueError('Assignment with slicing not supported in SparseNDArray.') + if isinstance(value, NDArray): + # avoid copying to itself + if value.handle is not self.handle: + value.copyto(self) + elif isinstance(value, numeric_types): + raise Exception("Assigning numeric types to SparseNDArray not supported yet.") + elif isinstance(value, (np.ndarray, np.generic)): + # TODO(haibin) Implement _sync_copyfrom for sparse ndarray to avoid an extra copy + warnings.warn('Assigning non-NDArray object to SparseNDArray is not efficient', + RuntimeWarning) + tmp = ndarray.array(value) + tmp.copyto(self) + else: + raise TypeError('type %s not supported' % str(type(value))) + else: + assert(isinstance(key, (int, tuple))) + raise Exception('SparseNDArray only supports [:] for assignment') + + def __getitem__(self, key): + """x.__getitem__(i) <=> x[i] + + Returns a sliced view of this array. + + Parameters + ---------- + key : int or slice + Indexing key. + + Examples + -------- + >>> x = mx.nd.zeros((2, 3), stype='row_sparse') + >>> x[:] = mx.nd.arange(0,6).reshape((2,3)) + >>> x.asnumpy() + array([[ 0., 1., 2.], + [ 3., 4., 5.]], dtype=float32) + >>> x[1:2].asnumpy() + array([[ 3., 4., 5.]], dtype=float32) + """ + stype = self.stype + if stype != 'csr': + raise Exception("__getitem__ for " + str(stype) + " not implemented yet") + if isinstance(key, int): + raise Exception("Not implemented yet") + if isinstance(key, py_slice): + if key.step is not None: + raise ValueError('NDArray only supports continuous slicing on axis 0') + if key.start is not None or key.stop is not None: + begin = key.start if key.start else 0 + end = key.stop if key.stop else self.shape[0] + return nd_slice(self, begin=begin, end=end) + else: + return self + if isinstance(key, tuple): + raise ValueError('Multi-dimension indexing is not supported') + + def _sync_copyfrom(self, source_array): + raise Exception('Not implemented for SparseND yet!') + + def _at(self, idx): + raise Exception('at operator for SparseND is not supported.') + + def reshape(self, shape): + raise Exception('Not implemented for SparseND yet!') + + def broadcast_to(self, shape): + raise Exception('Not implemented for SparseND yet!') + + def _aux_type(self, i): + """Data-type of the array’s ith aux data. + + Returns + ------- + numpy.dtype + This SparseNDArray's aux data type. + """ + aux_type = ctypes.c_int() + check_call(_LIB.MXNDArrayGetAuxType(self.handle, i, ctypes.byref(aux_type))) + return _DTYPE_MX_TO_NP[aux_type.value] + + @property + def data(self): + """The values array of the SparseNDArray. This is a read-only view of the values array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's values array. + """ + return self._data() + + @property + def _num_aux(self): + ''' The number of aux data used to help store the sparse ndarray. + ''' + return len(_STORAGE_AUX_TYPES[self.stype]) + + @property + # pylint: disable= invalid-name, undefined-variable + def T(self): + raise Exception('Transpose is not supported for SparseNDArray.') + + @property + def _aux_types(self): + """The data types of the aux data for the SparseNDArray. + """ + aux_types = [] + num_aux = self._num_aux + for i in range(num_aux): + aux_types.append(self._aux_type(i)) + return aux_types + + def asnumpy(self): + """Return a dense ``numpy.ndarray`` object with value copied from this array + + """ + return self.todense().asnumpy() + + def astype(self, dtype): + """Returns a copy of the array after casting to a specified type. + Parameters + ---------- + dtype : numpy.dtype or str + The type of the returned array. + Examples + -------- + >>> x = mx.nd.zeros('row_sparse', (2,3), dtype='float32') + >>> y = x.astype('int32') + >>> y.dtype + + """ + res = _zeros_sparse_ndarray(shape=self.shape, ctx=self.context, + dtype=dtype, stype=self.stype) + self.copyto(res) + return res + + def copyto(self, other): + """Copies the value of this array to another array. + + If ``other`` is a ``NDArray`` object, then ``other.shape`` and + ``self.shape`` should be the same. This function copies the value from + ``self`` to ``other``. + + If ``other`` is a context, a new ``NDArray`` will be first created on + the target context, and the value of ``self`` is copied. + + Parameters + ---------- + other : NDArray or Context + The destination array or context. + + Returns + ------- + NDArray + The copied array. If ``other`` is an ``NDArray``, then the return value + and ``other`` will point to the same ``NDArray``. + """ + if isinstance(other, NDArray): + if other.handle is self.handle: + warnings.warn('You are attempting to copy an array to itself', RuntimeWarning) + return + return _internal._copyto(self, out=other) + elif isinstance(other, Context): + hret = _ndarray_cls(_new_alloc_handle(self.stype, self.shape, other, + True, self.dtype, self._aux_types)) + return _internal._copyto(self, out=hret) + else: + raise TypeError('copyto does not support type ' + str(type(other))) + + def todense(self): + return todense(self) + + def _aux_data(self, i, writable=False): + """ Get an NDArray referencing the ith aux data array associated with the SparseNDArray. + """ + self.wait_to_read() + hdl = NDArrayHandle() + check_call(_LIB.MXNDArrayGetAuxNDArray(self.handle, i, ctypes.byref(hdl))) + return NDArray(hdl, writable) + + def _data(self, writable=False): + """ Get an NDArray referencing the value array associated with the SparseNDArray. + """ + self.wait_to_read() + hdl = NDArrayHandle() + check_call(_LIB.MXNDArrayGetDataNDArray(self.handle, ctypes.byref(hdl))) + return NDArray(hdl, writable) + +# pylint: disable=abstract-method +class CSRNDArray(SparseNDArray): + """A CSRNDArray represents a NDArray as three separate arrays: `values`, + `indptr` and `indices`. It uses the standard CSR representation where the column indices for + row i are stored in indices[indptr[i]:indptr[i+1]] and their corresponding values are stored + in values[indptr[i]:indptr[i+1]]. + + Example + ------- + >>> a = mx.nd.array([[0, 1, 0], [2, 0, 0], [0, 0, 0], [0, 0, 3]]) + >>> a = a._to_csr() + >>> a.indices.asnumpy() + array([1, 0, 2]) + >>> a.indptr.asnumpy() + array([0, 1, 2, 2, 3]) + >>> a.data.asnumpy() + array([ 1., 2., 3.], dtype=float32) + """ + + def __reduce__(self): + return CSRNDArray, (None,), super(CSRNDArray, self).__getstate__() + + @property + def indices(self): + """The indices array of the SparseNDArray. This is a read-only view of the indices array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's indices array. + """ + return self._aux_data(1) + + @property + def indptr(self): + """The indptr array of the SparseNDArray with `csr` storage type. + This is a read-only view of the indptr array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's indptr array. + """ + return self._aux_data(0) + +# pylint: disable=abstract-method +class RowSparseNDArray(SparseNDArray): + """A RowSparseNDArray is typically used to represent a subset of a larger + NDArray with `default` of shape [LARGE0, D1, .. , DN] where LARGE0 >> D0. The values + in indices are the indices in the first dimension of the slices that have been extracted from + the larger NDArray. The indices are expected to be sorted in ascending order. + + The corresponding NDArray ``dense`` with `default` storage represented by a ``rsp`` + RowSparseNDArray + + ``dense[rsp.indices[i], :, :, :, ...] = rsp.values[i, :, :, :, ...]`` + + RowSparseNDArray is used principally in the definition of gradients for operations + that have sparse gradients (e.g. SparseEmbedding). + + Examples + -------- + >>> import mxnet as mx + >>> dense = mx.nd.array([[1,2],[0,0],[3,0],[0,0]]) + >>> rsp = dense._to_rsp() + >>> rsp.indices.asnumpy() + array([0, 2], dtype=int32) + >>> rsp.data.asnumpy() + array([[ 1., 2.], + [ 3., 0.]], dtype=float32) + """ + def __reduce__(self): + return RowSparseNDArray, (None,), super(RowSparseNDArray, self).__getstate__() + + @property + def indices(self): + """The indices array of the SparseNDArray. This is a read-only view of the indices array. + They reveal internal implementation details and should be used with care. + + Returns + ------- + NDArray + This SparseNDArray's indices array. + """ + return self._aux_data(0) + + +def _prepare_src_array(src, dtype, default_dtype): + if isinstance(src, NDArray): + dtype = src.dtype if dtype is None else dtype + else: + dtype = default_dtype if dtype is None else dtype + if not isinstance(src, np.ndarray): + try: + src = np.array(src, dtype=dtype) + except: + raise TypeError('values must be array like object') + return src, dtype + + +def csr(data, indptr, indices, shape, ctx=None, dtype=None, indptr_type=None, indices_type=None): + """Creates a 2D array with compressed sparse row format. + + Parameters + ---------- + data: array_like + An object exposing the array interface, with shape [nnz], where D0 is the number of + non-zero entries. + indptr: array_like + An object exposing the array interface, with shape [D0 + 1]. The first element in indptr + should always be zero. + indices: array_like + An object exposing the array interface, with shape [nnz]. + ctx: Context, optional + Device context (default is the current default context). + dtype: str or numpy.dtype, optional + The data type of the output array. The default dtype is ``values.dtype`` + if `values` is an `NDArray`, `float32` otherwise. + indptr_type: str or numpy.dtype, optional + The data type of the indices array. The default dtype is ``indptr.dtype`` + if `indptr` is an `NDArray`, `int32` otherwise. + indices_type: str or numpy.dtype, optional + The data type of the indices array. The default dtype is ``indices.dtype`` + if `indicies` is an `NDArray`, `int32` otherwise. + + Returns + ------- + CSRNDArray + A `CSRNDArray` with the `csr` storage representation. + + Example + ------- + >>> import mxnet as mx + >>> a = mx.nd.csr([1, 2, 3], [0, 1, 2, 2, 3], [1, 0, 2], (4, 3)) + >>> a.asnumpy() + array([[ 0., 1., 0.], + [ 2., 0., 0.], + [ 0., 0., 0.], + [ 0., 0., 3.]], dtype=float32) + """ + storage_type = 'csr' + # context + if ctx is None: + ctx = Context.default_ctx + # prepare src array and types + data, dtype = _prepare_src_array(data, dtype, mx_real_t) + indptr, indptr_type = _prepare_src_array(indptr, indptr_type, + _STORAGE_AUX_TYPES[storage_type][0]) + indices, indices_type = _prepare_src_array(indices, indices_type, + _STORAGE_AUX_TYPES[storage_type][1]) + # verify types + assert('int64' in str(indptr_type)), "expected int64 for indptr" + assert('int64' in str(indices_type)), "expected int64 for indices" + # verify shapes + aux_shapes = [indptr.shape, indices.shape] + assert(data.ndim == 1) + assert(indptr.ndim == 1) + assert(indices.ndim == 1) + assert(len(shape) == 2) + result = CSRNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype, + [indptr_type, indices_type], aux_shapes)) + # assign indptr, indices and data + data_ref = result._data(True) + indptr_ref = result._aux_data(0, True) + indices_ref = result._aux_data(1, True) + data_ref[:] = data + indptr_ref[:] = indptr + indices_ref[:] = indices + return result + + +def row_sparse(values, indices, shape, ctx=None, dtype=None, indices_type=None): + """Creates a row sparse array with a set of tensor slices at given indices. + + Parameters + ---------- + values: array_like + An object exposing the array interface, with shape [D0, D1, .. Dn], where D0 is + the number of rows with non-zeros entries. + indices: array_like + An object exposing the array interface, with shape [D0]. + ctx : Context, optional + Device context (default is the current default context). + dtype : str or numpy.dtype, optional + The data type of the output array. The default dtype is ``values.dtype`` + if `values` is an `NDArray`, `float32` otherwise. + indices_type: str or numpy.dtype, optional + The data type of the indices array. The default dtype is ``indices.dtype`` + if `indicies` is an `NDArray`, `int32` otherwise. + + Returns + ------- + RowSparseNDArray + An `RowSparseNDArray` with the `row_sparse` storage representation. + + Example + ------- + >>> a = mx.nd.row_sparse([[1, 2], [3, 4]], [1, 4], (6, 2)) + >>> a.asnumpy() + array([[ 0., 0.], + [ 1., 2.], + [ 0., 0.], + [ 0., 0.], + [ 3., 4.], + [ 0., 0.]], dtype=float32) + """ + storage_type = 'row_sparse' + # context + if ctx is None: + ctx = Context.default_ctx + # prepare src array and types + values, dtype = _prepare_src_array(values, dtype, mx_real_t) + indices, indices_type = _prepare_src_array(indices, indices_type, + _STORAGE_AUX_TYPES[storage_type][0]) + # verify types + assert('int64' in str(indices_type)), "expected int64 for indices" + # verify shapes + assert(values.ndim == len(shape)) + assert(indices.ndim == 1) + result = RowSparseNDArray(_new_alloc_handle(storage_type, shape, ctx, False, dtype, + [indices_type], [indices.shape])) + # assign indices and values + values_ref = result._data(True) + indices_ref = result._aux_data(0, True) + values_ref[:] = values + indices_ref[:] = indices + return result + + +def todense(source): + """ Return a dense array representation of this SparseNDArray. + + Returns + ------- + NDArray + The dense array with default storage + """ + return cast_storage(source, stype='default') + + +def _ndarray_cls(handle, writable=True, stype=None): + if stype is None: + stype = _storage_type(handle) + if stype == 'default': + return NDArray(handle, writable=writable) + elif stype == 'csr': + return CSRNDArray(handle, writable=writable) + elif stype == 'row_sparse': + return RowSparseNDArray(handle, writable=writable) + else: + raise Exception("unknown storage type") + + +_set_ndarray_class(_ndarray_cls) + + +def _zeros_sparse_ndarray(stype, shape, ctx=None, dtype=None, aux_types=None, **kwargs): + """Return a new array of given shape and type, filled with zeros. + + Parameters + ---------- + shape : int or tuple of int + The shape of the empty array + stype: string + The storage type of the empty array, such as 'row_sparse', 'csr', etc + ctx : Context, optional + An optional device context (default is the current default context) + dtype : str or numpy.dtype, optional + An optional value type (default is `float32`) + aux_types: list of numpy.dtype, optional + An optional type for the aux data for SparseNDArray (default values depends + on the storage type) + + Returns + ------- + SparseNDArray + A created array + Examples + -------- + >>> mx.nd.zeros('csr', (1,2), mx.gpu(0)) + + >>> mx.nd.zeros('row_sparse', (1,2), mx.gpu(0), 'float16').asnumpy() + array([[ 0., 0.]], dtype=float16) + """ + if stype == 'default': + return _zeros_ndarray(shape, ctx=ctx, dtype=dtype, **kwargs) + if ctx is None: + ctx = Context.default_ctx + dtype = mx_real_t if dtype is None else dtype + if aux_types is None: + if stype == 'row_sparse' or stype == 'csr': + aux_types = _STORAGE_AUX_TYPES[stype] + else: + raise Exception("unknown storage type") + assert(len(aux_types) == len(_STORAGE_AUX_TYPES[stype])) + out = _ndarray_cls(_new_alloc_handle(stype, shape, ctx, True, dtype, aux_types)) + return _internal._zeros(shape=shape, ctx=ctx, dtype=dtype, out=out, **kwargs) diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py index 57fadf44335e..cc6542b4d61f 100644 --- a/python/mxnet/optimizer.py +++ b/python/mxnet/optimizer.py @@ -342,9 +342,11 @@ def create_state(self, index, weight): momentum = None weight_master_copy = None if self.multi_precision and weight.dtype == numpy.float16: + assert(weight.stype == 'default'), \ + "multi-precision doesn't supprot non-default weight yet" weight_master_copy = array(weight, ctx=weight.context, dtype=numpy.float32) if self.momentum != 0.0: - momentum = zeros(weight.shape, weight.context, dtype=numpy.float32) + momentum = zeros(weight.shape, weight.context, dtype=numpy.float32, stype=weight.stype) return (momentum, weight_master_copy) if weight.dtype == numpy.float16 and not self.multi_precision: warnings.warn("Accumulating with float16 in optimizer can lead to " @@ -352,7 +354,7 @@ def create_state(self, index, weight): "Consider using multi_precision=True option of the " "SGD optimizer") if self.momentum != 0.0: - momentum = zeros(weight.shape, weight.context, dtype=weight.dtype) + momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=weight.stype) return momentum def update(self, index, weight, grad, state): @@ -650,7 +652,7 @@ def create_state(self, index, weight): zeros(weight.shape, weight.context), # g zeros(weight.shape, weight.context)) # delta else: - return (zeros(weight.shape, weight.context), ) # n + return (zeros(weight.shape, weight.context),) # n def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) diff --git a/python/mxnet/random.py b/python/mxnet/random.py index 91c2f5035ffa..5707632c83c1 100644 --- a/python/mxnet/random.py +++ b/python/mxnet/random.py @@ -5,13 +5,13 @@ import ctypes from .base import _LIB, check_call -from ._ndarray_internal import _sample_uniform as uniform -from ._ndarray_internal import _sample_normal as normal -from ._ndarray_internal import _sample_gamma as gamma -from ._ndarray_internal import _sample_exponential as exponential -from ._ndarray_internal import _sample_poisson as poisson -from ._ndarray_internal import _sample_negbinomial as negative_binomial -from ._ndarray_internal import _sample_gennegbinomial as generalized_negative_binomial +from .ndarray._internal import _sample_uniform as uniform +from .ndarray._internal import _sample_normal as normal +from .ndarray._internal import _sample_gamma as gamma +from .ndarray._internal import _sample_exponential as exponential +from .ndarray._internal import _sample_poisson as poisson +from .ndarray._internal import _sample_negbinomial as negative_binomial +from .ndarray._internal import _sample_gennegbinomial as generalized_negative_binomial def seed(seed_state): """Seeds the random number generators in MXNet. diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 4a9a3f4550c8..e55d654039e7 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -17,8 +17,10 @@ from .base import NDArrayHandle, ExecutorHandle, SymbolHandle, OpHandle from .base import check_call, MXNetError, NotImplementedForSymbol, _Null # pylint: disable=unused-import from .context import Context -from .ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP +from .ndarray.ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP from .name import NameManager # pylint: disable=unused-import +from .ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID +from .ndarray.sparse_ndarray import _ndarray_cls from .executor import Executor from . import _symbol_internal as _internal from .attribute import AttrScope @@ -1234,8 +1236,9 @@ def _get_ndarray_inputs(arg_key, args, arg_names, allow_missing): raise TypeError('Only accept list of NDArrays or dict of str to NDArray') return c_array(NDArrayHandle, arg_handles), arg_arrays - def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, - shared_arg_names=None, shared_exec=None, shared_buffer=None, **kwargs): + def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None, + group2ctx=None, shared_arg_names=None, shared_exec=None, + shared_buffer=None, **kwargs): """Bind current symbol to get an executor, allocate all the arguments needed. Allows specifying data types. @@ -1277,6 +1280,9 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, type_dict : Dict of str->numpy.dtype Input type dictionary, name->dtype + stype_dict : Dict of str->str + Input storage type dictionary, name->storage_type + group2ctx : Dict of string to mx.Context The dict mapping the `ctx_group` attribute to the context assignment. @@ -1291,7 +1297,8 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, shared_buffer : Dict of string to `NDArray` The dict mapping argument names to the `NDArray` that can be reused for initializing the current executor. This buffer will be checked for reuse if one argument name - of the current executor is not found in `shared_arg_names`. + of the current executor is not found in `shared_arg_names`. The `NDArray`s are + expected have default storage type. kwargs : Dict of str->shape Input shape dictionary, name->shape @@ -1301,6 +1308,7 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, executor : mxnet.Executor The generated executor """ + # data types num_provided_arg_types = 0 provided_arg_type_names = ctypes.POINTER(ctypes.c_char_p)() # provided type argument names provided_arg_type_data = ctypes.POINTER(mx_uint)() # provided types @@ -1316,6 +1324,22 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, provided_arg_type_names = c_array(ctypes.c_char_p, provided_arg_type_names) provided_arg_type_data = c_array(ctypes.c_int, provided_arg_type_data) + # storage types + num_provided_arg_stypes = 0 + # provided storage type argument names + provided_arg_stype_names = ctypes.POINTER(ctypes.c_char_p)() + provided_arg_stype_data = ctypes.POINTER(mx_uint)() # provided storage types + if stype_dict is not None: + provided_arg_stype_names = [] + provided_arg_stype_data = [] + for k, v in stype_dict.items(): + if v in _STORAGE_TYPE_STR_TO_ID: + provided_arg_stype_names.append(c_str(k)) + provided_arg_stype_data.append(ctypes.c_int(_STORAGE_TYPE_STR_TO_ID[v])) + num_provided_arg_stypes = mx_uint(len(provided_arg_stype_names)) + provided_arg_stype_names = c_array(ctypes.c_char_p, provided_arg_stype_names) + provided_arg_stype_data = c_array(ctypes.c_int, provided_arg_stype_data) + provided_arg_shape_data = [] # shape data # argument shape index in sdata, # e.g. [sdata[indptr[0]], sdata[indptr[1]]) is the shape of the first arg @@ -1389,6 +1413,8 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, shared_buffer_names = [] shared_buffer_handles = [] for k, v in shared_buffer.items(): + assert(v.stype == 'default'), \ + "shared_buffer is expected to only contain NDArrays with default storage" shared_buffer_names.append(c_str(k)) shared_buffer_handles.append(v.handle) shared_buffer_names = c_array(ctypes.c_char_p, shared_buffer_names) @@ -1428,6 +1454,9 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, num_provided_arg_types, provided_arg_type_names, provided_arg_type_data, + num_provided_arg_stypes, + provided_arg_stype_names, + provided_arg_stype_data, mx_uint(len(shared_arg_name_list)), c_array(ctypes.c_char_p, shared_arg_name_list), ctypes.byref(shared_buffer_len), @@ -1457,11 +1486,12 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, group2ctx=None, shared_buffer[k] = v # create in_args, arg_grads, and aux_states for the current executor - arg_arrays = [NDArray(NDArrayHandle(in_arg_handles[i])) for i in range(num_in_args.value)] - grad_arrays = [NDArray(NDArrayHandle(arg_grad_handles[i])) + arg_arrays = [_ndarray_cls(NDArrayHandle(in_arg_handles[i])) \ + for i in range(num_in_args.value)] + grad_arrays = [_ndarray_cls(NDArrayHandle(arg_grad_handles[i])) if arg_grad_handles[i] is not None else None for i in range(num_in_args.value)] - aux_arrays = [NDArray(NDArrayHandle(aux_state_handles[i])) + aux_arrays = [_ndarray_cls(NDArrayHandle(aux_state_handles[i])) for i in range(num_aux_states.value)] executor = Executor(exe_handle, self, ctx, grad_req, group2ctx) @@ -1738,7 +1768,8 @@ def detach(self): def backward(self): raise NotImplementedForSymbol(self.backward, None) -def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, init=None, **kwargs): +def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, + init=None, stype=None, **kwargs): """Creates a symbolic variable with specified name. Example usage: @@ -1765,6 +1796,8 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, ini The dtype for input variable. If not specified, this value will be inferred. init : initializer (mxnet.init.*) Initializer for this variable to (optionally) override the default initializer. + stype : str + The storage type of the variable. kwargs : Additional attribute variables Additional attributes must start and end with double underscores. @@ -1792,6 +1825,8 @@ def var(name, attr=None, shape=None, lr_mult=None, wd_mult=None, dtype=None, ini if not isinstance(init, string_types): init = init.dumps() attr['__init__'] = init + if stype is not None: + attr['__storage_type__'] = str(_STORAGE_TYPE_STR_TO_ID[stype]) for k, v in kwargs.items(): if k.startswith('__') and k.endswith('__'): attr[k] = str(v) diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 0666e46d930f..542a11ad4e22 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -10,17 +10,19 @@ import os import errno import logging +import scipy.sparse as sp import numpy as np import numpy.testing as npt -import mxnet as mx -from .context import Context -from .ndarray import array -from .symbol import Symbol +import numpy.random as rnd try: import requests except ImportError: # in rare cases requests may be not installed pass +import mxnet as mx +from .context import Context +from .ndarray.ndarray import array, _STORAGE_TYPE_STR_TO_ID +from .symbol import Symbol _rng = np.random.RandomState(1234) @@ -66,6 +68,53 @@ def random_arrays(*shapes): return arrays +def random_sample(population, k): + """Return a k length list of the elements chosen from the population sequence.""" + assert 0 <= k <= len(population) + population_copy = population[:] + np.random.shuffle(population_copy) + return population_copy[0:k] + + +def rand_sparse_ndarray(shape, stype, density=None): + """Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np) """ + density = rnd.rand() if density is None else density + if stype == 'row_sparse': + # TODO(haibin) support high dim sparse ndarray + assert(len(shape) < 3) + prod = np.prod(shape) + num_cols = int(prod / shape[0]) + # sample index + idx_sample = rnd.rand(shape[0]) + indices = np.argwhere(idx_sample < density).flatten() + if indices.shape[0] == 0: + result = mx.nd.zeros(shape, stype='row_sparse') + return result, (np.array([], dtype='int64'), np.array([], dtype='int64')) + # generate random values + val = rnd.rand(indices.shape[0], num_cols) + arr = mx.nd.row_sparse(val, indices, shape, indices_type=np.int64) + return arr, (val, indices) + elif stype == 'csr': + assert(len(shape) == 2) + csr = sp.rand(shape[0], shape[1], density=density, format='csr') + result = mx.nd.csr(csr.data, csr.indptr, csr.indices, shape) + return result, (csr.indptr, csr.indices, csr.data) + else: + assert(False), "unknown storage type" + + +def rand_ndarray(shape, stype, density=None): + if stype == 'default': + arr = mx.nd.array(random_arrays(shape)) + else: + arr, _ = rand_sparse_ndarray(shape, stype, density=density) + return arr + + +def rand_shape_2d(dim0=10, dim1=10): + return rnd.randint(1, dim0), rnd.randint(1, dim1) + + def np_reduce(dat, axis, keepdims, numpy_reduce_func): """Compatible reduce for old version of NumPy. @@ -297,7 +346,8 @@ def _parse_location(sym, location, ctx): % (str(set(sym.list_arguments())), str(set(location.keys())))) else: location = {k: v for k, v in zip(sym.list_arguments(), location)} - location = {k: mx.nd.array(v, ctx=ctx) for k, v in location.items()} + location = {k: mx.nd.array(v, ctx=ctx) if isinstance(v, np.ndarray) \ + else v for k, v in location.items()} return location @@ -418,7 +468,8 @@ def numeric_grad(executor, location, aux_states=None, eps=1e-4, use_forward_trai def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rtol=1e-2, - atol=None, grad_nodes=None, use_forward_train=True, ctx=None): + atol=None, grad_nodes=None, use_forward_train=True, ctx=None, + grad_stype_dict=None): """Verify an operation by checking backward pass via finite difference method. Based on Theano's `theano.gradient.verify_grad` [1] @@ -435,7 +486,7 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto - if type is dict of str -> numpy.ndarray maps the name of arguments to the corresponding numpy.ndarray. *In either case, value of all the arguments must be provided.* - aux_states : ist or tuple or dict, optional + aux_states : list or tuple or dict, optional The auxiliary states required when generating the executor for the symbol. numeric_eps : float, optional Delta for the finite difference method that approximates the gradient. @@ -447,6 +498,8 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto Whether to use is_train=True when computing the finite-difference. ctx : Context, optional Check the gradient computation on the specified device. + grad_stype_dict : dict of str->str, optional + Storage type dictionary for gradient ndarrays. References --------- ..[1] https://github.com/Theano/Theano/blob/master/theano/gradient.py @@ -470,7 +523,7 @@ def random_projection(shape): location_npy = {k:v.asnumpy() for k, v in location.items()} aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx) if aux_states is not None: - aux_states_npy = {k:v.asnumpy() for k, v in aux_states.items()} + aux_states_npy = {k: v.asnumpy() for k, v in aux_states.items()} else: aux_states_npy = None if grad_nodes is None: @@ -497,6 +550,11 @@ def random_projection(shape): + [("__random_proj", _rng.normal(0, 0.01, size=out_shape[0]))]) args_grad = {k: mx.nd.array(v, ctx=ctx) for k, v in args_grad_npy.items()} + if grad_stype_dict is not None: + assert isinstance(grad_stype_dict, dict), "grad_stype_dict must be a dict" + for k, v in grad_stype_dict.items(): + if k in args_grad and v in _STORAGE_TYPE_STR_TO_ID and v != 'default': + args_grad[k] = mx.nd.cast_storage(args_grad[k], stype=v) executor = out.bind(ctx, grad_req=grad_req, args=location, args_grad=args_grad, aux_states=aux_states) @@ -588,15 +646,15 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, g[:] = 0 executor.forward(is_train=False) - outputs = [x.asnumpy() for x in executor.outputs] + outputs = [x.asnumpy() for x in executor.outputs] for output_name, expect, output in zip(sym.list_outputs(), expected, outputs): assert_almost_equal(expect, output, rtol, atol, ("EXPECTED_%s"%output_name, "FORWARD_%s"%output_name)) def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol=None, - aux_states=None, grad_req='write', ctx=None): + aux_states=None, grad_req='write', ctx=None, grad_stypes=None): """Compares a symbol's backward results with the expected ones. Prints error messages if the backward results are not the same as the expected results. @@ -632,6 +690,8 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= Gradient requirements. 'write', 'add' or 'null'. ctx : Context, optional Running context. + grad_stypes: dict of str->str + dictionary of mapping argument name to stype for the gradient Example ------- @@ -657,14 +717,24 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= if isinstance(expected, (list, tuple)): expected = {k:v for k, v in zip(sym.list_arguments(), expected)} args_grad_npy = {k:_rng.normal(size=v.shape) for k, v in expected.items()} - args_grad_data = {k: mx.nd.array(v, ctx=ctx) for k, v in args_grad_npy.items()} + args_grad_data = {} + for k, v in args_grad_npy.items(): + nd = mx.nd.array(v, ctx=ctx) + if grad_stypes is not None and k in grad_stypes: + out = mx.nd.cast_storage(nd, stype=grad_stypes[k]) + args_grad_data[k] = out + else: + args_grad_data[k] = nd + if isinstance(grad_req, str): grad_req = {k:grad_req for k in sym.list_arguments()} elif isinstance(grad_req, (list, tuple)): grad_req = {k:v for k, v in zip(sym.list_arguments(), grad_req)} - executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states) + executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, + aux_states=aux_states, grad_req=grad_req) executor.forward(is_train=True) + if isinstance(out_grads, (tuple, list)): out_grads = [mx.nd.array(v, ctx=ctx) for v in out_grads] elif isinstance(out_grads, (dict)): diff --git a/python/setup.py b/python/setup.py index 8a8693038b3c..1f3abf536f86 100644 --- a/python/setup.py +++ b/python/setup.py @@ -74,7 +74,7 @@ def config_cython(): version=__version__, description=open(os.path.join(CURRENT_DIR, 'README.md')).read(), packages=[ - 'mxnet', 'mxnet.module', 'mxnet._ctypes', 'mxnet.rnn', + 'mxnet', 'mxnet.module', 'mxnet._ctypes', 'mxnet.rnn', 'mxnet.ndarray', 'mxnet._cy2', 'mxnet._cy3', 'mxnet.notebook', 'mxnet.contrib' ], data_files=[('mxnet', [LIB_PATH[0]])], diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 214e6ede5292..20f3e60cd371 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -154,6 +154,39 @@ int MXNDArrayCreateEx(const mx_uint *shape, API_END(); } +int MXNDArrayCreateSparseEx(int storage_type, + const mx_uint *shape, + mx_uint ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + mx_uint num_aux, + int *aux_type, + mx_uint *aux_ndims, + const mx_uint *aux_shape, + NDArrayHandle *out) { + API_BEGIN(); + std::vector aux_types; + std::vector aux_shapes; + auto shape_start = aux_shape; + for (size_t i = 0; i < num_aux; i++) { + // types + aux_types.push_back(aux_type[i]); + // shapes + aux_shapes.emplace_back(shape_start, shape_start + aux_ndims[i]); + shape_start += aux_ndims[i]; + } + *out = new NDArray( + NDArrayStorageType(storage_type), + TShape(shape, shape + ndim), + Context::Create(static_cast(dev_type), dev_id), + delay_alloc != 0, + dtype, aux_types, aux_shapes); + API_END(); +} + + int MXNDArrayLoadFromRawBytes(const void *buf, size_t size, NDArrayHandle *out) { @@ -333,6 +366,18 @@ MXNET_DLL int MXNDArrayReshape(NDArrayHandle handle, API_END_HANDLE_ERROR(delete ptr); } +int MXNDArrayGetStorageType(NDArrayHandle handle, + int *out_storage_type) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + if (!arr->is_none()) { + *out_storage_type = arr->storage_type(); + } else { + *out_storage_type = kUndefinedStorage; + } + API_END(); +} + int MXNDArrayGetShape(NDArrayHandle handle, mx_uint *out_dim, const mx_uint **out_pdata) { @@ -382,6 +427,32 @@ int MXNDArrayGetDType(NDArrayHandle handle, API_END(); } +int MXNDArrayGetAuxType(NDArrayHandle handle, + mx_uint i, + int *out_type) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out_type = arr->aux_type(i); + API_END(); +} + +int MXNDArrayGetAuxNDArray(NDArrayHandle handle, + mx_uint i, + NDArrayHandle *out) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out = new NDArray(arr->aux_ndarray(i)); + API_END(); +} + +int MXNDArrayGetDataNDArray(NDArrayHandle handle, + NDArrayHandle *out) { + API_BEGIN(); + NDArray *arr = static_cast(handle); + *out = new NDArray(arr->data_ndarray()); + API_END(); +} + int MXNDArrayGetContext(NDArrayHandle handle, int *out_dev_type, int *out_dev_id) { diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index d8857f80635d..f2cad238a71b 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -58,6 +58,8 @@ struct MXAPIThreadLocalEntry { std::vector arg_shapes, out_shapes, aux_shapes; /*! \brief result holder for returning type flags */ std::vector arg_types, out_types, aux_types; + /*! \brief result holder for returning storage types */ + std::vector arg_storage_types, out_storage_types, aux_storage_types; /*! \brief result holder for returning shape dimensions */ std::vector arg_shape_ndim, out_shape_ndim, aux_shape_ndim; /*! \brief result holder for returning shape pointer */ diff --git a/src/c_api/c_api_executor.cc b/src/c_api/c_api_executor.cc index ca49402ecf7e..d9beb410e929 100644 --- a/src/c_api/c_api_executor.cc +++ b/src/c_api/c_api_executor.cc @@ -173,6 +173,9 @@ int MXExecutorBindEX(SymbolHandle symbol_handle, * \param num_provided_arg_dtypes number of user provided in_arg and axu_state dtypes * \param provided_arg_dtype_names argument name list of provided dtypes * \param provided_arg_dtypes data of provided dtypes + * \param num_provided_arg_stypes number of user provided in_arg and axu_state storage types + * \param provided_arg_stype_names argument name list of provided storage types + * \param provided_arg_stypes data of provided storage types * \param num_shared_arg_names number of parameter names passed from _bind_ith_exec * \param shared_arg_name_list parameter name list passed from _bind_ith_exec * \param shared_buffer_len number of shared data arrays passed from _bind_ith_exec @@ -205,6 +208,9 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, const mx_uint num_provided_arg_dtypes, const char** provided_arg_dtype_names, const int* provided_arg_dtypes, + const mx_uint num_provided_arg_stypes, + const char** provided_arg_stype_names, + const int* provided_arg_stypes, const mx_uint num_shared_arg_names, const char** shared_arg_name_list, int* shared_buffer_len, @@ -229,7 +235,7 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, // attr_dict for setting up type_dict and arg/aux ctx std::unordered_map> attr_dict; - if (nullptr == provided_arg_dtypes || nullptr != g2c_keys) { + if (nullptr == provided_arg_dtypes || nullptr != g2c_keys || nullptr == provided_arg_stypes) { std::vector> attrs = sym->ListAttrsRecursive(); attr_dict.reserve(attrs.size()); @@ -255,6 +261,23 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, } } + // setup arg_stype_map + std::unordered_map arg_stype_map; + if (nullptr == provided_arg_stypes) { // use attr_dict + for (const auto& arg_name : in_arg_names) { + const auto it = attr_dict.find(arg_name); + if (it == attr_dict.end() || !it->second.count("__storage_type__")) { + arg_stype_map[arg_name] = kDefaultStorage; + } + } + } else { // use user input type_dict + // create stype map for in_args and aux_states + arg_stype_map.reserve(num_provided_arg_stypes); + for (mx_uint i = 0; i < num_provided_arg_stypes; ++i) { + arg_stype_map[provided_arg_stype_names[i]] = provided_arg_stypes[i]; + } + } + // create default ctx Context ctx = Context::Create(static_cast(dev_type), dev_id); // create ctx map @@ -395,9 +418,10 @@ int MXExecutorSimpleBind(SymbolHandle symbol_handle, std::vector aux_state_vec; *out = Executor::SimpleBind(*sym, ctx, ctx_map, in_arg_ctx_vec, arg_grad_ctx_vec, - aux_state_ctx_vec, arg_shape_map, arg_dtype_map, grad_req_type_vec, - shared_arg_name_set, &in_arg_vec, &arg_grad_vec, &aux_state_vec, - use_shared_buffer? &shared_buffer_map : nullptr, + aux_state_ctx_vec, arg_shape_map, arg_dtype_map, arg_stype_map, + grad_req_type_vec, shared_arg_name_set, &in_arg_vec, + &arg_grad_vec, &aux_state_vec, + use_shared_buffer ? &shared_buffer_map : nullptr, reinterpret_cast(shared_exec_handle)); // copy ndarray ptrs to ret->handles so that front end diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 818f263cb3b7..452b06fe0634 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -1,6 +1,6 @@ /*! * Copyright (c) 2016 by Contributors - * \file c_api_symbolic.cc + * \file c_api_ndarray.cc * \brief C API of mxnet */ @@ -132,14 +132,17 @@ void SetContext(Context* p_ctx, #endif // MXNET_USE_CUDA } +// Set the shape, dtype and storage type void SetShapeType(const nnvm::Op* op, const nnvm::NodeAttrs& attrs, const Context& ctx, const std::vector& ndinputs, - std::vector* p_ndoutputs) { + std::vector* p_ndoutputs, + int* dispatch_stype) { std::vector& ndoutputs = *p_ndoutputs; static auto& infershape = nnvm::Op::GetAttr("FInferShape"); static auto& infertype = nnvm::Op::GetAttr("FInferType"); + static auto& inferstorage = nnvm::Op::GetAttr("FInferStorageType"); MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); // infer shape std::vector& in_shapes = ret->arg_shapes; @@ -175,9 +178,34 @@ void SetShapeType(const nnvm::Op* op, CHECK(infertype[op](attrs, &in_types, &out_types)); CHECK_EQ(out_types.size(), ndoutputs.size()); + // infer storage type + auto& in_storage_types = ret->arg_storage_types; + auto& out_storage_types = ret->out_storage_types; + in_storage_types.clear(); + out_storage_types.clear(); + for (auto& i : ndinputs) { + in_storage_types.push_back(i.storage_type()); + } + for (auto& i : ndoutputs) { + out_storage_types.push_back(i.storage_type()); + } + if (inferstorage.count(op)) { + CHECK(inferstorage[op](attrs, ctx, &in_storage_types, &out_storage_types)); + CHECK_EQ(out_storage_types.size(), ndoutputs.size()); + } + + bool contains_non_default = common::ContainsNonDefaultStorage(in_storage_types); + contains_non_default |= common::ContainsNonDefaultStorage(out_storage_types); + int kNonDefaultStorage = -2; + *dispatch_stype = contains_non_default ? kNonDefaultStorage : kDefaultStorage; for (size_t i = 0; i < ndoutputs.size(); ++i) { if (ndoutputs[i].is_none()) { - ndoutputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]); + // if failed to infer the storage type, assume the output storage is dense + if (storage_type == kDefaultStorage || out_storage_types[i] == kUndefinedStorage) { + ndoutputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]); + } else { + ndoutputs[i] = NDArray(storage_type, out_shapes[i], ctx, true, out_types[i]); + } } else { CHECK_EQ(ndoutputs[i].shape(), out_shapes[i]) << i << "th output has invalid shape. " @@ -250,23 +278,60 @@ void PushFCompute(const FCompute& fn, const std::vector& requested, const std::vector& ndinputs, const std::vector& ndoutputs) { + using namespace common; bool is_train = AutogradRuntime::Get()->IsTraining(); Engine::Get()->PushAsync( [ctx, attrs, fn, ndinputs, ndoutputs, requested, is_train]( RunContext rctx, engine::CallbackOnComplete on_complete) { std::vector input_blobs, output_blobs; - for (auto& i : ndinputs) { - input_blobs.push_back(i.data()); - } - for (auto& i : ndoutputs) { - output_blobs.push_back(i.data()); - } + std::vector temp_in, temp_out; OpContext opctx{is_train, rctx, engine::CallbackOnComplete(), requested}; - std::vector req(output_blobs.size(), kWriteTo); - fn(attrs, opctx, input_blobs, req, output_blobs); + if (ctx.dev_mask() == gpu::kDevMask) { +#if MXNET_USE_CUDA + GetDefaultBlobs(ndinputs, &input_blobs, &temp_in, opctx); + GetDefaultBlobs(ndoutputs, &output_blobs, &temp_out, opctx); + std::vector req(output_blobs.size(), kWriteTo); + fn(attrs, opctx, input_blobs, req, output_blobs); + // cast to original storage type, if necessary + CastNonDefaultStorage(ndoutputs, temp_out, opctx); + rctx.get_stream()->Wait(); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + GetDefaultBlobs(ndinputs, &input_blobs, &temp_in, opctx); + GetDefaultBlobs(ndoutputs, &output_blobs, &temp_out, opctx); + std::vector req(output_blobs.size(), kWriteTo); + fn(attrs, opctx, input_blobs, req, output_blobs); + CastNonDefaultStorage(ndoutputs, temp_out, opctx); + } + on_complete(); + }, ctx, read_vars, write_vars, FnProperty::kNormal, + 0, PROFILER_MESSAGE(op->name.c_str())); +} + +void PushFComputeEx(const FComputeEx& fn, + const nnvm::Op* op, + const nnvm::NodeAttrs& attrs, + const Context& ctx, + const std::vector& read_vars, + const std::vector& write_vars, + const std::vector& requested, + const std::vector& ndinputs, + const std::vector& ndoutputs) { + Engine::Get()->PushAsync( + [ctx, attrs, fn, ndinputs, ndoutputs, requested]( + RunContext rctx, + engine::CallbackOnComplete on_complete) { + std::vector input_blobs, output_blobs; + OpContext opctx{false, rctx, + engine::CallbackOnComplete(), + requested}; + std::vector req(ndoutputs.size(), kWriteTo); + fn(attrs, opctx, ndinputs, req, ndoutputs); if (ctx.dev_mask() == gpu::kDevMask) { rctx.get_stream()->Wait(); } @@ -301,10 +366,25 @@ void PushOperator(const OpStatePtr& state, engine::CallbackOnComplete on_complete) { OpContext opctx{is_train, rctx, on_complete, requested}; std::vector input_blobs, output_blobs; - for (const auto& i : ndinputs) input_blobs.push_back(i.data()); - for (const auto& i : ndoutputs) output_blobs.push_back(i.data()); - std::vector req(output_blobs.size(), kWriteTo); - fcompute(state, opctx, input_blobs, req, output_blobs); + std::vector temp_in, temp_out; + if (ctx.dev_mask() == gpu::kDevMask) { +#if MXNET_USE_CUDA + GetDefaultBlobs(ndinputs, &input_blobs, &temp_in, opctx); + GetDefaultBlobs(ndoutputs, &output_blobs, &temp_out, opctx); + std::vector req(output_blobs.size(), kWriteTo); + fcompute(state, opctx, input_blobs, req, output_blobs); + // cast to original storage type, if necessary + CastNonDefaultStorage(ndoutputs, temp_out, opctx); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + GetDefaultBlobs(ndinputs, &input_blobs, &temp_in, opctx); + GetDefaultBlobs(ndoutputs, &output_blobs, &temp_out, opctx); + std::vector req(output_blobs.size(), kWriteTo); + fcompute(state, opctx, input_blobs, req, output_blobs); + CastNonDefaultStorage(ndoutputs, temp_out, opctx); + } if (exec_type == ExecType::kSync) { if (rctx.get_ctx().dev_mask() == gpu::kDevMask) { rctx.get_stream()->Wait(); @@ -443,6 +523,28 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, API_END(); } +int MXImperativeInvokeEx(AtomicSymbolCreator creator, + int num_inputs, + NDArrayHandle *inputs, + int *num_outputs, + NDArrayHandle **outputs, + int num_params, + const char **param_keys, + const char **param_vals, + const int **out_stypes) { // outputs storage types + API_BEGIN(); + MXImperativeInvoke(creator, num_inputs, inputs, num_outputs, outputs, + num_params, param_keys, param_vals); + MXAPIThreadLocalEntry *ret = MXAPIThreadLocalStore::Get(); + NDArray** output_nds = reinterpret_cast(*outputs); + ret->out_types.resize(*num_outputs); + for (int i = 0; i < *num_outputs; ++i) { + ret->out_types[i] = output_nds[i]->storage_type(); + } + *out_stypes = dmlc::BeginPtr(ret->out_types); + API_END(); +} + int MXCreateCachedOp(SymbolHandle handle, CachedOpHandle *out) { nnvm::Symbol* sym = static_cast(handle); diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index d3603e94b2a1..9cc12225ca2a 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -11,6 +11,7 @@ #include #include "./c_api_common.h" #include "../operator/operator_common.h" +#include "../executor/exec_pass.h" namespace mxnet { namespace op { @@ -441,7 +442,7 @@ int MXSymbolInferShape(SymbolHandle sym, } try { - g = nnvm::pass::InferShape(std::move(g), arg_shapes, "__shape__"); + g = mxnet::exec::InferShape(std::move(g), arg_shapes, "__shape__"); } catch (const mxnet::op::InferShapeError &err) { throw dmlc::Error(err.msg); } @@ -526,7 +527,7 @@ int MXSymbolInferType(SymbolHandle sym, mxnet::MatchArguments(g.indexed_graph(), kwargs, &arg_types, "InferType"); } - g = nnvm::pass::InferType(std::move(g), arg_types, "__dtype__"); + g = mxnet::exec::InferType(std::move(g), arg_types, "__dtype__"); // copy back CopyAttr(g.indexed_graph(), g.GetAttr("dtype"), &(ret->arg_types), &(ret->out_types), &(ret->aux_types)); diff --git a/src/c_api/c_predict_api.cc b/src/c_api/c_predict_api.cc index 1dd784ba2249..0bee6cf9f838 100644 --- a/src/c_api/c_predict_api.cc +++ b/src/c_api/c_predict_api.cc @@ -14,6 +14,7 @@ #include #include "./c_api_common.h" #include "../operator/operator_common.h" +#include "../executor/exec_pass.h" using namespace mxnet; @@ -176,7 +177,7 @@ int MXPredCreatePartialOut(const char* symbol_json_str, } } nnvm::Graph g; g.outputs = sym.outputs; - g = nnvm::pass::InferShape(std::move(g), in_shapes, "__shape__"); + g = mxnet::exec::InferShape(std::move(g), in_shapes, "__shape__"); bool infer_complete = (g.GetAttr("shape_num_unknown_nodes") == 0); CHECK(infer_complete) << "The shape information of is not enough to get the shapes"; diff --git a/src/common/utils.cc b/src/common/utils.cc new file mode 100644 index 000000000000..4bcae02e990c --- /dev/null +++ b/src/common/utils.cc @@ -0,0 +1,23 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file utils.cc + * \brief cpu implementation of util functions + */ + +#include "./utils.h" +#include "../operator/tensor/cast_storage-inl.h" + +namespace mxnet { +namespace common { + + +template<> +void CastStorageDispatch(mshadow::Stream* s, + const NDArray& input, + const NDArray& output) { + mxnet::op::CastStorageComputeImpl(s, input, output); +} + + +} // namespace common +} // namespace mxnet diff --git a/src/common/utils.cu b/src/common/utils.cu new file mode 100644 index 000000000000..7221a2b6ec6c --- /dev/null +++ b/src/common/utils.cu @@ -0,0 +1,21 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file utils.cu + * \brief gpu implementation of util functions + */ + +#include "./utils.h" +#include "../operator/tensor/cast_storage-inl.h" + +namespace mxnet { +namespace common { + +template<> +void CastStorageDispatch(mshadow::Stream* s, + const NDArray& input, + const NDArray& output) { + mxnet::op::CastStorageComputeImpl(s, input, output); +} + +} // namespace common +} // namespace mxnet diff --git a/src/common/utils.h b/src/common/utils.h index 5f50aab4781f..95ddc240cbf6 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -6,7 +6,13 @@ #ifndef MXNET_COMMON_UTILS_H_ #define MXNET_COMMON_UTILS_H_ -#if DMLC_USE_CXX11 +#include +#include +#include +#include +#include +#include + #include #include #include @@ -15,15 +21,99 @@ #include #include #include -#endif // DMLC_USE_CXX11 - -#include -#include +#include namespace mxnet { + namespace common { -#if DMLC_USE_CXX11 +template +void CastStorageDispatch(mshadow::Stream* s, const NDArray& input, const NDArray& output); + +/* + * \brief Get the corresponding tensor blobs from default storage NDArrays. + * If any NDArray is of non-default storage, it is casted to default storage and + * the temporary NDArrays are stored in `temps`. When storage_fallback is false, + * and `MXNET_EXEC_STORAGE_FALLBACK` == 0, storage fallback is disallowed. + * \return true if any input is casted + */ +template +inline bool GetDefaultBlobs(const std::vector& nds, + std::vector *blobs, + std::vector *temps, + const OpContext& ctx, + bool storage_fallback = false) { + bool casted = false; + if (storage_fallback == false) { + storage_fallback = dmlc::GetEnv("MXNET_EXEC_STORAGE_FALLBACK", true); + } + for (auto& nd : nds) { + if (nd.storage_type() != kDefaultStorage) { + if (storage_fallback == false) { + LOG(FATAL) << "Storage type conversion detected during execution. " + << "You are probably executing an operator which " + << "doesn't support NDArray inputs with non-default storage."; + } + NDArray temp(nd.shape(), nd.ctx(), false); + CastStorageDispatch(ctx.get_stream(), nd, temp); + temps->push_back(temp); + blobs->push_back(temp.data()); + casted = true; + } else { + blobs->push_back(nd.data()); + } + } + return casted; +} + +/* + * \brief Cast the NDArrays in `src` according to the storage types of the NDArrays + * in `dst`. The ones with default storage in `dst` are ignored. + * When storage_fallback is false, and `MXNET_EXEC_STORAGE_FALLBACK` == 0, + * storage fallback is disallowed. + */ +template +inline void CastNonDefaultStorage(const std::vector& dst, + const std::vector& src, + const OpContext& ctx, + bool storage_fallback = false) { + CHECK_GE(dst.size(), src.size()); + if (src.size() == 0) return; + if (storage_fallback == false) { + storage_fallback = dmlc::GetEnv("MXNET_EXEC_STORAGE_FALLBACK", true); + } + size_t src_idx = 0; + for (size_t i = 0; i < dst.size(); i++) { + auto stype = dst[i].storage_type(); + if (stype != kDefaultStorage) { + if (storage_fallback == false) { + LOG(FATAL) << "Storage type conversion detected during execution. " + << "You are probably executing an operator which " + << "doesn't support NDArray inputs with non-default storage."; + } + CastStorageDispatch(ctx.get_stream(), src[src_idx++], dst[i]); + } + } + CHECK_EQ(src_idx, src.size()) << "Not all src NDArrays are casted"; +} + +// Check if any storage type is not default storage +inline bool ContainsNonDefaultStorage(const nnvm::StorageTypeVector& vstorage) { + for (auto& i : vstorage) { + if (i != kUndefinedStorage && i != kDefaultStorage) return true; + } + return false; +} + +inline bool ContainsDefaultStorage(const std::vector& ndarrays) { + for (auto &nd : ndarrays) { + if (nd.storage_type() == kDefaultStorage) { + return true; + } + } + return false; +} + // heuristic to dermine number of threads per GPU inline int GetNumThreadPerGPU() { // This is resource efficient option. @@ -38,6 +128,67 @@ inline int GetExecNumMatchColor() { return std::min(num_match_color, GetNumThreadPerGPU()); } +template +V ParallelAccumulate(const T* a, const int n, V start) { + V sum = start; +#pragma omp parallel for reduction(+:sum) + for (int i = 0; i < n; ++i) { + sum += a[i]; + } + return sum; +} + +/*! + * \brief + * Helper function for ParallelSort. + * DO NOT call this function directly. + * Use the interface ParallelSort instead. + * Ref: https://github.com/dmlc/difacto/blob/master/src/common/parallel_sort.h + */ +template +void ParallelSortHelper(RandomIt first, size_t len, + size_t grainsize, const Compare& comp) { + if (len < grainsize) { + std::sort(first, first+len, comp); + } else { + std::thread thr(ParallelSortHelper, first, len/2, grainsize, comp); + ParallelSortHelper(first+len/2, len - len/2, grainsize, comp); + thr.join(); + std::inplace_merge(first, first+len/2, first+len, comp); + } +} + +/*! + * \brief + * Sort the elements in the range [first, last) into the ascending order defined by + * the comparator comp. + * If the length of the range [first, last) is greater than a certain threshold, + * the range will be recursively divided into two and assign two threads + * to sort each half range. + * Ref: https://github.com/dmlc/difacto/blob/master/src/common/parallel_sort.h + */ +template +void ParallelSort(RandomIt first, RandomIt last, size_t num_threads, Compare comp) { + const auto num = std::distance(first, last); + size_t grainsize = std::max(num / num_threads + 5, static_cast(1024*16)); + ParallelSortHelper(first, num, grainsize, comp); +} + +/*! + * \brief + * Sort the elements in the range [first, last) into ascending order. + * The elements are compared using the default < operator. + * If the length of the range [first, last) is greater than a certain threshold, + * the range will be recursively divided into two and assign two threads + * to sort each half range. + * Ref: https://github.com/dmlc/difacto/blob/master/src/common/parallel_sort.h + */ +template +void ParallelSort(RandomIt first, RandomIt last, size_t num_threads) { + ParallelSort(first, last, num_threads, + std::less::value_type>()); +} + /*! * \brief Random Engine */ @@ -141,8 +292,6 @@ FCompType GetFCompute(const nnvm::Op* op, const std::string& name, } } -#endif // DMLC_USE_CXX11 - } // namespace common } // namespace mxnet #endif // MXNET_COMMON_UTILS_H_ diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index 6a0c489a1ec5..af9a6fcbe36e 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -22,28 +22,40 @@ const OperatorProperty* OpPropGetOpProperty(const NodeAttrs& attrs); namespace exec { -// forward executor +// stateful compute executor class StatefulComputeExecutor : public OpExecutor { public: - void Run(RunContext rctx) override { + void Run(RunContext rctx, bool is_gpu) override { + using namespace common; + // TODO(haibin) avoid repeating this if all inputs are already in default-storage op_ctx.run_ctx = rctx; - fcompute_(state_, op_ctx, in_data_, req, out_data_); + in_data_.clear(); + out_data_.clear(); + temp_in_.clear(); + temp_out_.clear(); + if (is_gpu) { +#if MXNET_USE_CUDA + GetDefaultBlobs(in_array, &in_data_, &temp_in_, op_ctx); + GetDefaultBlobs(out_array, &out_data_, &temp_out_, op_ctx); + fcompute_(state_, op_ctx, in_data_, req, out_data_); + CastNonDefaultStorage(out_array, temp_out_, op_ctx); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + GetDefaultBlobs(in_array, &in_data_, &temp_in_, op_ctx); + GetDefaultBlobs(out_array, &out_data_, &temp_out_, op_ctx); + fcompute_(state_, op_ctx, in_data_, req, out_data_); + CastNonDefaultStorage(out_array, temp_out_, op_ctx); + } #if MKL_EXPERIMENTAL == 1 + //TODO(haibin) handle MKL mem with non-default NDArray mkl_tblobs_prv_to_cpu(in_data_); mkl_tblobs_prv_to_cpu(out_data_); #endif } - void Setup() override { - in_data_.clear(); - for (size_t i = 0; i < in_array.size(); ++i) { - in_data_.push_back(in_array[i].data()); - } - out_data_.clear(); - for (size_t i = 0; i < out_array.size(); ++i) { - out_data_.push_back(out_array[i].data()); - } - } + void Setup() override {} ExecType exec_type() const override { return exec_type_; @@ -64,10 +76,11 @@ class StatefulComputeExecutor : public OpExecutor { FStatefulCompute fcompute_; ExecType exec_type_; std::vector in_data_, out_data_; + std::vector temp_in_, temp_out_; }; -// forward executor +// stateful compute_ex executor class StatefulComputeExExecutor : public OpExecutor { public: void Run(RunContext rctx) override { @@ -98,27 +111,40 @@ class StatefulComputeExExecutor : public OpExecutor { }; -// fcompute executor executor +// fcompute executor class FComputeExecutor : public OpExecutor { public: void Run(RunContext rctx) override { + using namespace common; + // TODO(haibin) avoid repeating this if all inputs are already in default-storage op_ctx.run_ctx = rctx; - fcompute_(attrs_, op_ctx, in_data_, req, out_data_); + in_data_.clear(); + out_data_.clear(); + temp_in_.clear(); + temp_out_.clear(); + if (is_gpu) { +#if MXNET_USE_CUDA + GetDefaultBlobs(in_array, &in_data_, &temp_in_, op_ctx); + GetDefaultBlobs(out_array, &out_data_, &temp_out_, op_ctx); + fcompute_(attrs_, op_ctx, in_data_, req, out_data_); + CastNonDefaultStorage(out_array, temp_out_, op_ctx); +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; +#endif + } else { + GetDefaultBlobs(in_array, &in_data_, &temp_in_, op_ctx); + GetDefaultBlobs(out_array, &out_data_, &temp_out_, op_ctx); + fcompute_(attrs_, op_ctx, in_data_, req, out_data_); + CastNonDefaultStorage(out_array, temp_out_, op_ctx); + } #if MKL_EXPERIMENTAL == 1 + //TODO(haibin) handle MKL mem with non-default NDArray mkl_tblobs_prv_to_cpu(in_data_); mkl_tblobs_prv_to_cpu(out_data_); #endif } - void Setup() override { - in_data_.resize(in_array.size()); - out_data_.resize(out_array.size()); - auto get_blob = [](const NDArray& nd) { - return nd.data(); - }; - std::transform(in_array.begin(), in_array.end(), in_data_.begin(), get_blob); - std::transform(out_array.begin(), out_array.end(), out_data_.begin(), get_blob); - } + void Setup() override {} ExecType exec_type() const override { return exec_type_; @@ -134,6 +160,32 @@ class FComputeExecutor : public OpExecutor { FCompute fcompute_; ExecType exec_type_; std::vector in_data_, out_data_; + std::vector temp_in_, temp_out_; +}; + +// fcompute_ex executor +class FComputeExExecutor : public OpExecutor { + public: + void Run(RunContext rctx) override { + op_ctx.run_ctx = rctx; + fcompute_(attrs_, op_ctx, in_array, req, out_array); + } + + void Setup() override {} + + ExecType exec_type() const override { + return exec_type_; + } + + explicit FComputeExExecutor(const NodeAttrs& attrs, FComputeEx fcompute, + ExecType exec_type) + : attrs_(attrs), fcompute_(fcompute), exec_type_(exec_type) { + } + + private: + NodeAttrs attrs_; + FComputeEx fcompute_; + ExecType exec_type_; }; // pass to attach operator executors @@ -152,6 +204,8 @@ Graph AttachOpExecs(Graph g) { const auto& vctx = g.GetAttr("context"); const auto& saved_states = g.GetAttr< std::unordered_map >("saved_states"); + const auto& dispatch_stypes = g.GetAttr("dispatch_stypes"); + // get the graph const auto& idx = g.indexed_graph(); @@ -221,11 +275,15 @@ Graph AttachOpExecs(Graph g) { } } else { FCompute fcompute = common::GetFCompute(op, "FCompute", vctx[i]); - if (fcompute != nullptr) { + FComputeEx fcomp_ex = common::GetFCompute(op, "FComputeEx", vctx[i]); + if (fcomp_ex != nullptr && dispatch_stypes[i] != kDefaultStorage) { + ret[i] = std::make_shared( + inode.source->attrs, fcomp_ex, exec_type); + } else if (fcompute != nullptr) { ret[i] = std::make_shared( inode.source->attrs, fcompute, exec_type); } else { - LOG(INFO) << "FCompute not registered " << op->name; + LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name; } } } diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index 76b02de736e9..c51123214a98 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -10,8 +10,10 @@ #include #include #include +#include #include #include +#include namespace mxnet { namespace exec { @@ -19,6 +21,12 @@ namespace exec { /*! \brief reuse graph definition */ using nnvm::Graph; +const int kBadStorageID = -1; +const int kExternalStorageID = -2; +const int kDynamicStorageID = -3; + +const int kNonDefaultStorage = -2; + /*! * \brief executor to execute an operator * This is a graph executor dependent interface @@ -26,7 +34,7 @@ using nnvm::Graph; */ class OpExecutor { public: - /*! \brief input arrays */ + /*! \brief input data arrays, which may be either input or aux */ std::vector in_array; /*! \brief output data arrays */ std::vector out_array; @@ -47,7 +55,7 @@ class OpExecutor { * This function call do not synchronize the stream. * \param rctx The runtime context passed in by environment. */ - virtual void Run(RunContext rctx) = 0; + virtual void Run(RunContext rctx, bool is_gpu) = 0; /*! \return the execution type */ virtual ExecType exec_type() const = 0; /*! \return return engine variable for operator states */ @@ -105,6 +113,45 @@ Graph AttachOpResources(Graph g); */ Graph DetectInplaceAddTo(Graph g); +/*! + * \brief Infer shapes in the graph given the information. + * \param graph The input graph. + * \param shape_inputs The shapes of input symbols to the graph. + * \param shape_attr_key The key to the node attribute that can indicate shape. This is + * the place where manual hint for shapes could be injected. + * \return A graph with new attribute "shape" containing inferred shape of each NodeEntry. + * The index of ShapeVector is given by graph.indexed_graph().entry_id. + */ +Graph InferShape(Graph graph, + nnvm::ShapeVector shape_inputs, + const std::string& shape_attr_key = ""); + +/*! + * \brief Infer types in the graph given the information. + * \param graph The input graph. + * \param dtype_inputs The types of input symbols to the graph. + * \param dtype_attr_key The key to the node attribute that can indicate types. This is + * the place where manual hint for types could be injected. + * \return A graph with new attribute "dtype" containing inferred type of each NodeEntry. + * The index of ShapeVector is given by graph.indexed_graph().entry_id. + */ +Graph InferType(Graph graph, + nnvm::DTypeVector dtype_inputs, + const std::string& dtype_attr_key = ""); + +/*! + * \brief Infer storage types in the graph given the information. + * \param graph The input graph. + * \param storage_type_inputs The storage types of input symbols to the graph. + * \param storage_type_attr_key The key to the node attribute that can indicate storage types. + This is the place where manual hint for types could be injected. + * \return A graph with new attribute "storage_type" containing inferred type of each NodeEntry. + * The index of StorageTypeVector is given by graph.indexed_graph().entry_id. + */ +Graph InferStorageType(Graph graph, + nnvm::StorageTypeVector storage_type_inputs, + const std::string& storage_type_attr_key = ""); + } // namespace exec } // namespace mxnet diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index af5ec7f492dd..be1c0c5f2eb4 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -12,6 +12,7 @@ #include "./exec_pass.h" #include "./graph_executor.h" #include "../engine/profiler.h" +#include "../common/utils.h" namespace mxnet { namespace exec { @@ -29,6 +30,30 @@ GraphExecutor::~GraphExecutor() { } } +inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype) { + // NDArray with default storage + if (stype == kDefaultStorage) { + NDArray ret(shape, ctx, false, dtype); + ret = 0; + return ret; + } + // NDArray with non-default storage. Storage allocation is always delayed. + return NDArray(stype, shape, ctx, true, dtype); +} + +inline void EmplaceBackZeros(const NDArrayStorageType stype, const TShape &shape, + const Context &ctx, const int dtype, + std::vector *vec) { + // NDArray with default storage + if (stype == kDefaultStorage) { + vec->emplace_back(shape, ctx, false, dtype); + vec->back() = 0; + } else { + // NDArray with non-default storage. Storage allocation is always delayed. + vec->emplace_back(stype, shape, ctx, true, dtype); + } +} void GraphExecutor::Forward(bool is_train) { RunOps(is_train, 0, num_forward_nodes_); } @@ -420,6 +445,29 @@ void HandleInferTypeError(const size_t num_forward_inputs, << oss.str(); } +void HandleInferStorageTypeError(const size_t num_forward_inputs, + const nnvm::IndexedGraph& idx, + const nnvm::StorageTypeVector& inferred_stypes) { + int cnt = 10; + std::ostringstream oss; + for (size_t i = 0; i < num_forward_inputs; ++i) { + const uint32_t nid = idx.input_nodes().at(i); + const uint32_t eid = idx.entry_id(nid, 0); + const int inferred_stype = inferred_stypes[eid]; + if (inferred_stype == -1) { + const std::string& arg_name = idx[nid].source->attrs.name; + oss << arg_name << ": " << inferred_stype << ", "; + if (--cnt == 0) { + oss << "..."; + break; + } + } + } + LOG(FATAL) << "InferStoragetType pass cannot decide storage type for the following arguments " + "(-1 means unknown stype). Please consider providing them as inputs:\n" + << oss.str(); +} + /*! * \brief GraphExecutor initializer for regular bind flow in which * input arguments and gradients are provided by users. This initializer @@ -457,21 +505,25 @@ void GraphExecutor::Init(nnvm::Symbol symbol, data_entry_.resize(idx.num_node_entries()); nnvm::ShapeVector arg_shapes; nnvm::DTypeVector arg_dtypes; + nnvm::StorageTypeVector arg_stypes; for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const std::string& arg_name = idx[nid].source->attrs.name; + size_t eid = idx.entry_id(nid, 0); if (mutable_nodes.count(nid)) { CHECK_LT(aux_top, aux_states.size()); - data_entry_[idx.entry_id(nid, 0)] = aux_states[aux_top]; + data_entry_[eid] = aux_states[aux_top]; arg_shapes.push_back(aux_states[aux_top].shape()); arg_dtypes.push_back(aux_states[aux_top].dtype()); + arg_stypes.push_back(aux_states[aux_top].storage_type()); aux_state_map_.emplace(arg_name, aux_states[aux_top]); ++aux_top; } else { CHECK_LT(arg_top, in_args.size()); - data_entry_[idx.entry_id(nid, 0)] = in_args[arg_top]; + data_entry_[eid] = in_args[arg_top]; arg_shapes.push_back(in_args[arg_top].shape()); arg_dtypes.push_back(in_args[arg_top].dtype()); + arg_stypes.push_back(in_args[arg_top].storage_type()); in_arg_map_.emplace(arg_name, in_args[arg_top]); if (kNullOp != grad_req_types[arg_top]) { grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_store[arg_top]); @@ -479,23 +531,33 @@ void GraphExecutor::Init(nnvm::Symbol symbol, } ++arg_top; } +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign data entry\t" << eid << " as stype " + << data_entry_[eid].storage_type() << " (input)"; +#endif } // expand arg_shapes and arg_dtypes to contain backward inputs arg_shapes.resize(idx.input_nodes().size(), TShape()); - g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + g = InferShape(std::move(g), arg_shapes, "__shape__"); if (g.GetAttr("shape_num_unknown_nodes") != 0U) { HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("shape")); } arg_dtypes.resize(idx.input_nodes().size(), -1); - g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + g = InferType(std::move(g), arg_dtypes, "__dtype__"); if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("dtype")); } + g = InferStorageType(std::move(g), arg_stypes, "__storage_type__"); + if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { + HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("storage_type")); + } + // Initialize the rest attributes of the graph. // This function can be called by regular bind // operation flow as well. @@ -511,6 +573,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -528,22 +591,37 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const uint32_t eid = idx.entry_id(nid, 0); const TShape& inferred_shape = inferred_shapes[eid]; const int inferred_dtype = inferred_dtypes[eid]; + const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid]; const std::string& arg_name = idx[nid].source->attrs.name; if (mutable_nodes.count(nid)) { // aux_states - aux_state_vec->emplace_back(inferred_shape, aux_state_ctxes[aux_top], false, inferred_dtype); - aux_state_vec->back() = 0; + EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top], + inferred_dtype, aux_state_vec); data_entry_[eid] = aux_state_vec->back(); aux_state_map_.emplace(arg_name, aux_state_vec->back()); ++aux_top; +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign aux entry\t" << eid << "\t as stype " << inferred_stype; +#endif } else { // in_args - in_arg_vec->emplace_back(inferred_shape, in_arg_ctxes[arg_top], false, inferred_dtype); - in_arg_vec->back() = 0; + EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top], + inferred_dtype, in_arg_vec); data_entry_[eid] = in_arg_vec->back(); +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign data entry\t" << eid << "\tas stype " << inferred_stype; +#endif + // Get the storage type for grad if (kNullOp == grad_req_types[arg_top]) { arg_grad_vec->emplace_back(); } else { - arg_grad_vec->emplace_back(inferred_shape, arg_grad_ctxes[arg_top], false, inferred_dtype); - arg_grad_vec->back() = 0; + // Init based on storage type + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; + EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top], + inferred_dtype, arg_grad_vec); +#if EXECUTOR_DEBUG + LOG(INFO) << "\tassign grad entry\t" << grad_eid << "\tas stype " << grad_stype; +#endif grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); arg_grad_map_.emplace(arg_name, arg_grad_vec->back()); } @@ -555,33 +633,40 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, /*! * \brief If the requested ndarray's shape size is less than - * the corresponding shared_data_array's shape size, reuse - * the memory allocation; otherwise, create a zero ndarray. + * the corresponding shared_data_array's shape size and the + * storage type is default storage, reuse the memory allocation + * in shared_buffer; otherwise, create a zero ndarray. */ NDArray ReshapeOrCreate(const std::string& name, const TShape& dest_arg_shape, const int dest_arg_dtype, + const NDArrayStorageType dest_arg_stype, const Context& ctx, std::unordered_map* shared_buffer) { + if (dest_arg_dtype != kDefaultStorage) { + return InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + } auto it = shared_buffer->find(name); if (it != shared_buffer->end()) { if (it->second.shape().Size() >= dest_arg_shape.Size()) { // memory can be reused CHECK_EQ(it->second.dtype(), dest_arg_dtype) << "Requested arg array's dtype does not match the reusable ndarray"; + CHECK_EQ(it->second.storage_type(), kDefaultStorage) + << "shared_buffer should only contain NDArrays with default storage type."; return it->second.Reshape(dest_arg_shape); } else { LOG(WARNING) << "Bucketing: data " << name << " has a shape " << dest_arg_shape << ", which is larger than already allocated shape " << it->second.shape() << ". Need to re-allocate. Consider putting default bucket key to be " << "the bucket taking the largest input for better memory sharing."; - it->second = NDArray(dest_arg_shape, ctx, false, dest_arg_dtype); - it->second = 0; + // the NDArrays in shared_buffer are guaranteed to be of default storage + it->second = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); return it->second; } // arg_array.shape().Size() >= arg_shape.Size() } else { - auto p = shared_buffer->emplace(name, NDArray(dest_arg_shape, ctx, false, dest_arg_dtype)); - p.first->second = 0; - return p.first->second; + auto ret = InitZeros(dest_arg_stype, dest_arg_shape, ctx, dest_arg_dtype); + shared_buffer->emplace(name, ret); + return ret; } // if (it != shared_buffer->end()) } @@ -594,6 +679,7 @@ NDArray ReshapeOrCreate(const std::string& name, void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -613,9 +699,12 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, const uint32_t eid = idx.entry_id(nid, 0); const TShape& inferred_shape = inferred_shapes[eid]; const int inferred_dtype = inferred_dtypes[eid]; + const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid]; const std::string& arg_name = idx[nid].source->attrs.name; - if (mutable_nodes.count(nid)) { // aux_states - if (nullptr != shared_exec) { + // aux_states + if (mutable_nodes.count(nid)) { + if (nullptr != shared_exec && inferred_stype == kDefaultStorage && + shared_exec->aux_state_map().at(arg_name).storage_type() == kDefaultStorage) { const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name); CHECK_EQ(inferred_shape, aux_nd.shape()) << "Inferred shape does not match shared_exec.aux_array's shape." @@ -629,16 +718,18 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, << arg_name << " for the current executor"; aux_state_vec->emplace_back(aux_nd); } else { - aux_state_vec->emplace_back(inferred_shape, aux_state_ctxes[aux_top], - false, inferred_dtype); - aux_state_vec->back() = 0; + EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top], + inferred_dtype, aux_state_vec); } // if (has_shared_exec) data_entry_[eid] = aux_state_vec->back(); aux_state_map_.emplace(arg_name, aux_state_vec->back()); ++aux_top; - } else { // in_args + } else { // in_args and grad for in_args if (shared_arg_names.count(arg_name)) { // model parameter - if (nullptr != shared_exec) { + // model parameter + if (nullptr != shared_exec && inferred_stype == kDefaultStorage && + shared_exec->in_arg_map().at(arg_name).storage_type() == kDefaultStorage) { + // try to reuse memory from shared_exec const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name); CHECK_EQ(inferred_shape, in_arg_nd.shape()) << "Inferred shape does not match shared_exec.arg_array's shape" @@ -651,33 +742,43 @@ void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx, " be resued for creating NDArray of the argument" << arg_name << " for the current executor"; in_arg_vec->emplace_back(in_arg_nd); - if (kNullOp == grad_req_types[arg_top]) { - arg_grad_vec->emplace_back(); - } else { + } else { + // doesn't have shared_exec, or non-default storage + EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top], + inferred_dtype, in_arg_vec); + } + // gradient for model parameter + if (kNullOp == grad_req_types[arg_top]) { + arg_grad_vec->emplace_back(); + } else { + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; + if (nullptr != shared_exec && grad_stype == kDefaultStorage && + shared_exec->arg_grad_map().at(arg_name).storage_type() == kDefaultStorage) { + // try to reuse memory from shared_exec arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name)); - grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); - } // if (kNullOp == grad_req_types[arg_top]) - } else { // !has shared_exec - in_arg_vec->emplace_back(inferred_shape, in_arg_ctxes[arg_top], false, inferred_dtype); - in_arg_vec->back() = 0; - if (kNullOp == grad_req_types[arg_top]) { - arg_grad_vec->emplace_back(); } else { - arg_grad_vec->emplace_back(inferred_shape, arg_grad_ctxes[arg_top], - false, inferred_dtype); - arg_grad_vec->back() = 0; - grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); - } // if (kNullOp == grad_req_types[arg_top]) - } // if (has_shared_exec) + EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top], + inferred_dtype, arg_grad_vec); + } + grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); + } } else { // !shared_arg_names.count(arg_name) + // model parameter in_arg_vec->emplace_back(ReshapeOrCreate(arg_name, inferred_shape, inferred_dtype, - in_arg_ctxes[arg_top], shared_buffer)); + inferred_stype, in_arg_ctxes[arg_top], + shared_buffer)); + // gradient for model parameter if (kNullOp == grad_req_types[arg_top]) { arg_grad_vec->emplace_back(); } else { + auto grad_oid = grad_store_.size() + num_forward_outputs_; + auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]); + auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid]; arg_grad_vec->emplace_back(ReshapeOrCreate("grad of " + arg_name, inferred_shape, - inferred_dtype, arg_grad_ctxes[arg_top], - shared_buffer)); + inferred_dtype, grad_stype, + arg_grad_ctxes[arg_top], shared_buffer)); grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back()); } // if (kNullOp == grad_req_types[arg_top]) } // if (shared_arg_names.count(arg_name)) @@ -700,14 +801,35 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, Executor* shared_exec, const nnvm::NodeEntryMap& feed_dict) { const auto& idx = g.indexed_graph(); + // dispatch based on stype per operator + const auto& vstorage_type = g.GetAttr("storage_type"); + nnvm::StorageTypeVector dispatch_stypes(idx.num_nodes(), kUndefinedStorage); + for (size_t nid = 0; nid < idx.num_nodes(); nid++) { + const auto& inode = idx[nid]; + auto num_outputs = inode.source->num_outputs(); + auto num_inputs = inode.inputs.size(); + nnvm::StorageTypeVector vs(num_inputs + num_outputs, kUndefinedStorage); + for (size_t i = 0; i < num_inputs; i++) { + auto e = inode.inputs[i]; + vs[i] = vstorage_type[idx.entry_id(e)]; + CHECK_NE(vs[i], kUndefinedStorage); + } + for (uint32_t i = 0; i < num_outputs; ++i) { + uint32_t eid = idx.entry_id(nid, i); + vs[i + num_inputs] = vstorage_type[eid]; + } + bool contains_non_default = common::ContainsNonDefaultStorage(vs); + dispatch_stypes[nid] = contains_non_default ? kNonDefaultStorage : kDefaultStorage; + } + g.attrs["dispatch_stypes"] = std::make_shared(std::move(dispatch_stypes)); + + // data entries for output gradients for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { data_entry_[idx.entry_id(idx.outputs()[j])] = grad_store_[j - num_forward_outputs_].second; } { // memory allocator - const int kBadStorageID = -1; - const int kExternalStorageID = -2; nnvm::StorageVector arg_storage_id(idx.num_node_entries(), kBadStorageID); for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) { arg_storage_id[idx.entry_id(idx.outputs()[j])] = kExternalStorageID; @@ -717,6 +839,9 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol, data_entry_[eid] = kv.second; arg_storage_id[eid] = kExternalStorageID; } + for (size_t i = 0; i < idx.num_node_entries(); i++) { + if (vstorage_type[i] != kDefaultStorage) arg_storage_id[i] = kDynamicStorageID; + } g.attrs["storage"] = std::make_shared(std::move(arg_storage_id)); g = nnvm::ApplyPass(g, "PlanMemory"); } @@ -774,6 +899,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const std::vector& aux_state_ctxes, const std::unordered_map& arg_shape_map, const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, const std::vector& grad_req_types, const std::unordered_set& shared_arg_names, std::vector* in_arg_vec, @@ -793,6 +919,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol, const nnvm::IndexedGraph& idx = g.indexed_graph(); nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape()); nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1); + nnvm::StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage); for (size_t i = 0; i < num_forward_inputs_; ++i) { const uint32_t nid = idx.input_nodes().at(i); const std::string& name = idx[nid].source->attrs.name; @@ -804,29 +931,41 @@ void GraphExecutor::Init(nnvm::Symbol symbol, if (arg_dtype_map.end() != it2) { arg_dtypes[i] = it2->second; } + auto it3 = arg_stype_map.find(name); + if (arg_stype_map.end() != it3) { + arg_stypes[i] = it3->second; + } } - g = nnvm::pass::InferShape(g, arg_shapes, "__shape__"); + g = InferShape(std::move(g), arg_shapes, "__shape__"); if (g.GetAttr("shape_num_unknown_nodes") != 0U) { HandleInferShapeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("shape")); } - g = nnvm::pass::InferType(g, arg_dtypes, "__dtype__"); + g = InferType(std::move(g), arg_dtypes, "__dtype__"); if (g.GetAttr("dtype_num_unknown_nodes") != 0U) { HandleInferTypeError(num_forward_inputs_, g.indexed_graph(), g.GetAttr("dtype")); } + g = InferStorageType(std::move(g), arg_stypes, "__storage_type__"); + if (g.GetAttr("storage_type_num_unknown_nodes") != 0U) { + HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(), + g.GetAttr("storage_type")); + } + // Create in_args, arg_grads, and aux_states using // the inferred shapes and dtypes. if (nullptr == shared_buffer) { // regular simple bind InitArguments(idx, g.GetAttr("shape"), g.GetAttr("dtype"), + g.GetAttr("storage_type"), in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec); } else { // simple bind using shared data arrays and shared_exec InitArguments(idx, g.GetAttr("shape"), g.GetAttr("dtype"), + g.GetAttr("storage_type"), in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, grad_req_types, shared_arg_names, shared_exec, shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec); @@ -879,6 +1018,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, // initialize the memory of each entries void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { using nnvm::DTypeVector; + using nnvm::StorageTypeVector; using nnvm::ShapeVector; using nnvm::StorageVector; // get the graph @@ -887,20 +1027,29 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { const auto& vdtype = graph_.GetAttr("dtype"); const auto& vshape = graph_.GetAttr("shape"); const auto& vstorage = graph_.GetAttr("storage_id"); + const auto& vstorage_type = graph_.GetAttr("storage_type"); const auto& vctx = graph_.GetAttr("context"); CHECK_EQ(idx.num_node_entries(), vshape.size()); CHECK_EQ(idx.num_node_entries(), vdtype.size()); CHECK_EQ(idx.num_node_entries(), vstorage.size()); CHECK_EQ(data_entry_.size(), vshape.size()); std::vector data_context(idx.num_node_entries()); + std::vector data_storage_type(idx.num_node_entries(), kUndefinedStorage); for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t i = 0; i < idx[nid].source->num_outputs(); ++i) { - data_context[idx.entry_id(nid, i)] = vctx[nid]; + auto eid = idx.entry_id(nid, i); + data_context[eid] = vctx[nid]; + CHECK_NE(vstorage_type[nid], kUndefinedStorage); + data_storage_type[eid] = (NDArrayStorageType) vstorage_type[nid]; } } // information about the pool - using PoolEntry = std::pair; + struct PoolEntry { + Context ctx; + size_t bytes; + NDArrayStorageType stype; + }; std::vector pool_info; // assign array to head gradient @@ -908,26 +1057,36 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { uint32_t nid = idx.input_nodes().at(i); uint32_t oid = head_grad_map_.at(idx[nid].source); uint32_t eid = idx.entry_id(idx.outputs()[oid]); + NDArrayStorageType stype = (NDArrayStorageType) vstorage_type[eid]; CHECK_NE(vshape[eid].ndim(), 0U); CHECK_NE(vdtype[eid], -1); - data_entry_[idx.entry_id(nid, 0)] = - NDArray(vshape[eid], data_context[eid], false, vdtype[eid]); + auto data_eid = idx.entry_id(nid, 0); + // initialize based on storage_type + if (stype != kDefaultStorage) { + data_entry_[data_eid] = NDArray(stype, vshape[eid], data_context[eid], true, vdtype[eid]); + } else { + data_entry_[data_eid] = NDArray(vshape[eid], data_context[eid], false, vdtype[eid]); + } +#if EXECUTOR_DEBUG + LOG(INFO) << "\tinit head_g entry\t" << data_eid << "\tas stype " << stype; +#endif } // get maximum bytes in each pool for (size_t i = 0; i < vshape.size(); ++i) { if (!data_entry_[i].is_none()) continue; size_t bytes = vshape[i].Size() * mshadow::mshadow_sizeof(vdtype[i]); int storage_id = vstorage[i]; + // skip pool allocation for kBadStorageID, kExternalStorageID and kDynamicStorageID if (storage_id < 0) continue; size_t sid = static_cast(storage_id); if (sid >= pool_info.size()) { - pool_info.resize(sid + 1, PoolEntry{Context::CPU(), size_t(0)}); + pool_info.resize(sid + 1, PoolEntry{Context::CPU(), size_t(0), kUndefinedStorage}); } PoolEntry& info = pool_info[sid]; - if (info.second == 0) { - info = PoolEntry{data_context[i], bytes}; + if (info.bytes == 0) { + info = PoolEntry{data_context[i], bytes, data_storage_type[i]}; } else { - info.second = std::max(info.second, bytes); + info.bytes = std::max(info.bytes, bytes); } } // construct the re-use pool, if needed @@ -948,13 +1107,14 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { sorted_pool_index.push_back(i); } auto pool_comparator = [&pool_info](int lhs, int rhs){ - return pool_info[lhs].second > pool_info[rhs].second; + return pool_info[lhs].bytes > pool_info[rhs].bytes; }; std::sort(sorted_pool_index.begin(), sorted_pool_index.end(), pool_comparator); for (size_t i : sorted_pool_index) { - const Context& ctx = pool_info[i].first; - size_t bytes = pool_info[i].second; + const Context& ctx = pool_info[i].ctx; + size_t bytes = pool_info[i].bytes; + NDArrayStorageType storage_type = pool_info[i].stype; bool allocated = false; for (auto it = free_pool.lower_bound(bytes); it != free_pool.end(); ++it) { if (it->second.ctx() == ctx && it->first >= bytes) { @@ -979,15 +1139,22 @@ void GraphExecutor::InitDataEntryMemory(std::vector* shared_pool) { } CHECK_EQ(data_pool_.size(), pool_info.size()); // assign the data entries - for (size_t i = 0; i < data_entry_.size(); ++i) { // avoid pre-allocated arrays if (!data_entry_[i].is_none()) continue; // assign allocated array by storage id int storage_id = vstorage[i]; - CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet"; - const NDArray& src = data_pool_.at(storage_id); - data_entry_[i] = src.AsArray(vshape[i], vdtype[i]); + auto storage_type = (NDArrayStorageType) vstorage_type[i]; + if (storage_type == kDefaultStorage) { + CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet"; + const NDArray& src = data_pool_.at(storage_id); + data_entry_[i] = src.AsArray(vshape[i], vdtype[i]); + } else { + data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i]); + } +#if EXECUTOR_DEBUG + LOG(INFO) << "\tinit data entry\t" << i << "\tas stype " << storage_type; +#endif } } @@ -1002,11 +1169,28 @@ void GraphExecutor::InitCachedOps() { const auto& vctx = graph_.GetAttr("context"); const auto& addto_entry = graph_.GetAttr >("addto_entry"); const auto& skip_plus_node = graph_.GetAttr >("skip_plus_node"); + const auto& vstorage_type = graph_.GetAttr("storage_type"); op_nodes_.resize(idx.num_nodes()); // setup the array and requirements. for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; +#if EXECUTOR_DEBUG + if (inode.source->is_variable()) { + LOG(INFO) << "node " << nid << " var"; + } else { + LOG(INFO) << "node " << nid << " " << inode.source->attrs.op->name; + auto exec = op_execs[nid]; + for (const auto& e : inode.inputs) { + auto eid = idx.entry_id(e); + LOG(INFO) << "\t\tinput " << eid << " stype: " << vstorage_type[eid]; + } + for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { + uint32_t eid = idx.entry_id(nid, index); + LOG(INFO) << "\t\toutput " << eid << " stype: " << vstorage_type[eid]; + } + } +#endif if (inode.source->is_variable()) continue; #if MXNET_USE_PROFILER op_nodes_[nid].opr_name = inode.source->op()->name.c_str(); @@ -1086,7 +1270,7 @@ void GraphExecutor::InitCachedOps() { if (is_async) { exec->op_ctx.async_on_complete = on_complete; } - exec->Run(ctx); + exec->Run(ctx, is_gpu); // call on complete only if it is async op if (!is_async) { if (is_gpu) { @@ -1230,6 +1414,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning; #else bool profiling = false; +#endif +#if EXECUTOR_DEBUG + LOG(INFO) << "Run node " << nid << " - " << seg_op.topo_end - 1; #endif Engine::Get()->Push(seg_op.opr, seg_op.ctx, 0, profiling); nid = seg_op.topo_end - 1; @@ -1253,6 +1440,9 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning; #else bool profiling = false; +#endif +#if EXECUTOR_DEBUG + LOG(INFO) << "Run node " << nid; #endif Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling); } else { @@ -1317,7 +1507,7 @@ GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, RunContext ctx, Engine::CallbackOnComplete on_complete) { // Run all opr in the sub-graph for (auto &exec : exec_list) { - exec->Run(ctx); + exec->Run(ctx, is_gpu); } if (is_gpu) { #if MXNET_USE_CUDA @@ -1352,6 +1542,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, const std::vector& aux_state_ctxes, const std::unordered_map& arg_shape_map, const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, const std::vector& grad_req_types, const std::unordered_set& shared_arg_names, std::vector* in_args, @@ -1362,7 +1553,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, auto exec = new exec::GraphExecutor(); exec->Init(symbol, default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, - arg_shape_map, arg_dtype_map, + arg_shape_map, arg_dtype_map, arg_stype_map, grad_req_types, shared_arg_names, in_args, arg_grads, aux_states, shared_buffer, shared_exec); diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index 5b6fa395b242..6c9d8350774b 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -19,6 +19,8 @@ #include #include "./exec_pass.h" +#define EXECUTOR_DEBUG 0 + namespace mxnet { // forward declaration @@ -78,6 +80,7 @@ class GraphExecutor : public Executor { const std::vector& aux_state_ctxes, const std::unordered_map& arg_shape_map, const std::unordered_map& arg_dtype_map, + const std::unordered_map& arg_stype_map, const std::vector& grad_req_types, const std::unordered_set& shared_arg_names, std::vector* in_arg_vec, @@ -123,6 +126,7 @@ class GraphExecutor : public Executor { void InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -135,6 +139,7 @@ class GraphExecutor : public Executor { void InitArguments(const nnvm::IndexedGraph& idx, const nnvm::ShapeVector& inferred_shapes, const nnvm::DTypeVector& inferred_dtypes, + const nnvm::StorageTypeVector& inferred_stypes, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, @@ -183,7 +188,8 @@ class GraphExecutor : public Executor { std::vector op_nodes_; // internal data entry of each node std::vector data_entry_; - // internal data pool of allocated entries + // internal data pool of allocated entries. + // these allocated entries can be used for static memory sharing between executors. std::vector data_pool_; // output arrays std::vector output_arrays_; diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc new file mode 100644 index 000000000000..3789c313bf18 --- /dev/null +++ b/src/executor/infer_graph_attr_pass.cc @@ -0,0 +1,337 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file infer_graph_attr_pass.cc + * \brief infer graph shape, dtype, and storage type + */ + +#include +#include "./exec_pass.h" + +namespace mxnet { +namespace exec { + +template +bool ApplyOpInferAttr(const nnvm::Graph& g, + const FInfer& finfer, + const NodeAttrs& attrs, + const uint32_t nid, + std::vector* in_attrs, + std::vector* out_attrs) { + return finfer(attrs, in_attrs, out_attrs); +} + +template<> +bool ApplyOpInferAttr(const nnvm::Graph& g, + const FInferStorageType& finfer, + const NodeAttrs& attrs, + const uint32_t nid, + std::vector* in_attrs, + std::vector* out_attrs) { + const ContextVector& ctxes = g.GetAttr("context"); + return finfer(attrs, ctxes[nid], in_attrs, out_attrs); +} + +/*!\brief + * This is a duplicate of the InferAttr function in nnvm with minor modification + * to support inferring storage type whose function signature is different from + * shape/type inference functions'. The nnvm InferAttr will be deprecated + * in the future. Please use interfaces InferShape, InferType, and InferStorageType + * to call this function. + */ +template +nnvm::Graph InferAttr(nnvm::Graph &&ret, + const AttrType empty_val, + const char* infer_name, + const char* input_name, + const char* attr_key_name, + const char* attr_name, + const char* unknown_name, + IsNone fis_none, + FDefault fdefault, + bool backward_identity_assign) { + using nnvm::IndexedGraph; + using nnvm::Op; + using AttrVector = std::vector; + using dmlc::any; + + const IndexedGraph& idx = ret.indexed_graph(); + static auto& finfer_shape = + Op::GetAttr(infer_name); + static auto& is_backward = + Op::GetAttr("TIsBackward"); + // gradient function, used to get node correspondence. + static auto& fgrad = + Op::GetAttr("FGradient"); + // reshape shape vector + AttrVector rshape; + if (ret.attrs.count(attr_name) != 0) { + rshape = ret.MoveCopyAttr(attr_name); + } else { + rshape.resize(idx.num_node_entries(), empty_val); + } + + if (ret.attrs.count(input_name) != 0) { + const AttrVector& shape_args = ret.GetAttr(input_name); + CHECK_LE(shape_args.size(), idx.input_nodes().size()) + << "More provided " << attr_name << "s than number of arguments."; + for (size_t i = 0; i < shape_args.size(); ++i) { + rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i]; + } + // erase the provided arguments + ret.attrs.erase(input_name); + } + + // get the shape hints + std::string shape_hints_key = std::string(attr_name) + "_hints"; + if (ret.attrs.count(shape_hints_key)) { + nnvm::NodeEntryMap shape_hints = + ret.GetAttr>(shape_hints_key); + for (const auto& kv : shape_hints) { + nnvm::NodeEntry e = kv.first; + if (idx.exist(e.node.get())) { + rshape[idx.entry_id(kv.first)] = kv.second; + } + } + } + + std::string shape_attr_key; + if (ret.attrs.count(attr_key_name) != 0) { + shape_attr_key = ret.GetAttr(attr_key_name); + // erase the provided arguments + ret.attrs.erase(attr_key_name); + } + // Temp space for shape inference. + std::vector ishape, oshape; + + // inference step function for nid + auto infer_step = [&](uint32_t nid, bool last_iter) { + const auto& inode = idx[nid]; + const uint32_t num_inputs = inode.inputs.size(); + const uint32_t num_outputs = inode.source->num_outputs(); + if (inode.source->is_variable()) { + // Variable node. No operator. Only one output entry. + CHECK(inode.source->op() == nullptr); + CHECK_EQ(num_outputs, 1U); + const uint32_t out_ent_id = idx.entry_id(nid, 0); + if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) { + auto it = inode.source->attrs.dict.find(shape_attr_key); + if (it != inode.source->attrs.dict.end()) { + std::istringstream is(it->second); + CHECK(is >> rshape[out_ent_id]) << "Invalid attribute"; + } + } + } else if (is_backward.get(inode.source->op(), false) && + inode.control_deps.size() && backward_identity_assign) { + CHECK_GE(inode.control_deps.size(), 1U) + << "BackwardOp need to have control_deps to its forward op"; + const IndexedGraph::Node& fnode = idx[inode.control_deps[0]]; + nnvm::NodePtr fwd_ptr = inode.source->control_deps[0]; + CHECK(fwd_ptr->op() != nullptr) << "Forward op cannot be a variable"; + // use gradient function to find out the correspondence. + std::vector ograd(fwd_ptr->num_outputs()); + for (size_t i = 0; i < ograd.size(); ++i) { + ograd[i].index = static_cast(i); + } + // input gradient list + auto igrad = fgrad[fwd_ptr->op()](fwd_ptr, ograd); + const nnvm::Node* igrad_node = nullptr; + // Input gradient assignement + for (size_t i = 0; i < igrad.size(); ++i) { + if (igrad[i].node->op() == inode.source->op()) { + uint32_t eid = idx.entry_id(nid, igrad[i].index); + if (fis_none(rshape[eid])) { + rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; + } else { + CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) + << "Backward shape inconsistent with the forward shape"; + } + if (igrad_node == nullptr) { + igrad_node = igrad[i].node.get(); + } else { + CHECK(igrad_node == igrad[i].node.get()); + } + } + } + // out grad entries + CHECK(igrad_node != nullptr) + << "Cannot find matching backward op for " << inode.source->attrs.name; + for (size_t i = 0; i < igrad_node->inputs.size(); ++i) { + const nnvm::NodeEntry& e = igrad_node->inputs[i]; + if (e.node == nullptr) { + uint32_t eid = idx.entry_id(inode.inputs[i]); + if (fis_none(rshape[eid])) { + rshape[eid] = rshape[idx.entry_id(inode.control_deps[0], e.index)]; + } + } + } + } else { + bool forward_known = true; + // Forward operator inference. + ishape.resize(num_inputs, empty_val); + for (uint32_t i = 0; i < ishape.size(); ++i) { + ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; + if (fis_none(ishape[i])) forward_known = false; + } + oshape.resize(num_outputs, empty_val); + for (uint32_t i = 0; i < oshape.size(); ++i) { + oshape[i] = rshape[idx.entry_id(nid, i)]; + if (fis_none(oshape[i])) forward_known = false; + } + auto finfer = finfer_shape.get(inode.source->op(), fdefault); + if (!forward_known) { + if (finfer != nullptr) { + // Call inference function of the operator. + try { + forward_known = ApplyOpInferAttr(ret, finfer, inode.source->attrs, + nid, &ishape, &oshape); + } catch (const std::exception& e) { + throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what()); + } + } else { + CHECK(!last_iter) + << "Attribute " << infer_name + << " is not registed by op " << inode.source->op()->name + << " we are not able to complete the inference because of this"; + } + } + // Save to the result map. + for (uint32_t i = 0; i < num_inputs; ++i) { + rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; + } + for (uint32_t i = 0; i < num_outputs; ++i) { + rshape[idx.entry_id(nid, i)] = oshape[i]; + } + } + }; + + size_t last_num_unknown; + size_t num_unknown = rshape.size(); + int i = 0; + do { + if (i % 2 == 0) { + for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { + infer_step(nid, false); + } + } else { + // backward inference + for (uint32_t i = idx.num_nodes(); i != 0; --i) { + infer_step(i - 1, false); + } + } + last_num_unknown = num_unknown; + num_unknown = 0; + for (size_t j = 0; j < idx.num_node_entries(); ++j) { + if (fis_none(rshape[j])) { + ++num_unknown; + } + } + ++i; + } while (num_unknown > 0 && last_num_unknown > num_unknown); + // set the shapes + ret.attrs[attr_name] = std::make_shared(std::move(rshape)); + // number of nodes who knows the shape. + ret.attrs[unknown_name] = std::make_shared(num_unknown); + return ret; +} + +// inference fucntion for same type +inline bool SameType(const nnvm::NodeAttrs& attrs, + std::vector *iattr, + std::vector *oattr) { + int def_v = -1; + for (int v : *oattr) { + if (v != -1) { + def_v = v; break; + } + } + if (def_v == -1) { + for (int v : *iattr) { + if (v != -1) { + def_v = v; break; + } + } + } + if (def_v == -1) return false; + for (int& v : *oattr) { + v = def_v; + } + for (int& v : *iattr) { + v = def_v; + } + return true; +} + +// assigning default type N to both input and output attrs with value -1 +template +inline bool DefaultType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *iattr, + std::vector *oattr) { + // TODO(junwu): check whether need to use ctx + for (int& v : *oattr) { + if (v == none) v = default_val; + } + for (int& v : *iattr) { + if (v == none) v = default_val; + } + return true; +} + +nnvm::Graph InferShape(nnvm::Graph graph, + nnvm::ShapeVector shape_inputs, + const std::string& shape_attr_key) { + using dmlc::any; + if (shape_inputs.size() != 0) { + graph.attrs["shape_inputs"] = std::make_shared(std::move(shape_inputs)); + } + if (shape_attr_key.length() != 0) { + graph.attrs["shape_attr_key"] = std::make_shared(std::move(shape_attr_key)); + } + return InferAttr( + std::move(graph), nnvm::TShape(), + "FInferShape", "shape_inputs", "shape_attr_key", + "shape", "shape_num_unknown_nodes", + [](const nnvm::TShape& s) { return s.ndim() == 0 || s.Size() == 0; }, + nullptr, true); +} + +nnvm::Graph InferType(nnvm::Graph graph, + nnvm::DTypeVector dtype_inputs, + const std::string& dtype_attr_key) { + using dmlc::any; + if (dtype_inputs.size() != 0) { + graph.attrs["dtype_inputs"] = std::make_shared(std::move(dtype_inputs)); + } + if (dtype_attr_key.length() != 0) { + graph.attrs["dtype_attr_key"] = std::make_shared(std::move(dtype_attr_key)); + } + return InferAttr( + std::move(graph), -1, + "FInferType", "dtype_inputs", "dtype_attr_key", + "dtype", "dtype_num_unknown_nodes", + [](const int t) { return t == -1; }, + SameType, true); +} + +nnvm::Graph InferStorageType(nnvm::Graph graph, + nnvm::StorageTypeVector storage_type_inputs, + const std::string& storage_type_attr_key) { + using dmlc::any; + if (storage_type_inputs.size() != 0) { + graph.attrs["storage_type_inputs"] = std::make_shared(std::move(storage_type_inputs)); + } + if (storage_type_attr_key.length() != 0) { + graph.attrs["storage_type_attr_key"] = std::make_shared(std::move(storage_type_attr_key)); + } + // for storage type, the backward attr is not necessarily the same as it's correspondence + const int kDefaultStorage = 0; + return InferAttr( + std::move(graph), -1, + "FInferStorageType", "storage_type_inputs", "storage_type_attr_key", + "storage_type", "storage_type_num_unknown_nodes", + [](const int t) { return t == -1; }, + DefaultType, false); +} + +} // namespace exec +} // namespace mxnet diff --git a/src/executor/inplace_addto_detect_pass.cc b/src/executor/inplace_addto_detect_pass.cc index 75a2608313aa..1a0bc9cb40a6 100644 --- a/src/executor/inplace_addto_detect_pass.cc +++ b/src/executor/inplace_addto_detect_pass.cc @@ -44,6 +44,8 @@ Graph DetectInplaceAddTo(Graph g) { uint32_t eid_rhs = idx.entry_id(inode.inputs[1]); if (ref_count[eid_rhs] != 1) continue; if (inode.inputs[0].node_id >= inode.inputs[1].node_id) continue; + // TODO(haibin) support inplace addto for Dynamic Storage + if (storage_id[eid_rhs] == kDynamicStorageID) continue; CHECK_NE(storage_id[eid_rhs], sid); storage_id[eid_rhs] = sid; addto_entry[eid_rhs] = 1; diff --git a/src/io/inst_vector.h b/src/io/inst_vector.h index d82bd48e2fa1..0a665bd6811d 100644 --- a/src/io/inst_vector.h +++ b/src/io/inst_vector.h @@ -12,6 +12,7 @@ #include #include #include +#include #include #include diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index a51e24503785..91488c065033 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -23,7 +23,7 @@ namespace io { class BatchLoader : public IIterator { public: explicit BatchLoader(IIterator *base): - base_(base), head_(1), num_overflow_(0) { + head_(1), num_overflow_(0), base_(base) { } virtual ~BatchLoader(void) { @@ -34,7 +34,7 @@ class BatchLoader : public IIterator { std::vector > kwargs_left; // init batch param, it could have similar param with kwargs_left = param_.InitAllowUnknown(kwargs); - // Init space for out_ + // Init space for out out_.inst_index = new unsigned[param_.batch_size]; out_.batch_size = param_.batch_size; out_.data.clear(); @@ -51,6 +51,7 @@ class BatchLoader : public IIterator { } head_ = 1; } + virtual bool Next(void) { out_.num_batch_padd = 0; out_.batch_size = param_.batch_size; @@ -110,23 +111,25 @@ class BatchLoader : public IIterator { return out_; } - private: + protected: /*! \brief batch parameters */ BatchParam param_; /*! \brief output data */ TBlobBatch out_; - /*! \brief base iterator */ - IIterator *base_; /*! \brief on first */ int head_; /*! \brief number of overflow instances that readed in round_batch mode */ int num_overflow_; + /*! \brief tensor to hold data */ + std::vector data_; + + private: + /*! \brief base iterator */ + IIterator *base_; /*! \brief data shape */ std::vector shape_; /*! \brief unit size */ std::vector unit_size_; - /*! \brief tensor to hold data */ - std::vector data_; // initialize the data holder by using from the first batch. inline void InitData(const DataInst& first_batch) { shape_.resize(first_batch.data.size()); diff --git a/src/io/iter_libsvm.cc b/src/io/iter_libsvm.cc new file mode 100644 index 000000000000..04dcf289a020 --- /dev/null +++ b/src/io/iter_libsvm.cc @@ -0,0 +1,258 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file iter_libsvm.cc + * \brief define a LibSVM Reader to read in arrays + */ +#include +#include +#include +#include +#include +#include "./iter_sparse_prefetcher.h" +#include "./iter_sparse_batchloader.h" + +namespace mxnet { +namespace io { +// LibSVM parameters +struct LibSVMIterParam : public dmlc::Parameter { + /*! \brief path to data libsvm file */ + std::string data_libsvm; + /*! \brief data shape */ + TShape data_shape; + /*! \brief path to label libsvm file */ + std::string label_libsvm; + /*! \brief label shape */ + TShape label_shape; + // declare parameters + DMLC_DECLARE_PARAMETER(LibSVMIterParam) { + DMLC_DECLARE_FIELD(data_libsvm) + .describe("The input LibSVM file or a directory path."); + DMLC_DECLARE_FIELD(data_shape) + .describe("The shape of one example."); + DMLC_DECLARE_FIELD(label_libsvm).set_default("NULL") + .describe("The input LibSVM file or a directory path. " + "If NULL, all labels will be read from ``data_libsvm``."); + index_t shape1[] = {1}; + DMLC_DECLARE_FIELD(label_shape).set_default(TShape(shape1, shape1 + 1)) + .describe("The shape of one label."); + } +}; + +class LibSVMIter: public SparseIIterator { + public: + LibSVMIter() {} + virtual ~LibSVMIter() {} + + // intialize iterator loads data in + virtual void Init(const std::vector >& kwargs) { + param_.InitAllowUnknown(kwargs); + CHECK_EQ(param_.data_shape.ndim(), 1) << "dimension of data_shape is expected to be 1"; + data_parser_.reset(dmlc::Parser::Create(param_.data_libsvm.c_str(), + 0, 1, "libsvm")); + if (param_.label_libsvm != "NULL") { + label_parser_.reset(dmlc::Parser::Create(param_.label_libsvm.c_str(), + 0, 1, "libsvm")); + CHECK_GT(param_.label_shape.Size(), 1) + << "label_shape is not expected to be (1,) when param_.label_libsvm is set."; + } else { + CHECK_EQ(param_.label_shape.Size(), 1) + << "label_shape is expected to be (1,) when param_.label_libsvm is NULL"; + } + // both data and label are of CSRStorage in libsvm format + if (param_.label_shape.Size() > 1) { + out_.data.resize(6); + } else { + // only data is of CSRStorage in libsvm format. + out_.data.resize(4); + } + } + + virtual void BeforeFirst() { + data_parser_->BeforeFirst(); + if (label_parser_.get() != nullptr) { + label_parser_->BeforeFirst(); + } + data_ptr_ = label_ptr_ = 0; + data_size_ = label_size_ = 0; + inst_counter_ = 0; + end_ = false; + } + + virtual bool Next() { + if (end_) return false; + while (data_ptr_ >= data_size_) { + if (!data_parser_->Next()) { + end_ = true; return false; + } + data_ptr_ = 0; + data_size_ = data_parser_->Value().size; + } + out_.index = inst_counter_++; + CHECK_LT(data_ptr_, data_size_); + const auto data_row = data_parser_->Value()[data_ptr_++]; + // data, indices and indptr + out_.data[0] = AsDataBlob(data_row); + out_.data[1] = AsIdxBlob(data_row); + out_.data[2] = AsIndPtrPlaceholder(data_row); + + if (label_parser_.get() != nullptr) { + while (label_ptr_ >= label_size_) { + CHECK(label_parser_->Next()) + << "Data LibSVM's row is smaller than the number of rows in label_libsvm"; + label_ptr_ = 0; + label_size_ = label_parser_->Value().size; + } + CHECK_LT(label_ptr_, label_size_); + const auto label_row = label_parser_->Value()[label_ptr_++]; + // data, indices and indptr + out_.data[3] = AsDataBlob(label_row); + out_.data[4] = AsIdxBlob(label_row); + out_.data[5] = AsIndPtrPlaceholder(label_row); + } else { + out_.data[3] = AsScalarLabelBlob(data_row); + } + return true; + } + + virtual const DataInst &Value(void) const { + return out_; + } + + virtual const NDArrayStorageType GetStorageType(bool is_data) const { + if (is_data) return kCSRStorage; + return param_.label_shape.Size() > 1 ? kCSRStorage : kDefaultStorage; + } + + virtual const TShape GetShape(bool is_data) const { + if (is_data) return param_.data_shape; + return param_.label_shape; + } + + private: + inline TBlob AsDataBlob(const dmlc::Row& row) { + const real_t* ptr = row.value; + TShape shape(mshadow::Shape1(row.length)); + return TBlob((real_t*) ptr, shape, cpu::kDevMask); // NOLINT(*) + } + + inline TBlob AsIdxBlob(const dmlc::Row& row) { + const uint64_t* ptr = row.index; + TShape shape(mshadow::Shape1(row.length)); + return TBlob((int64_t*) ptr, shape, cpu::kDevMask, mshadow::kInt64); // NOLINT(*) + } + + inline TBlob AsIndPtrPlaceholder(const dmlc::Row& row) { + return TBlob(nullptr, mshadow::Shape1(0), cpu::kDevMask, mshadow::kInt64); + } + + inline TBlob AsScalarLabelBlob(const dmlc::Row& row) { + const real_t* ptr = row.label; + return TBlob((real_t*) ptr, mshadow::Shape1(1), cpu::kDevMask); // NOLINT(*) + } + + LibSVMIterParam param_; + // output instance + DataInst out_; + // internal instance counter + unsigned inst_counter_{0}; + // at end + bool end_{false}; + // label parser + size_t label_ptr_{0}, label_size_{0}; + size_t data_ptr_{0}, data_size_{0}; + std::unique_ptr > label_parser_; + std::unique_ptr > data_parser_; +}; + + +DMLC_REGISTER_PARAMETER(LibSVMIterParam); + +MXNET_REGISTER_IO_ITER(LibSVMIter) +.describe(R"code(Returns the LibSVM file iterator. This iterator is experimental and +should be used with care. + +The input data is similar to libsvm file format, except that the indices are expected to be +zero-based instead of one-based. Details of the libsvm format are available at +`https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/` + +In this function, the `data_shape` parameter is used to set the shape of each line of the data. +The dimension of both `data_shape` and `label_shape` are expected to be 1. + +When `label_libsvm` is set to ``NULL``, both data and label are read from the same file specified +by `data_libsvm`. Otherwise, data is read from `data_libsvm` and label from `label_libsvm`, +in this case, if `data_libsvm` contains label, it will ignored. + +The `LibSVMIter` only support `round_batch` parameter set to ``True`` for now. So, if `batch_size` +is 3 and there are 4 total rows in libsvm file, 2 more examples +are consumed at the first round. If `reset` function is called after first round, +the call is ignored and remaining examples are returned in the second round. + +If ``data_libsvm = 'data/'`` is set, then all the files in this directory will be read. + +Examples:: + + // Contents of libsvm file ``data.t``. + 1.0 0:0.5 2:1.2 + -2.0 + -3.0 0:0.6 1:2.4 2:1.2 + 4 2:-1.2 + + // Creates a `LibSVMIter` with `batch_size`=3. + LibSVMIter = mx.io.LibSVMIter(data_libsvm = 'data.t', data_shape = (3,), + batch_size = 3) + + // The first batch (data and label) + [[ 0.5 0. 1.2 ] + [ 0. 0. 0. ] + [ 0.6 2.4 1.2 ]] + + [ 1. -2. -3.] + + // The second batch (data and label) + [[ 0. 0. -1.2 ] + [ 0.5 0. 1.2 ] + [ 0. 0. 0. ]] + + [ 4. 1. -2.] + + // Contents of libsvm file ``label.t`` + 1.0 + -2.0 0:0.125 + -3.0 2:1.2 + 4 1:1.0 2:-1.2 + + // Creates a `LibSVMIter` with specified label file + LibSVMIter = mx.io.LibSVMIter(data_libsvm = 'data.t', data_shape = (3,), + label_libsvm = 'label.t', label_shape = (3,), batch_size = 3) + + // Two batches of data read from the above iterator are as follows(data and label): + // The first batch + [[ 0.5 0. 1.2 ] + [ 0. 0. 0. ] + [ 0.6 2.4 1.2 ]] + + [[ 0. 0. 0. ] + [ 0.125 0. 0. ] + [ 0. 0. 1.2 ]] + + // The second batch + [[ 0. 0. -1.2 ] + [ 0.5 0. 1.2 ] + [ 0. 0. 0. ]] + + [[ 0. 1. -1.2 ] + [ 0. 0. 0. ] + [ 0.125 0. 0. ]] + +)code" ADD_FILELINE) +.add_arguments(LibSVMIterParam::__FIELDS__()) +.add_arguments(BatchParam::__FIELDS__()) +.add_arguments(PrefetcherParam::__FIELDS__()) +.set_body([]() { + return new SparsePrefetcherIter( + new SparseBatchLoader( + new LibSVMIter())); + }); + +} // namespace io +} // namespace mxnet diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 9050ef2d1b38..3eb85b12c077 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -28,8 +28,7 @@ namespace io { class PrefetcherIter : public IIterator { public: explicit PrefetcherIter(IIterator* base) - : loader_(base), out_(nullptr) { - } + : loader_(base), out_(nullptr) {} ~PrefetcherIter() { while (recycle_queue_.size() != 0) { @@ -38,21 +37,24 @@ class PrefetcherIter : public IIterator { delete batch; } delete out_; - iter_.Destroy(); + iter.Destroy(); } - virtual void Init(const std::vector >& kwargs) { + void InitParams(const std::vector >& kwargs) { std::vector > kwargs_left; // init image rec param kwargs_left = param_.InitAllowUnknown(kwargs); - // use the kwarg to init batch loader - loader_->Init(kwargs); // maximum prefetch threaded iter internal size const int kMaxPrefetchBuffer = 16; // init thread iter - iter_.set_max_capacity(kMaxPrefetchBuffer); + iter.set_max_capacity(kMaxPrefetchBuffer); + } - iter_.Init([this](DataBatch **dptr) { + virtual void Init(const std::vector >& kwargs) { + InitParams(kwargs); + // use the kwarg to init batch loader + loader_->Init(kwargs); + iter.Init([this](DataBatch **dptr) { if (!loader_->Next()) return false; const TBlobBatch& batch = loader_->Value(); if (*dptr == nullptr) { @@ -91,7 +93,7 @@ class PrefetcherIter : public IIterator { } virtual void BeforeFirst(void) { - iter_.BeforeFirst(); + iter.BeforeFirst(); } virtual bool Next(void) { @@ -106,9 +108,9 @@ class PrefetcherIter : public IIterator { arr.WaitToWrite(); } recycle_queue_.pop(); - iter_.Recycle(&old_batch); + iter.Recycle(&old_batch); } - return iter_.Next(&out_); + return iter.Next(&out_); } virtual const DataBatch &Value(void) const { return *out_; @@ -117,16 +119,16 @@ class PrefetcherIter : public IIterator { protected: /*! \brief prefetcher parameters */ PrefetcherParam param_; - /*! \brief internal batch loader */ - std::unique_ptr > loader_; + /*! \brief backend thread */ + dmlc::ThreadedIter iter; private: + /*! \brief internal batch loader */ + std::unique_ptr > loader_; /*! \brief output data */ DataBatch *out_; /*! \brief queue to be recycled */ std::queue recycle_queue_; - /*! \brief backend thread */ - dmlc::ThreadedIter iter_; }; } // namespace io } // namespace mxnet diff --git a/src/io/iter_sparse.h b/src/io/iter_sparse.h new file mode 100644 index 000000000000..24e3d81ee553 --- /dev/null +++ b/src/io/iter_sparse.h @@ -0,0 +1,27 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file iter_sparse.h + * \brief mxnet sparse data iterator + */ +#ifndef MXNET_IO_ITER_SPARSE_H_ +#define MXNET_IO_ITER_SPARSE_H_ + +#include +#include + +namespace mxnet { +/*! + * \brief iterator type + * \param DType data type + */ +template +class SparseIIterator : public IIterator { + public: + /*! \brief storage type of the data or label */ + virtual const NDArrayStorageType GetStorageType(bool is_data) const = 0; + /*! \brief shape of the data or label */ + virtual const TShape GetShape(bool is_data) const = 0; +}; // class SparseIIterator + +} // namespace mxnet +#endif // MXNET_IO_ITER_SPARSE_H_ diff --git a/src/io/iter_sparse_batchloader.h b/src/io/iter_sparse_batchloader.h new file mode 100644 index 000000000000..e8cddb9e9704 --- /dev/null +++ b/src/io/iter_sparse_batchloader.h @@ -0,0 +1,185 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file iter_sparse_batchloader.h + * \brief define a batch adapter to create sparse tblob batch + */ +#ifndef MXNET_IO_ITER_SPARSE_BATCHLOADER_H_ +#define MXNET_IO_ITER_SPARSE_BATCHLOADER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "./inst_vector.h" +#include "./image_iter_common.h" +#include "./iter_batchloader.h" +#include "./iter_sparse.h" + +namespace mxnet { +namespace io { + +/*! \brief create a batch iterator from single instance iterator */ +class SparseBatchLoader : public BatchLoader, public SparseIIterator { + public: + explicit SparseBatchLoader(SparseIIterator *base): + BatchLoader(base), sparse_base_(base) { + } + + virtual ~SparseBatchLoader(void) {} + + inline void Init(const std::vector >& kwargs) { + BatchLoader::Init(kwargs); + data_stype_ = sparse_base_->GetStorageType(true); + label_stype_ = sparse_base_->GetStorageType(false); + if (param_.round_batch == 0) { + LOG(FATAL) << "sparse batch loader doesn't support round_batch == false yet"; + } + } + + virtual void BeforeFirst(void) { + BatchLoader::BeforeFirst(); + } + + virtual bool Next(void) { + out_.num_batch_padd = 0; + out_.batch_size = param_.batch_size; + this->head_ = 0; + // if overflown from previous round, directly return false, until before first is called + if (num_overflow_ != 0) return false; + index_t top = 0; + inst_cache_.clear(); + while (sparse_base_->Next()) { + inst_cache_.emplace_back(sparse_base_->Value()); + if (inst_cache_.size() >= param_.batch_size) break; + } + // no more data instance + if (inst_cache_.size() == 0) { + return false; + } + if (inst_cache_.size() < param_.batch_size) { + CHECK_GT(param_.round_batch, 0); + num_overflow_ = 0; + sparse_base_->BeforeFirst(); + for (; inst_cache_.size() < param_.batch_size; ++num_overflow_) { + CHECK(sparse_base_->Next()) << "number of input must be bigger than batch size"; + inst_cache_.emplace_back(sparse_base_->Value()); + } + } + out_.num_batch_padd = num_overflow_; + CHECK_EQ(inst_cache_.size(), param_.batch_size); + this->InitDataFromBatch(); + for (size_t j = 0; j < inst_cache_.size(); j++) { + const auto& d = inst_cache_[j]; + out_.inst_index[top] = d.index; + // TODO(haibin) double check the type? + int64_t unit_size = 0; + for (size_t i = 0; i < d.data.size(); ++i) { + // indptr tensor + if (IsIndPtr(i)) { + auto indptr = data_[i].get(); + if (j == 0) indptr[0] = 0; + indptr[j + 1] = indptr[j] + unit_size; + offsets_[i] = j; + } else { + // indices and values tensor + unit_size = d.data[i].shape_.Size(); + MSHADOW_TYPE_SWITCH(data_[i].type_flag_, DType, { + const auto begin = offsets_[i]; + const auto end = offsets_[i] + unit_size; + mshadow::Copy(data_[i].get().Slice(begin, end), + d.data[i].get_with_shape(mshadow::Shape1(unit_size))); + }); + offsets_[i] += unit_size; + } + } + } + return true; + } + + virtual const TBlobBatch &Value(void) const { + return BatchLoader::Value(); + } + + virtual const NDArrayStorageType GetStorageType(bool is_data) const { + return sparse_base_->GetStorageType(is_data); + } + + virtual const TShape GetShape(bool is_data) const { + TShape inst_shape = sparse_base_->GetShape(is_data); + std::vector shape_vec; + shape_vec.push_back(param_.batch_size); + for (index_t dim = 0; dim < inst_shape.ndim(); ++dim) { + shape_vec.push_back(inst_shape[dim]); + } + return TShape(shape_vec.begin(), shape_vec.end()); + } + + private: + /*! \brief base sparse iterator */ + SparseIIterator *sparse_base_; + /*! \brief data instances */ + std::vector inst_cache_; + /*! \brief data storage type */ + NDArrayStorageType data_stype_; + /*! \brief data label type */ + NDArrayStorageType label_stype_; + /*! \brief tensor offset for slicing */ + std::vector offsets_; + + // check whether ith position is the indptr tensor for a CSR tensor + inline bool IsIndPtr(size_t i) { + auto data_num_aux = num_aux_data(data_stype_); + auto label_num_aux = num_aux_data(label_stype_); + auto label_indptr_offset = data_num_aux + 1 + label_num_aux; + // data indptr + if (i == data_num_aux && data_stype_ == kCSRStorage) { + return true; + } + // label indptr + if (i == label_indptr_offset && label_stype_ == kCSRStorage && data_stype_ == kCSRStorage) { + return true; + } + return false; + } + + // initialize the data holder by using from the batch + inline void InitDataFromBatch() { + CHECK(data_stype_ == kCSRStorage || label_stype_ == kCSRStorage); + CHECK_GT(inst_cache_.size(), 0); + out_.data.clear(); + data_.clear(); + offsets_.clear(); + + size_t total_size = inst_cache_[0].data.size(); + data_.resize(total_size); + offsets_.resize(total_size, 0); + std::vector vec_sizes(total_size, 0); + // accumulate the memory required for a batch + for (size_t i = 0; i < total_size; ++i) { + size_t size = 0; + // vec_size for indptr + if (IsIndPtr(i)) { + size = param_.batch_size + 1; + } else { + for (const auto &d : inst_cache_) size += d.data[i].shape_.Size(); + } + vec_sizes[i] = size; + } + + CHECK_EQ(vec_sizes[0], vec_sizes[1]); + for (size_t i = 0; i < total_size; ++i) { + int src_type_flag = inst_cache_[0].data[i].type_flag_; + // init object attributes + TShape dst_shape(mshadow::Shape1(vec_sizes[i])); + data_[i].resize(mshadow::Shape1(vec_sizes[i]), src_type_flag); + CHECK(data_[i].dptr_ != nullptr); + out_.data.push_back(TBlob(data_[i].dptr_, dst_shape, cpu::kDevMask, src_type_flag)); + } + } +}; // class BatchLoader +} // namespace io +} // namespace mxnet +#endif // MXNET_IO_ITER_SPARSE_BATCHLOADER_H_ diff --git a/src/io/iter_sparse_prefetcher.h b/src/io/iter_sparse_prefetcher.h new file mode 100644 index 000000000000..79b4fa8e2c6c --- /dev/null +++ b/src/io/iter_sparse_prefetcher.h @@ -0,0 +1,135 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file iter_sparse_prefetcher.h + * \brief define a prefetcher using threaditer to keep k batch fetched + */ +#ifndef MXNET_IO_ITER_SPARSE_PREFETCHER_H_ +#define MXNET_IO_ITER_SPARSE_PREFETCHER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "./inst_vector.h" +#include "./image_iter_common.h" +#include "./iter_prefetcher.h" +#include "./iter_sparse.h" + +namespace mxnet { +namespace io { +// iterator on sparse data +class SparsePrefetcherIter : public PrefetcherIter { + public: + explicit SparsePrefetcherIter(SparseIIterator* base) + : PrefetcherIter(base), sparse_loader_(base) {} + + ~SparsePrefetcherIter() {} + + virtual void Init(const std::vector >& kwargs) { + PrefetcherIter::InitParams(kwargs); + // use the kwarg to init batch loader + sparse_loader_->Init(kwargs); + iter.Init([this](DataBatch **dptr) { + if (!sparse_loader_->Next()) return false; + const TBlobBatch& batch = sparse_loader_->Value(); + if (*dptr == nullptr) { + // allocate databatch + *dptr = new DataBatch(); + (*dptr)->num_batch_padd = batch.num_batch_padd; + // (*dptr)->data.at(0) => data + // (*dptr)->data.at(1) => label + (*dptr)->data.resize(2); + (*dptr)->index.resize(batch.batch_size); + size_t data_iter = 0; + for (size_t i = 0; i < (*dptr)->data.size(); ++i) { + bool is_data = i == 0; + auto stype = this->GetStorageType(is_data); + auto dtype = param_.dtype ? param_.dtype.value() : batch.data[data_iter].type_flag_; + if (stype == kDefaultStorage) { + (*dptr)->data.at(i) = NDArray(batch.data[data_iter].shape_, + Context::CPU(), false, dtype); + } else { + (*dptr)->data.at(i) = NDArray(stype, this->GetShape(is_data), + Context::CPU(), false, dtype); + } + data_iter += num_aux_data(stype) + 1; + } + } + // copy data over + size_t data_iter = 0; + for (size_t i = 0; i < (*dptr)->data.size(); ++i) { + auto& nd = ((*dptr)->data)[i]; + auto stype = nd.storage_type(); + auto& data_i = ((*dptr)->data)[i]; + if (stype == kDefaultStorage) { + CopyFromTo(data_i.data(), batch.data[data_iter]); + } else if (stype == kCSRStorage) { + auto& values = batch.data[data_iter]; + auto& indices = batch.data[data_iter + 1]; + auto& indptr = batch.data[data_iter + 2]; + // allocate memory + CHECK_EQ(indices.shape_.Size(), values.shape_.Size()); + nd.CheckAndAllocAuxData(csr::kIdx, indices.shape_); + nd.CheckAndAllocData(values.shape_); + nd.CheckAndAllocAuxData(csr::kIndPtr, indptr.shape_); + // copy values, indices and indptr + CopyFromTo(data_i.data(), values); + CopyFromTo(data_i.aux_data(csr::kIdx), indices); + CopyFromTo(data_i.aux_data(csr::kIndPtr), indptr); + } else { + LOG(FATAL) << "Storage type not implemented: " << stype; + } + data_iter += num_aux_data(stype) + 1; + (*dptr)->num_batch_padd = batch.num_batch_padd; + } + if (batch.inst_index) { + std::copy(batch.inst_index, + batch.inst_index + batch.batch_size, + (*dptr)->index.begin()); + } + return true; + }, + [this]() { sparse_loader_->BeforeFirst(); }); + } + + virtual void BeforeFirst(void) { + PrefetcherIter::BeforeFirst(); + } + + virtual bool Next(void) { + return PrefetcherIter::Next(); + } + virtual const DataBatch &Value(void) const { + return PrefetcherIter::Value(); + } + + virtual const NDArrayStorageType GetStorageType(bool is_data) const { + return sparse_loader_->GetStorageType(is_data); + } + + virtual const TShape GetShape(bool is_data) const { + return sparse_loader_->GetShape(is_data); + } + + private: + /*! \brief internal sparse batch loader */ + SparseIIterator* sparse_loader_; + + inline void CopyFromTo(TBlob dst, const TBlob src) { + MSHADOW_TYPE_SWITCH(src.type_flag_, DType, { + mshadow::Copy(dst.FlatTo1D(), src.FlatTo1D()); + }); + } +}; +} // namespace io +} // namespace mxnet +#endif // MXNET_IO_ITER_SPARSE_PREFETCHER_H_ diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 07f2d24bd223..59308be92ce3 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -3,13 +3,16 @@ */ #ifndef MXNET_KVSTORE_COMM_H_ #define MXNET_KVSTORE_COMM_H_ +#include #include #include #include #include #include #include +#include #include "mxnet/ndarray.h" +#include "../ndarray/ndarray_function.h" namespace mxnet { namespace kvstore { /** @@ -22,9 +25,10 @@ class Comm { } virtual ~Comm() { } /** - * \brief init key with the data shape + * \brief init key with the data shape and storage shape */ - virtual void Init(int key, const TShape& shape, int dtype = mshadow::kFloat32) = 0; + virtual void Init(int key, const NDArrayStorageType stype, + const TShape& shape, int dtype = mshadow::kFloat32) = 0; /** * \brief returns src[0] + .. + src[src.size()-1] */ @@ -57,43 +61,85 @@ class CommCPU : public Comm { CommCPU() { nthread_reduction_ = dmlc::GetEnv("MXNET_KVSTORE_REDUCTION_NTHREADS", 4); bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000); + // TODO(junwu) delete the following data member, now for benchmark only + is_serial_push_ = dmlc::GetEnv("MXNET_KVSTORE_SERIAL_PUSH", 0); } virtual ~CommCPU() { } - void Init(int key, const TShape& shape, int type = mshadow::kFloat32) override { - merge_buf_[key].merged = NDArray(shape, pinned_ctx_, false, type); + void Init(int key, const NDArrayStorageType stype, const TShape& shape, + int type = mshadow::kFloat32) override { + if (stype == kDefaultStorage) { + merge_buf_[key].merged = NDArray(shape, pinned_ctx_, false, type); + } else { + merge_buf_[key].merged = NDArray(stype, shape, pinned_ctx_, true, type); + } } const NDArray& Reduce(int key, const std::vector& src, int priority) override { + auto& buf = merge_buf_[key]; // avoid extra copy for single device, but it may bring problems for // abnormal usage of kvstore if (src.size() == 1) { - return src[0]; + if (src[0].storage_type() == buf.merged.storage_type()) { + return src[0]; + } else { + CopyFromTo(src[0], &buf.merged, priority); + return buf.merged; + } } - std::vector const_vars(src.size() - 1); - std::vector reduce(src.size()); - auto& buf = merge_buf_[key]; - CopyFromTo(src[0], &buf.merged, priority); - reduce[0] = buf.merged; - if (buf.copy_buf.empty()) { - buf.copy_buf.resize(src.size()-1); - for (size_t j = 0; j < src.size() - 1; ++j) { - buf.copy_buf[j] = NDArray( - src[0].shape(), pinned_ctx_, false, src[0].dtype()); + if (buf.merged.storage_type() == kDefaultStorage) { + std::vector const_vars(src.size() - 1); + std::vector reduce(src.size()); + CopyFromTo(src[0], &buf.merged, priority); + reduce[0] = buf.merged; + + if (buf.copy_buf.empty()) { + buf.copy_buf.resize(src.size()-1); + for (size_t j = 0; j < src.size() - 1; ++j) { + // allocate NDArray basd on storage type + buf.copy_buf[j] = NDArray( + src[0].shape(), pinned_ctx_, false, src[0].dtype()); + } } - } - for (size_t i = 1; i < src.size(); ++i) { - CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority); - reduce[i] = buf.copy_buf[i-1]; - const_vars[i-1] = reduce[i].var(); - } + for (size_t i = 1; i < src.size(); ++i) { + CopyFromTo(src[i], &(buf.copy_buf[i-1]), priority); + reduce[i] = buf.copy_buf[i-1]; + const_vars[i-1] = reduce[i].var(); + } + + Engine::Get()->PushSync([reduce, this](RunContext rctx) { + ReduceSumCPU(reduce); + }, Context::CPU(), const_vars, {reduce[0].var()}, + FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); + + } else { + // buf.merged is a sparse ndarray. + std::vector const_vars(src.size()); + std::vector reduce(src.size()); - Engine::Get()->PushSync([reduce, this](RunContext rctx) { - ReduceSumCPU(reduce); - }, Context::CPU(), const_vars, {reduce[0].var()}, - FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); + if (buf.copy_buf.empty()) { + buf.copy_buf.resize(src.size()); + for (size_t j = 0; j < src.size(); ++j) { + buf.copy_buf[j] = NDArray( + src[0].storage_type(), src[0].shape(), pinned_ctx_, true, src[0].dtype()); + } + } + for (size_t i = 0; i < src.size(); ++i) { + CopyFromTo(src[i], &(buf.copy_buf[i]), priority); + reduce[i] = buf.copy_buf[i]; + const_vars[i] = reduce[i].var(); + } + auto result = buf.merged; + Engine::Get()->PushSync([reduce, result, this](RunContext rctx) { + NDArray out = result; + is_serial_push_? + ReduceSumCPUExSerial(reduce, &out) + : mxnet::ndarray::ElementwiseSum(rctx.get_stream(), reduce, &out); + }, Context::CPU(), const_vars, {result.var()}, + FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); + } return buf.merged; } @@ -126,6 +172,79 @@ class CommCPU : public Comm { }); } + // serial implementation of reduce sum for row sparse NDArray. + // TODO(haibin) use openmp kernel to parallelize the summation + inline void ReduceSumCPUExSerial(const std::vector &in, NDArray *out) { + using namespace rowsparse; + using namespace mshadow; + auto stype = out->storage_type(); + CHECK_EQ(stype, kRowSparseStorage) << "Unexpected storage type " << stype; + size_t total_num_rows = 0; + size_t num_in = in.size(); + // skip the ones with empty indices and values + std::vector skip(num_in, false); + // the values tensor of the inputs + MSHADOW_TYPE_SWITCH(out->dtype(), DType, { + MSHADOW_IDX_TYPE_SWITCH(out->aux_type(kIdx), IType, { + std::vector> in_vals(num_in); + std::vector> in_indices(num_in); + // offset to the values tensor of all inputs + std::vector offsets(num_in, 0); + std::vector num_rows(num_in, 0); + for (size_t i = 0; i < num_in; i++) { + if (!in[i].storage_initialized()) { + skip[i] = true; + continue; + } + auto size = in[i].aux_shape(kIdx).Size(); + num_rows[i] = size; + total_num_rows += size; + in_vals[i] = in[i].data().FlatTo2D(); + in_indices[i] = in[i].aux_data(kIdx).FlatTo1D(); + } + std::vector indices; + indices.reserve(total_num_rows); + // gather indices from all inputs + for (size_t i = 0; i < num_in; i++) { + for (size_t j = 0; j < num_rows[i]; j++) { + indices.emplace_back(in_indices[i][j]); + } + } + CHECK_EQ(indices.size(), total_num_rows); + // dedup indices + std::sort(indices.begin(), indices.end()); + indices.resize(std::unique(indices.begin(), indices.end()) - indices.begin()); + // the one left are unique non-zero rows + size_t nnr = indices.size(); + // allocate memory for output + out->CheckAndAlloc({Shape1(nnr)}); + auto idx_data = out->aux_data(kIdx).FlatTo1D(); + auto val_data = out->data().FlatTo2D(); + + for (size_t i = 0; i < nnr; i++) { + // copy indices back + idx_data[i] = indices[i]; + bool zeros = true; + for (size_t j = 0; j < num_in; j++) { + if (skip[j]) continue; + size_t offset = offsets[j]; + if (offset < num_rows[j]) { + if (indices[i] == in_indices[j][offset]) { + if (zeros) { + Copy(val_data[i], in_vals[j][offset], nullptr); + zeros = false; + } else { + val_data[i] += in_vals[j][offset]; + } + offsets[j] += 1; + } + } + } + } + }); + }); + } + template inline static void ReduceSumCPU( const std::vector &dptr, size_t offset, index_t size) { @@ -191,6 +310,7 @@ class CommCPU : public Comm { std::unordered_map merge_buf_; size_t bigarray_bound_; int nthread_reduction_; + bool is_serial_push_; }; /** @@ -209,8 +329,13 @@ class CommDevice : public Comm { virtual ~CommDevice() { } - void Init(int key, const TShape& shape, int dtype = mshadow::kFloat32) override { - sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype)); + void Init(int key, const NDArrayStorageType stype, const TShape& shape, + int dtype = mshadow::kFloat32) override { + if (stype == kDefaultStorage) { + sorted_key_attrs_.push_back(std::make_tuple(key, shape, dtype)); + } else { + LOG(FATAL) << "storage type " << stype << " not implemented for device yet"; + } } const NDArray& Reduce(int key, const std::vector& src, diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index 5f5a0cc67a64..59d9158012ef 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -11,6 +11,7 @@ #include "mxnet/engine.h" #include "ps/ps.h" #include "./kvstore_dist_server.h" +#include "../operator/tensor/init_op.h" #if MKL_EXPERIMENTAL == 1 #include #include "../operator/mkl/mkl_memory-inl.h" @@ -42,6 +43,7 @@ class KVStoreDist : public KVStoreLocal { } } bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000); + row_sparse_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false); } virtual ~KVStoreDist() { @@ -63,7 +65,7 @@ class KVStoreDist : public KVStoreLocal { const std::vector& values) override { CheckUnique(keys); for (size_t i = 0; i < keys.size(); ++i) { - comm_->Init(keys[i], values[i].shape(), values[i].dtype()); + comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype()); } if (get_rank() == 0) { Push_(keys, values, 0, false); @@ -97,36 +99,51 @@ class KVStoreDist : public KVStoreLocal { // use the same array for merging to guarantee that pull always happens // after the previous push on this key auto& recv_buf = comm_buf_[key]; + const auto storage_type = grouped_vals[i][0]->storage_type(); if (recv_buf.is_none()) { // it may happen for the first time a no-rank-0 worker pull the weight. - recv_buf = NDArray( - grouped_vals[i][0]->shape(), pinned_ctx_, false, grouped_vals[i][0]->dtype()); + if (storage_type == kDefaultStorage) { + recv_buf = NDArray(grouped_vals[i][0]->shape(), pinned_ctx_, + false, grouped_vals[i][0]->dtype()); + } else { + recv_buf = NDArray(storage_type, grouped_vals[i][0]->shape(), + pinned_ctx_, true, grouped_vals[i][0]->dtype()); + // initialize the buffer with sufficient memory + op::FillDnsZerosRspImpl(nullptr, &recv_buf); + } } + if (storage_type == kDefaultStorage) { #if MKL_EXPERIMENTAL == 1 - mkl_set_tblob_eager_mode(recv_buf.data()); + mkl_set_tblob_eager_mode(recv_buf.data()); #endif - real_t* data = static_cast(recv_buf.data().dptr_); - size_t size = recv_buf.shape().Size(); - - auto pull_from_servers = [this, key, data, size]( - RunContext rctx, Engine::CallbackOnComplete cb) { - // convert to ps keys - PSKV& pskv = EncodeKey(key, size); - - // issue pull, false means no delete - auto vals = new ps::SArray(data, size, false); - CHECK_NOTNULL(ps_worker_)->ZPull( - pskv.keys, vals, &pskv.lens, 0, [vals, cb](){ delete vals; cb(); }); - }; - - CHECK_NOTNULL(Engine::Get())->PushAsync( - pull_from_servers, - pinned_ctx_, - {}, - {recv_buf.var()}, - FnProperty::kNormal, - priority, - PROFILER_MESSAGE("KVStoreDistPull")); + real_t* data = static_cast(recv_buf.data().dptr_); + size_t size = recv_buf.shape().Size(); + auto pull_from_servers = [this, key, data, size]( + RunContext rctx, Engine::CallbackOnComplete cb) { + // convert to ps keys + PSKV& pskv = EncodeKey(key, size); + + // issue pull, false means no delete + auto vals = new ps::SArray(data, size, false); + CHECK_NOTNULL(ps_worker_)->ZPull( + pskv.keys, vals, &pskv.lens, kDefaultPushPull, [vals, cb](){ delete vals; cb(); }); + }; + + CHECK_NOTNULL(Engine::Get())->PushAsync( + pull_from_servers, + pinned_ctx_, + {}, + {recv_buf.var()}, + FnProperty::kNormal, + priority, + PROFILER_MESSAGE("KVStoreDistDefaultPull")); + } else if (storage_type == kRowSparseStorage) { + recv_buf.WaitToRead(); + grouped_vals[i][0]->WaitToRead(); + PullRowSparse(key, &recv_buf, grouped_vals[i][0]->aux_ndarray(rowsparse::kIdx), priority); + } else { + LOG(FATAL) << "unknown storage type " << storage_type; + } comm_->Broadcast(key, recv_buf, grouped_vals[i], priority); } @@ -204,41 +221,128 @@ class KVStoreDist : public KVStoreLocal { NDArray merged = do_merge ? comm_->Reduce(key, vals, priority) : vals[0]; auto& send_buf = comm_buf_[key]; + const auto storage_type = merged.storage_type(); if (merged.ctx().dev_mask() == cpu::kDevMask) { send_buf = merged; // avoid memory copy } else { if (send_buf.is_none()) { - send_buf = NDArray(merged.shape(), pinned_ctx_, false, merged.dtype()); + if (storage_type == kDefaultStorage) { + send_buf = NDArray(merged.shape(), pinned_ctx_, false, merged.dtype()); + } else { + send_buf = NDArray(storage_type, merged.shape(), pinned_ctx_, true, merged.dtype()); + // initialize the buffer with sufficient memory + op::FillDnsZerosRspImpl(nullptr, &send_buf); + } } CopyFromTo(merged, &send_buf); } // push to servers - send_buf.WaitToRead(); - size_t size = send_buf.shape().Size(); + if (storage_type == kDefaultStorage) { + send_buf.WaitToRead(); + size_t size = send_buf.shape().Size(); +#if MKL_EXPERIMENTAL == 1 + mkl_set_tblob_eager_mode(send_buf.data()); +#endif + real_t* data = static_cast(send_buf.data().dptr_); + auto push_to_servers = + [this, key, data, size](RunContext rctx, Engine::CallbackOnComplete cb) { + // convert to ps keys + PSKV& pskv = EncodeKey(key, size); + // do push. false means no delete + ps::SArray vals(data, size, false); + CHECK_NOTNULL(ps_worker_)->ZPush( + pskv.keys, vals, pskv.lens, 0, [cb]() { cb(); }); + }; + Engine::Get()->PushAsync( + push_to_servers, + pinned_ctx_, + {send_buf.var()}, + {}, + FnProperty::kNormal, + priority, + PROFILER_MESSAGE("KVStoreDistDefaultPush")); + } else if (storage_type == kRowSparseStorage) { + PushRowSparse(key, send_buf, priority); + } else { + LOG(FATAL) << "unknown storage type"; + } + } + } + + // pull row sparse weight into `recv_buf` based on indices given by `indices` + void PullRowSparse(int key, NDArray *recv_buf, const NDArray indices, int priority) { + using namespace rowsparse; + auto pull_from_servers = [this, key, recv_buf, &indices] + (RunContext rctx, Engine::CallbackOnComplete cb) { + // reading aux_shape & aux_data should be inside the engine + size_t num_rows = indices.shape().Size(); + recv_buf->CheckAndAlloc({mshadow::Shape1(num_rows)}); +#if MKL_EXPERIMENTAL == 1 + mkl_set_tblob_eager_mode(recv_buf->data()); +#endif + real_t* data = static_cast(recv_buf->data().dptr_); + const auto offsets = indices.data().dptr(); + const auto unit_len = recv_buf->shape().ProdShape(1, recv_buf->shape().ndim()); + size_t size = num_rows * unit_len; + // convert to ps keys in row sparse format + PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, unit_len); + if (this->row_sparse_verbose_) { + LOG(INFO) << "pull lens: " << pskv.lens << " keys: " << pskv.keys + << " size: " << size; + } + auto vals = new ps::SArray(data, size, false); + CHECK_NOTNULL(ps_worker_)->ZPull(pskv.keys, vals, &pskv.lens, kRowSparsePushPull, + [vals, cb]() { delete vals; cb(); }); + }; + CHECK_NOTNULL(Engine::Get())->PushAsync( + pull_from_servers, + pinned_ctx_, + {indices.var()}, + {recv_buf->var()}, + FnProperty::kNormal, + priority, + PROFILER_MESSAGE("KVStoreDistRowSparsePull")); + recv_buf->WaitToRead(); + // copy indices pulled + auto recv_buf_idx = recv_buf->aux_ndarray(kIdx); + CopyFromTo(indices, &recv_buf_idx); + } + + // push row sparse gradient + void PushRowSparse(int key, const NDArray &send_buf, int priority) { + using namespace rowsparse; + auto push_to_servers = [this, key, &send_buf] + (RunContext rctx, Engine::CallbackOnComplete cb) { #if MKL_EXPERIMENTAL == 1 mkl_set_tblob_eager_mode(send_buf.data()); #endif real_t* data = static_cast(send_buf.data().dptr_); - auto push_to_servers = - [this, key, data, size](RunContext rctx, Engine::CallbackOnComplete cb) { - // convert to ps keys - PSKV& pskv = EncodeKey(key, size); - - // do push. false means no delete - ps::SArray vals(data, size, false); - CHECK_NOTNULL(ps_worker_)->ZPush( - pskv.keys, vals, pskv.lens, 0, [cb]() { cb(); }); - }; - Engine::Get()->PushAsync( - push_to_servers, - pinned_ctx_, - {send_buf.var()}, - {}, - FnProperty::kNormal, - priority, - PROFILER_MESSAGE("KVStoreDistPush")); - } + if (!send_buf.storage_initialized()) return; + size_t num_rows = send_buf.aux_shape(kIdx).Size(); + const auto offsets = send_buf.aux_data(kIdx).dptr(); + const auto unit_len = send_buf.shape().ProdShape(1, send_buf.shape().ndim()); + const auto size = num_rows * unit_len; + + // convert to ps keys in row sparse format + PSKV& pskv = EncodeRowSparseKey(key, size, num_rows, offsets, unit_len); + if (this->row_sparse_verbose_) { + LOG(INFO) << "push lens: " << pskv.lens << " keys: " << pskv.keys + << " size: " << size; + } + ps::SArray vals(data, size, false); + CHECK_NOTNULL(ps_worker_)->ZPush(pskv.keys, vals, pskv.lens, kRowSparsePushPull, [cb]() { + cb(); + }); + }; + Engine::Get()->PushAsync( + push_to_servers, + pinned_ctx_, + {send_buf.var()}, + {}, + FnProperty::kNormal, + priority, + PROFILER_MESSAGE("KVStoreDistRowSparsePush")); } /** @@ -266,7 +370,7 @@ class KVStoreDist : public KVStoreLocal { std::unordered_map ps_kv_; /** - * \brief serizelize EncodeKey + * \brief serizelize EncodeRowSparseKey and EncodeKey */ std::mutex mu_; @@ -313,6 +417,37 @@ class KVStoreDist : public KVStoreLocal { return pskv; } + inline PSKV& EncodeRowSparseKey(int key, size_t size, int64_t num_rows, + const int64_t *offsets, size_t unit_len) { + mu_.lock(); + PSKV& pskv = ps_kv_[key]; + mu_.unlock(); + pskv.keys.clear(); + pskv.lens.clear(); + // TODO(haibin) cache this information + auto krs = ps::Postoffice::Get()->GetServerKeyRanges(); + int num_servers = krs.size(); + CHECK_GT(num_servers, 0); + + if (size >= bigarray_bound_ && row_sparse_verbose_) { + LOG(INFO) << "WARNING: big row_sparse weight array sharding is not implemented"; + } + // send it to a single random picked server + int server = (key * 9973) % num_servers; + ps::Key master_key = krs[server].begin() + key; + pskv.keys.push_back(master_key); + pskv.lens.push_back(0); + for (int64_t i = 0; i < num_rows; i++) { + ps::Key ps_key = krs[server].begin() + key + offsets[i]; + CHECK_LT(ps_key, krs[server].end()); + pskv.keys.push_back(ps_key); + pskv.lens.push_back(unit_len); + } + pskv.size = size; + return pskv; + } + + /** * \brief for worker to push and pull data */ @@ -327,6 +462,7 @@ class KVStoreDist : public KVStoreLocal { size_t bigarray_bound_; /// \brief send & recver buffer std::unordered_map comm_buf_; + bool row_sparse_verbose_; }; } // namespace kvstore diff --git a/src/kvstore/kvstore_dist_server.h b/src/kvstore/kvstore_dist_server.h index 02d4a38c2b10..59d2cb705654 100644 --- a/src/kvstore/kvstore_dist_server.h +++ b/src/kvstore/kvstore_dist_server.h @@ -19,6 +19,8 @@ namespace mxnet { namespace kvstore { +static const int kRowSparsePushPull = 1; +static const int kDefaultPushPull = 0; static const int kStopServer = -1; static const int kSyncMode = -2; @@ -92,7 +94,7 @@ class KVStoreDistServer { static_cast(ps_server_)->set_request_handle( std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2)); ps_server_->set_request_handle( - std::bind(&KVStoreDistServer::DataHandle, this, _1, _2, _3)); + std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3)); sync_mode_ = false; } @@ -133,9 +135,162 @@ class KVStoreDistServer { app->Response(recved); } - void DataHandle(const ps::KVMeta& req_meta, - const ps::KVPairs& req_data, - ps::KVServer* server) { + void DataHandleEx(const ps::KVMeta& req_meta, + const ps::KVPairs& req_data, + ps::KVServer* server) { + if (req_meta.cmd == kRowSparsePushPull) { + DataHandleRowSparse(req_meta, req_data, server); + } else { + DataHandleDefault(req_meta, req_data, server); + } + return; + } + + inline void MergeUpdates(const NDArray& recved, int key, + std::unordered_set *change_set) { + auto& merged = merge_buf_[key]; + if (merged.is_none()) { + merged = NDArray(recved.shape(), Context()); + } + if (change_set->find(key) == change_set->end()) { + CopyFromTo(recved, &merged, 0); + } else { + // TODO(haibin) handle row sparse gradient NDArray with `ReduceSumCPUExParallel` + merged += recved; + } + change_set->insert(key); + } + + void DataHandleRowSparse(const ps::KVMeta& req_meta, + const ps::KVPairs& req_data, + ps::KVServer* server) { + int master_key = DecodeKey(req_data.keys[0]); + auto num_rows = req_data.keys.size() - 1; + if (req_meta.push) { + CHECK_EQ(req_data.lens[0], 0); + CHECK_GT(req_data.lens.size(), 0); + auto unit_len = req_data.lens[1]; + CHECK_GT(unit_len, 0); + real_t* data = req_data.vals.data(); + auto& stored = store_[master_key]; + if (stored.is_none()) { + // LOG(INFO) << "initial push: " << master_key << " size = " << num_rows * unit_len; + // initialization + size_t ds[] = {num_rows, (size_t) unit_len}; + TShape dshape(ds, ds + 2); + CHECK_EQ(req_data.vals.size(), num_rows * unit_len); + TBlob recv_blob(data, dshape, cpu::kDevMask); // NOLINT(*) + NDArray recved = NDArray(recv_blob, 0); + stored = NDArray(dshape, Context()); + CopyFromTo(recved, &stored, 0); + stored.WaitToRead(); + server->Response(req_meta); + return; + } + // synced push + if (sync_mode_) { + // LOG(INFO) << "sync push: " << master_key; + size_t offset = 0; + auto& stored = store_[master_key]; + // merge updates + auto& request_buf = request_buf_[master_key]; + for (size_t i = 1; i <= num_rows; i++) { + // TODO(haibin) decode once and cache result + int key = DecodeKey(req_data.keys[i]); + auto len = req_data.lens[i]; + size_t ds[] = {(size_t)len}; + TShape dshape(ds, ds + 1); + TBlob recv_blob(data, // NOLINT(*) + dshape, cpu::kDevMask); + NDArray recved = NDArray(recv_blob, 0); + MergeUpdates(recved, key, &request_buf.change_set); + offset += len; + } + // perform updates + request_buf.requests.push_back(req_meta); + if (request_buf.requests.size() == (size_t) ps::NumWorkers()) { + // let the main thread to execute updater_, which is necessary for python + for (auto key : request_buf.change_set) { + // slice a row + auto row_id = key - master_key; + NDArray slice = stored.At(row_id); + NDArray update = merge_buf_[key]; + if (updater_) { + exec_.Exec([this, key, &update, &slice](){ + CHECK(updater_); + updater_(key, update, &slice); + }); + } else { + // if no updater, just copy + CopyFromTo(update, &slice); + } + slice.WaitToRead(); + } + request_buf.change_set.clear(); + // LOG(INFO) << "RESPONSE SYNC to " << request_buf.requests.size() << " clients"; + for (const auto& req : request_buf.requests) { + server->Response(req); + } + request_buf.requests.clear(); + } else { + for (size_t i = 1; i <= num_rows; i++) { + int key = DecodeKey(req_data.keys[i]); + merge_buf_[key].WaitToRead(); + } + } + } else { + // async push + auto& stored = store_[master_key]; + for (size_t i = 1; i <= num_rows; i++) { + int key = DecodeKey(req_data.keys[i]); + auto row_id = key - master_key; + auto len = req_data.lens[i]; + size_t ds[] = {(size_t)len}; + TShape dshape(ds, ds + 1); + TBlob recv_blob(data, // NOLINT(*) + dshape, cpu::kDevMask); + NDArray recved = NDArray(recv_blob, 0); + NDArray slice = stored.At(row_id); + exec_.Exec([this, key, &recved, &slice](){ + CHECK(updater_); + updater_(key, recved, &slice); + }); + } + server->Response(req_meta); + stored.WaitToRead(); + } + } else { + // pull + ps::KVPairs response; + auto& stored = store_[master_key]; + CHECK(!stored.is_none()) << "init " << master_key << " first"; + auto shape = stored.shape(); + auto unit_len = shape.ProdShape(1, shape.ndim()); + const float* data = stored.data().dptr(); + auto len = unit_len * num_rows; + // LOG(INFO) << "received pull: " << len; + // concat response values + response.vals.resize(len); + for (size_t i = 1; i <= num_rows; i++) { + int key = DecodeKey(req_data.keys[i]); + const auto src = data + key * unit_len; + auto begin = (i - 1) * unit_len; + auto end = i * unit_len; + response.vals.segment(begin, end).CopyFrom(src, unit_len); + } + // setup response + response.keys = req_data.keys; + std::vector lens(req_data.keys.size(), unit_len); + lens[0] = 0; + response.lens.CopyFrom(lens.begin(), lens.end()); + server->Response(req_meta, response); + } + } + + void DataHandleDefault(const ps::KVMeta& req_meta, + const ps::KVPairs &req_data, + ps::KVServer* server) { + CHECK_EQ(req_meta.cmd, kDefaultPushPull); // do some check CHECK_EQ(req_data.keys.size(), (size_t)1); if (req_meta.push) { @@ -164,37 +319,29 @@ class KVStoreDistServer { } else if (sync_mode_) { // synced push auto& merged = merge_buf_[key]; - if (merged.array.is_none()) { - merged.array = NDArray(dshape, Context()); - } - - if (merged.request.size() == 0) { - CopyFromTo(recved, &merged.array, 0); - } else { - merged.array += recved; - } - - merged.request.push_back(req_meta); - - if (merged.request.size() == (size_t)ps::NumWorkers()) { - // let the main thread to execute updater_, which is necessary for - // python + auto& request_buf = request_buf_[key]; + MergeUpdates(recved, key, &request_buf.change_set); + request_buf.requests.push_back(req_meta); + if (request_buf.requests.size() == (size_t) ps::NumWorkers()) { + CHECK_EQ(request_buf.change_set.size(), 1); + // let the main thread to execute updater_, which is necessary for python if (updater_) { exec_.Exec([this, key, &merged, &stored](){ CHECK(updater_); - updater_(key, merged.array, &stored); + updater_(key, merged, &stored); }); } else { // if no updater, just copy - CopyFromTo(merged.array, &stored); + CopyFromTo(merged, &stored); } - for (const auto& req : merged.request) { + request_buf.change_set.clear(); + for (const auto& req : request_buf.requests) { server->Response(req); } - merged.request.clear(); + request_buf.requests.clear(); stored.WaitToRead(); } else { - merged.array.WaitToRead(); + merged.WaitToRead(); } } else { // async push @@ -209,7 +356,7 @@ class KVStoreDistServer { // pull ps::KVPairs response; CHECK(!stored.is_none()) << "init " << key << " first"; - int len = stored.shape()[0]; + auto len = stored.shape().Size(); response.keys = req_data.keys; response.lens = {len}; // TODO(mli) try to remove this CopyFrom @@ -232,11 +379,14 @@ class KVStoreDistServer { std::unordered_map store_; - struct MergeBuf { - std::vector request; - NDArray array; + struct RequestBuf { + std::vector requests; + std::unordered_set change_set; }; - std::unordered_map merge_buf_; + + std::unordered_map merge_buf_; + std::unordered_map request_buf_; + Executor exec_; diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index dc5f7b786244..e159dd42e596 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -44,7 +44,7 @@ class KVStoreLocal : public KVStore { CHECK(local_.find(keys[i]) == local_.end()) << "duplicate init of key " << keys[i]; local_[keys[i]] = values[i].Copy(pinned_ctx_); - comm_->Init(keys[i], values[i].shape(), values[i].dtype()); + comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype()); } } @@ -82,7 +82,11 @@ class KVStoreLocal : public KVStore { } updater_(key, merged, &local); } else { - local = merged; + if (merged.storage_type() != local.storage_type()) { + local = merged.Copy(local.ctx()); + } else { + local = merged; + } } } } diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index f2e90dd56f31..21fecd2af22e 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -12,6 +12,9 @@ #include #include #include "./ndarray_function.h" +#include "../common/utils.h" +#include "../operator/tensor/matrix_op-inl.h" +#include "../operator/tensor/init_op.h" #include "./autograd.h" #if MXNET_USE_OPENCV @@ -34,6 +37,8 @@ NDArray NDArray::grad() const { NDArray NDArray::Reshape(const TShape &shape) const { using namespace autograd; + CHECK(storage_type() == kDefaultStorage) << "Reshape for storage type " << + storage_type() << " is not implemented yet"; if (AutogradRuntime::Get()->IsTraining()) { CHECK_GE(shape_.Size(), shape.Size()) << "NDArray.Reshape: target shape must have must have the same size as " @@ -64,12 +69,14 @@ NDArray NDArray::Reshape(const TShape &shape) const { } } - NDArray NDArray::Slice(index_t begin, index_t end) const { using namespace autograd; - NDArray ret = *this; + using namespace mshadow; CHECK(!is_none()) << "NDArray is not initialized"; CHECK_GE(shape_[0], end) << "Slice end index out of range"; + CHECK_EQ(storage_type(), kDefaultStorage); + NDArray ret = *this; + auto stype = storage_type(); size_t length = shape_.ProdShape(1, shape_.ndim()); MSHADOW_TYPE_SWITCH(ret.dtype(), DType, { ret.byte_offset_ += begin * length * sizeof(DType); @@ -96,8 +103,9 @@ NDArray NDArray::Slice(index_t begin, index_t end) const { } } - NDArray NDArray::At(index_t idx) const { + CHECK(storage_type() == kDefaultStorage) << "Storage type " + << storage_type() << " doesn't support At()"; NDArray ret = this->Slice(idx, idx+1); if (shape_.ndim() > 1) { return ret.Reshape(TShape(shape_.data()+1, shape_.data()+shape_.ndim())); @@ -220,11 +228,11 @@ void BinaryOp(const NDArray &lhs, // redirect everything to mshadow operations switch (lhs.ctx().dev_mask()) { case cpu::kDevMask: { - Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); - }, lhs.ctx(), const_vars, {ret.var()}, - FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { + TBlob tmp = ret.data(); + ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); + }, lhs.ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #if MXNET_USE_CUDA @@ -250,6 +258,7 @@ void SetValueOp(const real_t &rhs, NDArray *out) { switch (ret.ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([rhs, ret](RunContext ctx) { + CHECK(ret.storage_type() == kDefaultStorage); TBlob tmp = ret.data(); ndarray::Eval(rhs, &tmp, ctx); }, ret.ctx(), {}, {ret.var()}, @@ -321,6 +330,112 @@ void ScalarOp(const NDArray &lhs, } } +size_t num_aux_data(NDArrayStorageType stype) { + size_t num = 0; + switch (stype) { + case kDefaultStorage: num = 0; break; + case kCSRStorage: num = 2; break; + case kRowSparseStorage: num = 1; break; + default: LOG(FATAL) << "Unknown storage type" << stype; break; + } + return num; +} + +// Make a copy of a CSR NDArray +template +inline void CopyFromToCsrImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace mshadow; + CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + // if source storage is not initialized, fill destination with zeros + auto s = ctx.get_stream(); + if (!from.storage_initialized()) { + op::FillZerosCsrImpl(s, to); + return; + } + // Allocate storage + to->CheckAndAllocAuxData(csr::kIndPtr, from.aux_shape(csr::kIndPtr)); + to->CheckAndAllocAuxData(csr::kIdx, from.aux_shape(csr::kIdx)); + to->CheckAndAllocData(from.aux_shape(csr::kIdx)); + TBlob val = to->data(); + TBlob indptr = to->aux_data(csr::kIndPtr); + TBlob idx = to->aux_data(csr::kIdx); + ndarray::Copy(from.data(), &val, + from.ctx(), to->ctx(), ctx); + ndarray::Copy(from.aux_data(csr::kIndPtr), &indptr, + from.ctx(), to->ctx(), ctx); + ndarray::Copy(from.aux_data(csr::kIdx), &idx, + from.ctx(), to->ctx(), ctx); +} + +// Make a copy of a row-sparse NDArray +template +inline void CopyFromToRspImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace mshadow; + CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + // if source is zeros, fill destination with zeros, too + auto s = ctx.get_stream(); + if (!from.storage_initialized()) { + op::FillZerosRspImpl(s, to); + return; + } + auto aux_shape = from.aux_shape(rowsparse::kIdx); + to->CheckAndAlloc({aux_shape}); + TBlob val = to->data(); + TBlob idx = to->aux_data(rowsparse::kIdx); + ndarray::Copy(from.data(), &val, + from.ctx(), to->ctx(), ctx); + ndarray::Copy(from.aux_data(rowsparse::kIdx), &idx, + from.ctx(), to->ctx(), ctx); +} + +// Make a copy of a dense NDArray +template +inline void CopyFromToDnsImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace mshadow; + CHECK_EQ(from.storage_type(), to->storage_type()) << "Copying with different storage type"; + TBlob tmp = to->data(); + ndarray::Copy(from.data(), &tmp, + from.ctx(), to->ctx(), ctx); +} + +// Make a copy of an NDArray based on storage type +template +void CopyFromToImpl(const NDArray from, NDArray *to, RunContext ctx) { + using namespace std; + using namespace mshadow; + // if storage type doesn't match, cast the storage first + auto from_stype = from.storage_type(); + auto to_stype = to->storage_type(); + NDArray casted_nd; + if (from_stype != to_stype) { + TShape shape = from.shape(); + auto from_ctx = from.ctx(); + auto s = ctx.get_stream(); + // TODO(haibin) inplace conversion + if (to_stype == kDefaultStorage) { + casted_nd = NDArray(shape, from_ctx); + } else { + casted_nd = NDArray(to_stype, shape, from_ctx); + } + common::CastStorageDispatch(s, from, casted_nd); + } else { + casted_nd = from; + } + if (to_stype == kDefaultStorage) { + CopyFromToDnsImpl(casted_nd, to, ctx); + } else if (to_stype == kRowSparseStorage) { + CopyFromToRspImpl(casted_nd, to, ctx); + } else if (to_stype == kCSRStorage) { + CopyFromToCsrImpl(casted_nd, to, ctx); + } else { + LOG(FATAL) << "unknown storage type" << to_stype; + } + if (is_same::value || is_same::value) { + // Wait GPU kernel to complete + ctx.get_stream()->Wait(); + } +} + void CopyFromTo(const NDArray &from, NDArray *to, int priority) { if (from.var() == to->var()) { // skip to copy to itself @@ -335,44 +450,33 @@ void CopyFromTo(const NDArray &from, NDArray *to, int priority) { NDArray ret = *to; int a = from.ctx().dev_mask(); int b = to->ctx().dev_mask(); - std::vector const_vars; if (from.var() != ret.var()) const_vars.push_back(from.var()); if (a == cpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, from.ctx(), const_vars, {ret.var()}, FnProperty::kNormal, priority, PROFILER_MESSAGE("CopyCPU2CPU")); } else { #if MXNET_USE_CUDA if (a == cpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, ret.ctx(), const_vars, {ret.var()}, FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE("CopyCPU2GPU")); } else if (a == gpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, from.ctx(), const_vars, {ret.var()}, FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE("CopyGPU2CPU")); } else if (a == gpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { - TBlob tmp = ret.data(); - ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); - // Wait GPU kernel to complete - ctx.get_stream()->Wait(); + NDArray nd(ret); + CopyFromToImpl(from, &nd, ctx); }, from.ctx(), const_vars, {ret.var()}, from.dtype() != ret.dtype() ? FnProperty::kNormal : FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE("CopyGPU2GPU")); @@ -646,34 +750,76 @@ NDArray &NDArray::operator/=(const real_t &src) { /* magic number for ndarray version 1, with int64_t TShape */ static const uint32_t NDARRAY_V1_MAGIC = 0xF993fac8; +/* magic number for ndarray version 2, with storage type */ +static const uint32_t NDARRAY_V2_MAGIC = 0xF993fac9; + void NDArray::Save(dmlc::Stream *strm) const { - strm->Write(NDARRAY_V1_MAGIC); + // write magic number to mark this version + // for storage type + strm->Write(NDARRAY_V2_MAGIC); + + // save storage type + int32_t stype = storage_type(); + strm->Write(&stype, sizeof(stype)); + + const int32_t nad = num_aux_data(storage_type()); + // save storage shape if ndarray is sparse + if (nad > 0) { + storage_shape().Save(strm); + } + + // save shape shape_.Save(strm); if (is_none()) return; + // save context Context ctx = this->ctx(); ctx.Save(strm); TBlob save_data; - NDArray temp; + NDArray nd_cpu; // a copy of *this on cpu if (ctx.dev_mask() != cpu::kDevMask) { - temp = this->Copy(Context::CPU()); - temp.WaitToRead(); - save_data = temp.data(); + nd_cpu = this->Copy(Context::CPU()); + nd_cpu.WaitToRead(); + save_data = nd_cpu.data(); } else { this->WaitToRead(); save_data = this->data(); + nd_cpu = *this; } + // save type flag int32_t type_flag = save_data.type_flag_; strm->Write(&type_flag, sizeof(type_flag)); + + // save aux_types and aux_shapes + if (nad > 0) { + for (int i = 0; i < nad; ++i) { + int32_t aux_type_flag = aux_type(i); + strm->Write(&aux_type_flag, sizeof(aux_type_flag)); + aux_shape(i).Save(strm); + } + } + + // save data CHECK(save_data.CheckContiguous()); size_t type_size = mshadow::mshadow_sizeof(type_flag); - strm->Write(save_data.dptr_, type_size * shape_.Size()); + // save data could be values of sparse tensors + // must use save_data.shape_ instead of this->shape_ + strm->Write(save_data.dptr_, type_size * save_data.shape_.Size()); + + // save aux data + if (nad > 0) { + for (int i = 0; i < nad; ++i) { + TBlob save_data = nd_cpu.aux_data(i); + // save aux_data + CHECK(save_data.CheckContiguous()); + size_t aux_type_size = mshadow::mshadow_sizeof(aux_type(i)); + strm->Write(save_data.dptr_, aux_type_size * save_data.Size()); + } + } } -bool LegacyTShapeLoad(dmlc::Stream *strm, TShape *shape) { - uint32_t magic; - if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false; +bool LegacyTShapeLoad(dmlc::Stream *strm, TShape *shape, const uint32_t magic) { switch (magic) { case NDARRAY_V1_MAGIC: return shape->Load(strm); @@ -689,10 +835,10 @@ bool LegacyTShapeLoad(dmlc::Stream *strm, TShape *shape) { } } -bool NDArray::Load(dmlc::Stream *strm) { +bool NDArray::LegacyLoad(dmlc::Stream *strm, const uint32_t magic) { // load shape TShape shape; - if (!LegacyTShapeLoad(strm, &shape)) return false; + if (!LegacyTShapeLoad(strm, &shape, magic)) return false; if (shape.ndim() == 0) { *this = NDArray(); return true; } @@ -720,6 +866,88 @@ bool NDArray::Load(dmlc::Stream *strm) { } } +bool NDArray::Load(dmlc::Stream *strm) { + uint32_t magic; + if (strm->Read(&magic, sizeof(uint32_t)) != sizeof(uint32_t)) return false; + if (magic != NDARRAY_V2_MAGIC) { + return LegacyLoad(strm, magic); + } + + // load storage type + int32_t stype; + if (strm->Read(&stype, sizeof(stype)) != sizeof(stype)) return false; + const int32_t nad = num_aux_data(static_cast(stype)); + + // load storage shape + TShape sshape; + if (nad > 0) { + if (!sshape.Load(strm)) return false; + } + + // load shape + TShape shape; + if (!shape.Load(strm)) return false; + if (shape.ndim() == 0) { + *this = NDArray(); return true; + } + + // load context + Context ctx; + if (!ctx.Load(strm)) return false; + + // load type flag + int32_t type_flag; + if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false; + + // load aux_types and aux_shapes + std::vector aux_types; + std::vector aux_shapes; + if (nad > 0) { + aux_types.resize(nad); + aux_shapes.resize(nad); + for (int i = 0; i < nad; ++i) { + // load aux_type(i) + if (strm->Read(&aux_types[i], sizeof(aux_types[i])) != sizeof(aux_types[i])) return false; + // load aux_shapes(i) + if (!aux_shapes[i].Load(strm)) return false; + } + } + + // load data into CPU + NDArray temp; + if (0 == nad) { + temp = NDArray(shape, Context::CPU(), false, type_flag); + } else { + temp = NDArray(static_cast(stype), shape, + Context::CPU(), false, type_flag, + aux_types, aux_shapes, sshape); + } + // load data + TBlob load_data = temp.data(); + size_t type_size = mshadow::mshadow_sizeof(type_flag); + size_t nread = type_size * load_data.Size(); + if (strm->Read(load_data.dptr_, nread) != nread) return false; + + // load aux_data + if (nad > 0) { + for (int i = 0; i < nad; ++i) { + load_data = temp.aux_data(i); + type_size = mshadow::mshadow_sizeof(load_data.type_flag_); + nread = type_size * load_data.Size(); + if (strm->Read(load_data.dptr_, nread) != nread) return false; + } + } + + if (ctx.dev_mask() == cpu::kDevMask) { + *this = std::move(temp); return true; + } else { +#if MXNET_USE_CUDA + *this = temp.Copy(ctx); return true; +#else + *this = std::move(temp); return true; +#endif + } +} const uint64_t kMXAPINDArrayListMagic = 0x112; @@ -752,7 +980,16 @@ void NDArray::Load(dmlc::Stream* fi, } NDArray NDArray::Copy(Context ctx) const { - NDArray ret(shape(), ctx, true, dtype_); + NDArray ret; + if (kDefaultStorage == storage_type()) { + ret = NDArray(shape(), ctx, true, dtype_); + } else if (kUndefinedStorage != storage_type()) { + ret = NDArray(storage_type(), shape(), ctx, true, dtype_, + ptr_->aux_types, ptr_->aux_shapes, storage_shape()); + } else { + LOG(FATAL) << "NDArray::Copy cannot copy undefined storage-type ndarray to ctx.dev_type=" + << ctx.dev_type << ", ctx.dev_id=" << ctx.dev_id; + } CopyFromTo(*this, &ret); return ret; } diff --git a/src/ndarray/ndarray_function-inl.h b/src/ndarray/ndarray_function-inl.h index 28524b73d0dd..aad80fd4360a 100644 --- a/src/ndarray/ndarray_function-inl.h +++ b/src/ndarray/ndarray_function-inl.h @@ -12,27 +12,28 @@ // macro to help specialize evaluation function #ifndef DECL_TERNARY -#define DECL_TERNARY(XPU, OP, FUN) \ - template<> \ - void Eval(const TBlob &lhs, const TBlob &mhs, \ - const TBlob &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, mhs, rhs, ret, ctx); \ +#define DECL_TERNARY(XPU, OP, FUN) \ + template<> \ + void Eval(const TBlob &lhs, const TBlob &mhs, \ + const TBlob &rhs, TBlob *ret, RunContext ctx) { \ + FUN(lhs, mhs, rhs, ret, ctx); \ } #endif #ifndef DECL_BINARY -#define DECL_BINARY(XPU, OP, FUN) \ - template<> \ +#define DECL_BINARY(XPU, OP, FUN) \ + template<> \ void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, rhs, ret, ctx); \ + FUN(lhs, rhs, ret, ctx); \ } #endif #ifndef DECL_SCALAR -#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ - template<> \ - void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, rhs, ret, ctx); \ +#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ + template<> \ + void Eval(const TBlob &lhs, const real_t &rhs, \ + TBlob *ret, RunContext ctx) { \ + FUN(lhs, rhs, ret, ctx); \ } #endif @@ -44,10 +45,11 @@ namespace mxnet { namespace ndarray { + // true implementation template -inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalBinary_(const TBlob &lhs, const TBlob &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); CHECK_EQ(ret->type_flag_, lhs.type_flag_) @@ -61,10 +63,9 @@ inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs, }); } - template -inline void EvalOneHot_(const TBlob &index, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalOneHot_(const TBlob &index, const TBlob &rhs, + TBlob *ret, RunContext ctx) { LOG(INFO) << "The operator onehot_encode is deprecated; use one_hot instead."; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); @@ -81,8 +82,8 @@ inline void EvalOneHot_(const TBlob &index, const TBlob &rhs, } template -inline void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); // TODO(eric): support mixed type choose, i.e. int index and float rhs. @@ -98,8 +99,8 @@ inline void EvalMatChooseRowElem_(const TBlob &lhs, const TBlob &rhs, } template -inline void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, - TBlob *ret, RunContext ctx) { +void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); ret->get(s) @@ -109,8 +110,8 @@ inline void EvalMatFillRowElem_(const TBlob &lhs, const TBlob &mhs, const TBlob } template -inline void EvalScalar_(const TBlob &lhs, const real_t &rhs, - TBlob *ret, RunContext ctx) { +void EvalScalar_(const TBlob &lhs, const real_t &rhs, + TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); CHECK_EQ(ret->type_flag_, lhs.type_flag_) @@ -130,7 +131,7 @@ inline void EvalScalar_(const TBlob &lhs, const real_t &rhs, template<> void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max, - TBlob *ret, RunContext ctx) { + TBlob *ret, RunContext ctx) { typedef DEVICE xpu; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); @@ -145,12 +146,11 @@ void EvalClip(const TBlob &src, const real_t &a_min, const real_t &a_max } template<> -void EvalRandom( - const real_t &a, - const real_t &b, - const Resource &resource, - TBlob *ret, - RunContext ctx) { +void EvalRandom(const real_t &a, + const real_t &b, + const Resource &resource, + TBlob *ret, + RunContext ctx) { typedef DEVICE xpu; mshadow::Stream *s = ctx.get_stream(); switch (ret->type_flag_) { @@ -426,6 +426,7 @@ DECL_SCALAR(DEVICE, Plus, EvalScalar_, true) DECL_SCALAR(DEVICE, Minus, EvalScalar_, true) DECL_SCALAR(DEVICE, Mul, EvalScalar_, true) DECL_SCALAR(DEVICE, Div, EvalScalar_, true) + // for reverse seq DECL_SCALAR(DEVICE, Plus, EvalScalar_, false) DECL_SCALAR(DEVICE, Minus, EvalScalar_, false) diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc index a5ba2660fd34..b03166f4d834 100644 --- a/src/ndarray/ndarray_function.cc +++ b/src/ndarray/ndarray_function.cc @@ -7,6 +7,7 @@ // this will be invoked by gcc and compile CPU version #include "./ndarray_function.h" #include "./ndarray_function-inl.h" +#include "../common/utils.h" namespace mxnet { namespace ndarray { @@ -26,5 +27,134 @@ void Copy(const TBlob &from, TBlob *to, } }) } + +template +void ElementwiseSumRspImpl(const std::vector& nds, + const std::vector& uniq_row_idx, + NDArray* out, + const int nthreads = 4) { +#pragma omp parallel num_threads(nthreads) + { + const size_t nnr = uniq_row_idx.size(); + const int num_threads = omp_get_num_threads(); + size_t row_block_len = (nnr + num_threads - 1) / num_threads; + const size_t row_block_start = omp_get_thread_num() * row_block_len; + if (row_block_start < nnr) { + const size_t row_block_end = std::min(row_block_start+row_block_len, nnr); + + auto out_values = out->data().FlatTo2D(); + auto out_indices = out->aux_data(rowsparse::kIdx).FlatTo1D(); + for (size_t i = row_block_start; i < row_block_end; ++i) { + out_indices[i] = uniq_row_idx[i]; + } + for (const auto& nd : nds) { + if (nd.storage_initialized()) { + const auto nd_indices = nd.aux_data(rowsparse::kIdx).FlatTo1D(); + const auto nd_values = nd.data().FlatTo2D(); + const auto nd_num_rows = nd.aux_shape(rowsparse::kIdx).Size(); + const IType* nd_indices_start = &nd_indices[0]; + const IType* nd_indices_end = nd_indices_start + nd_num_rows; + const IType* row_idx_ptr = std::lower_bound(nd_indices_start, nd_indices_end, + out_indices[row_block_start]); + // skip this nd if all of its row indices are smaller than out_indices[row_block_start] + // or current row block is not covered by [*row_idx_ptr, nd_indices_end). + if (nd_indices_end == row_idx_ptr || *row_idx_ptr > out_indices[row_block_end-1]) { + continue; + } + for (size_t irow = row_block_start; + irow < row_block_end && row_idx_ptr != nd_indices_end;) { + if (out_indices[irow] == *row_idx_ptr) { + auto out_value_cur_row = out_values[irow]; + const auto offset = row_idx_ptr - nd_indices_start; + auto nd_value_cur_row = nd_values[offset]; + for (size_t j = 0; j < nd_value_cur_row.shape_[0]; ++j) { + out_value_cur_row[j] += nd_value_cur_row[j]; + } + ++irow; + ++row_idx_ptr; + } else if (out_indices[irow] < *row_idx_ptr) { + ++irow; + } else { + ++row_idx_ptr; + } + } + } + } + } + } +} + +/*! + * \brief Given a vector of ndarrays, generate a index vector containing + * all the unique row indices of the ndarrays. + */ +template +void GetUniqueRspRowIdx(const std::vector& nds, + std::vector* uniq_row_idx) { + using namespace rowsparse; + size_t total_num_rows = 0; + for (const auto& nd : nds) { + CHECK_EQ(nd.storage_type(), kRowSparseStorage); + if (nd.storage_initialized()) { + total_num_rows += nd.aux_shape(kIdx).Size(); + } + } + + uniq_row_idx->resize(total_num_rows); + int nthreads = omp_get_max_threads(); + int offset = 0; + for (const auto& nd : nds) { + if (nd.storage_initialized()) { + const IType* nd_row_idx = nd.aux_data(kIdx).dptr(); + const int num_rows = nd.aux_shape(kIdx).Size(); +#pragma omp parallel for num_threads(nthreads) + for (int i = 0; i < num_rows; ++i) { + (*uniq_row_idx)[offset+i] = nd_row_idx[i]; + } + offset += num_rows; + } + } + + common::ParallelSort(uniq_row_idx->begin(), uniq_row_idx->end(), nthreads); + auto it = std::unique(uniq_row_idx->begin(), uniq_row_idx->end()); + uniq_row_idx->resize(it - uniq_row_idx->begin()); +} + +void ElementwiseSumRsp(const std::vector& nds, NDArray* out) { + if (nds.empty()) return; + using namespace rowsparse; + CHECK_EQ(out->storage_type(), kRowSparseStorage) + << "Expected row sparse storage type (" + << out->storage_type() << " given)"; + + MSHADOW_TYPE_SWITCH(out->dtype(), DType, { + MSHADOW_IDX_TYPE_SWITCH(out->aux_type(kIdx), IType, { + std::vector uniq_row_idx; + GetUniqueRspRowIdx(nds, &uniq_row_idx); + out->CheckAndAlloc({mshadow::Shape1(uniq_row_idx.size())}); + out->data().FlatTo2D() = static_cast(0); + ElementwiseSumRspImpl(nds, uniq_row_idx, out, omp_get_max_threads()); + }); + }); +} + +/*! + * \brief Parallel cpu impl of elemwise sum for sparse tensors. + * Currently only support row sparse sum. + */ +template<> +void ElementwiseSum(mshadow::Stream* s, + const std::vector& nds, + NDArray* out) { + if (nds.empty()) return; + + if (nds[0].storage_type() == kRowSparseStorage) { + ElementwiseSumRsp(nds, out); + } else { + LOG(FATAL) << "ElementwiseSum has not been implemented for storage_type = << " + << nds[0].storage_type(); + } +} + } // namespace ndarray } // namespace mxnet diff --git a/src/ndarray/ndarray_function.h b/src/ndarray/ndarray_function.h index 479f6f99f07a..5d992eaf8e53 100644 --- a/src/ndarray/ndarray_function.h +++ b/src/ndarray/ndarray_function.h @@ -10,6 +10,7 @@ #include #include #include +#include #include #include "../operator/mshadow_op.h" @@ -150,6 +151,14 @@ void ElementwiseSum(const std::vector source, TBlob *out, RunContext ctx); +/*! + * \brief Interface for parallel impl of elemwise sum for sparse matrices + */ +template +void ElementwiseSum(mshadow::Stream* s, + const std::vector& nds, + NDArray* out); + // broadcasting template void EvalBroadcast(TBlob const& src, TBlob* ret, int size, RunContext ctx); diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index aa95d2d8696a..c2bc58fb9972 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -62,6 +62,42 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs, return true; } +// Only inferring output storage types from input for now +template +inline bool ElemwiseStorageAttr(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + auto deduce = [&](std::vector *vec, const char *name, AttrType& result, + bool fallback) { + auto &v = *vec; + for (size_t i = 0; i < vec->size(); ++i) { + if (v[i] == kUndefinedStorage) { + // if input type is unknown, assume it's default storage + CHECK(assign(&v[i], kDefaultStorage)); + } else if (assign(&result, v[i]) == false && fallback) { + result = kDefaultStorage; + } + } + }; + AttrType dattr = kUndefinedStorage; + deduce(in_attrs, "input", dattr, enable_fallback); + if (reverse_infer) { + LOG(FATAL) << "not implemented yet"; + } + auto write = [&](std::vector *vec, const char *name) { + for (size_t i = 0; i < vec->size(); ++i) { + CHECK(assign(&(*vec)[i], dattr)) + << "Incompatible attr in node " << attrs.name << " at " << i << "-th " + << name << ": " << "expected " << dattr << ", got " << (*vec)[i]; + } + }; + if (is_none(dattr)) dattr = kDefaultStorage; + write(out_attrs, "output"); + return true; +} + template inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, @@ -82,6 +118,33 @@ inline bool ElemwiseType(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs, -1); } +template +inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + // TODO(junwu): add ctx info into storage inference logic + CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + return ElemwiseStorageAttr( + attrs, in_attrs, out_attrs); +} + +inline bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + // TODO(junwu): add ctx info into storage inference logic + CHECK_EQ(in_attrs->size(), static_cast(2)) << " in operator " << attrs.name; + CHECK_EQ(out_attrs->size(), static_cast(1)) << " in operator " << attrs.name; + auto &in = *in_attrs; + auto &out = *out_attrs; + CHECK_NE(in[1], kUndefinedStorage) << "rhs storage type must be known"; + if (in[0] == kUndefinedStorage) in[0] = in[1]; + if (out[0] == kUndefinedStorage) out[0] = in[1]; + return true; +} + // Transfer gradient and input to FGradient function struct ElemwiseGradUseIn { const char *op_name; diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 9b5dcfe3d3b1..d4a473c8be0c 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -7,6 +7,7 @@ #ifndef MXNET_OPERATOR_MXNET_OP_H_ #define MXNET_OPERATOR_MXNET_OP_H_ +#include #include #include @@ -22,6 +23,8 @@ const float PI = 3.14159265358979323846; using std::isnan; #endif +template +int get_num_threads(const int N); #ifdef __CUDACC__ #define CUDA_KERNEL_LOOP(i, n) \ @@ -37,8 +40,18 @@ inline int cuda_get_num_blocks(const int N) { using namespace mshadow::cuda; return std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum); } + +template<> +inline int get_num_threads(const int N) { + using namespace mshadow::cuda; + return kBaseThreadNum * cuda_get_num_blocks(N); +} #endif // __CUDACC__ +template<> +inline int get_num_threads(const int N) { + return omp_get_max_threads(); +} /*! \brief operator request type switch */ #define MXNET_ASSIGN_REQ_SWITCH(req, ReqType, ...) \ @@ -198,7 +211,6 @@ __global__ void mxnet_generic_kernel(int N, Args... args) { } } - template struct Kernel { template diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index a43d092bceb6..3d88c9047e3a 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -11,12 +11,15 @@ #include #include #include +#include +#include #include #include #include #include #include #include "../common/cuda_utils.h" +#include "../common/utils.h" namespace mxnet { namespace op { @@ -107,6 +110,19 @@ inline std::string type_string(const int& x) { return "unknown"; } +/*! \brief get string representation of storage_type */ +inline std::string stype_string(const int& x) { + switch (x) { + case kDefaultStorage: + return "default"; + case kCSRStorage: + return "csr"; + case kRowSparseStorage: + return "row_sparse"; + } + return "unknown"; +} + /*! * \brief Assign x to y. Checks for compatiblity when y is not empty. * Allow missing dim in both x and y (as 0). @@ -183,6 +199,24 @@ inline bool type_assign(int *y, const int& x) { } \ } +/*! + * \brief macro assign type to out if out is unknown (-1) otherwise check consistency + * Use macro so we can see the error file more clearly + * \param type_array the storage type array to store the result + * \param index the index of in the array + * \param type the inferred storage type + */ +#define STORAGE_TYPE_ASSIGN_CHECK(type_array, index, type) \ + { \ + if (!type_assign(&(type_array)[index], type)) { \ + std::ostringstream os; \ + os << "Storage type inconsistent, Provided=" \ + << stype_string((type_array)[index]) << ',' \ + << " inferred storage type=" << stype_string(type); \ + throw ::mxnet::op::InferTypeError(os.str(), index); \ + } \ + } + // helper macro to implement bind dispatch #if MXNET_USE_CUDA #define DO_BIND_DISPATCH(Method, ...) \ @@ -315,6 +349,33 @@ inline void ParamParser(nnvm::NodeAttrs* attrs) { attrs->parsed = std::move(param); } +template +void FCompExFallback(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + FCompute fcompute, + const std::string& fname) { + using namespace mxnet::common; + std::vector in_blobs, out_blobs; + std::vector temp_in, temp_out; + GetDefaultBlobs(inputs, &in_blobs, &temp_in, ctx, true); + GetDefaultBlobs(outputs, &out_blobs, &temp_out, ctx, true); + fcompute(attrs, ctx, in_blobs, req, out_blobs); + CastNonDefaultStorage(outputs, temp_out, ctx, true); +} + +#define CHECK_RSP_ALL_ROWS_NON_ZERO(rsp, func, param) \ + { \ + CHECK(rsp.storage_shape()[0] == rsp.shape()[0]) << func \ + << " for RowSparse " << param << " is only implemented for " \ + << "RowSparse " << param << " with all rows containing non-zeros. " \ + << "Expects " << param << ".values.shape[0] (" << rsp.storage_shape()[0] \ + << ") == " << param << ".shape[0] (" << rsp.shape()[0] << ")."; \ + } + + } // namespace op } // namespace mxnet #endif // MXNET_OPERATOR_OPERATOR_COMMON_H_ diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 9f4959350362..c73720b55370 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -18,6 +18,7 @@ #include "./mshadow_op.h" #include "./elemwise_op_common.h" #include "mxnet_op.h" +#include "./tensor/init_op.h" namespace mxnet { namespace op { @@ -84,6 +85,173 @@ inline void SGDUpdate(const nnvm::NodeAttrs& attrs, }); } +/*! \brief kernel for sparse sgd + */ +template +struct SGDDnsRspKernel { + // DType is the output data type + // IType is row sparse idx type + // i is the ith row in row sparse gradient + template + MSHADOW_XINLINE static void Map(int i, size_t width, DType* out, const DType* weight, + const IType* grad_idx, const DType *grad_val, + const DType clip_gradient, const DType lr, + const DType wd, const DType rescale_grad) { + for (size_t j = 0; j < width; j++) { + uint64_t data_i = grad_idx[i] * width + j; + uint64_t grad_i = i * width + j; + if (clip_gradient >= 0.0f) { + KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] - + (lr) * mshadow_op::clip::Map(rescale_grad * grad_val[grad_i], clip_gradient)); + } else { + KERNEL_ASSIGN(out[data_i], req, (1.f - lr * wd) * weight[data_i] - + (lr * rescale_grad) * grad_val[grad_i]); + } + } + } +}; + +template +inline void SGDUpdateDnsRspImpl(const SGDParam& param, + const OpContext &ctx, + const TBlob& weight, + const NDArray& grad, + const OpReqType& req, + TBlob *out) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + using namespace mxnet_op; + Stream* s = ctx.get_stream(); + CHECK_EQ(grad.storage_type(), kRowSparseStorage); + // if gradients are zeros, no weights are updated + if (!grad.storage_initialized() || req == kNullOp) return; + CHECK_GT(weight.shape_.Size(), 0); + + MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { + MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + auto weight_data = weight.dptr(); + auto grad_idx = grad.aux_data(rowsparse::kIdx).dptr(); + auto grad_val = grad.data().dptr(); + auto num_rows = grad.aux_shape(rowsparse::kIdx)[0]; + auto width = weight.shape_.ProdShape(1, weight.ndim()); + Kernel, xpu>::Launch(s, num_rows, width, + out->dptr(), weight_data, grad_idx, grad_val, + static_cast(param.clip_gradient), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad)); + }); + }); + }); +} + +/*! \brief kernel for sparse sgd + */ +template +struct SGDRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, size_t num_cols, DType* out, const DType* weight, + const DType *grad, const DType clip_gradient, const DType lr, + const DType wd, const DType rescale_grad) { + bool contains_non_zeros = false; + index_t j = 0; + index_t offset = i * num_cols; + for (; j < num_cols; ++j) { + if (grad[offset + j] != 0) { + contains_non_zeros = true; + break; + } + } + if (!contains_non_zeros) return; + const DType rate = 1.f - lr * wd; + for (index_t j = 0; j < num_cols; j++) { + auto index = offset + j; + if (clip_gradient >= 0.0f) { + KERNEL_ASSIGN(out[index], req, rate * weight[index] - + lr * mshadow_op::clip::Map(rescale_grad * grad[index], clip_gradient)); + } else { + KERNEL_ASSIGN(out[index], req, rate * weight[index] - + lr * rescale_grad * grad[index]); + } + } + } +}; + +template +inline void SGDUpdateRspDnsImpl(const SGDParam& param, + const OpContext &ctx, + const NDArray& weight, + const TBlob& grad, + const OpReqType req, + NDArray *out) { + using namespace mshadow; + using namespace mxnet_op; + using namespace rowsparse; + CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights"); + CHECK_EQ(weight.storage_type(), kRowSparseStorage); + if (req == kNullOp) return; + CHECK(weight.storage_initialized()); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + auto weight_data = weight.data().dptr(); + auto grad_data = grad.dptr(); + auto num_rows = weight.aux_shape(kIdx)[0]; + auto num_cols = weight.shape().ProdShape(1, weight.shape().ndim()); + Kernel, xpu>::Launch(s, num_rows, num_cols, + out->data().dptr(), weight_data, grad_data, + static_cast(param.clip_gradient), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad)); + }); + }); +} + +template +inline void SGDUpdateRspRspImpl(const SGDParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const OpReqType& req, + NDArray *out) { + CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDUpdate", "weights"); + // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only + // feed in kWriteTo as req for all operators. + // For sgd we don't want to assign zeros to the output values when req == kWriteTo + auto out_req = req; + if (out_req == kWriteTo) out_req = kWriteInplace; + // reuse dns rsp implementation when storage_shape == shape + TBlob out_blob = out->data(); + SGDUpdateDnsRspImpl(param, ctx, weight.data(), grad, out_req, &out_blob); +} + +template +inline void SGDUpdateEx(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + const SGDParam& param = nnvm::get(attrs.parsed); + auto weight_stype = inputs[0].storage_type(); + auto grad_stype = inputs[1].storage_type(); + if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage) { + TBlob out = outputs[0].data(); + SGDUpdateDnsRspImpl(param, ctx, inputs[0].data(), inputs[1], req[0], &out); + } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + SGDUpdateRspRspImpl(param, ctx, inputs[0], inputs[1], req[0], &out); + } else if (weight_stype == kRowSparseStorage && grad_stype == kDefaultStorage) { + NDArray out = outputs[0]; + SGDUpdateRspDnsImpl(param, ctx, inputs[0], inputs[1].data(), req[0], &out); + } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage) { + FCompExFallback(attrs, ctx, inputs, req, outputs, SGDUpdate, "SGDUpdate"); + } +} + struct SGDMomParam : public dmlc::Parameter { float lr; float momentum; @@ -257,6 +425,206 @@ inline void MP_SGDMomUpdate(const nnvm::NodeAttrs& attrs, }); } +template +struct SGDMomDnsRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, size_t width, DType* out_data, + DType* mom_data, const DType* weight_data, const IType* grad_idx, + const DType* grad_data, const DType clip_gradient, const DType momentum, + const DType lr, const DType wd, const DType rescale_grad) { + const DType rate = lr * wd; + for (size_t j = 0; j < width; j++) { + uint64_t data_i = grad_idx[i] * width + j; + uint64_t grad_i = i * width + j; + if (clip_gradient >= 0.0f) { + mom_data[data_i] = momentum * mom_data[data_i] + - rate * weight_data[data_i] + - lr * + mshadow_op::clip::Map(rescale_grad * grad_data[grad_i], + clip_gradient); + } else { + mom_data[data_i] = momentum * mom_data[data_i] + - rate * weight_data[data_i] + - lr * rescale_grad * grad_data[grad_i]; + } + KERNEL_ASSIGN(out_data[data_i], req, weight_data[data_i] + mom_data[data_i]); + } + } +}; + +template +inline void SGDMomUpdateDnsRspDnsImpl(const SGDMomParam& param, + const OpContext& ctx, + const TBlob& weight, + const NDArray& grad, + const TBlob& mom, + const OpReqType& req, + TBlob *out) { + using namespace mxnet_op; + using namespace rowsparse; + Stream* s = ctx.get_stream(); + if (!grad.storage_initialized() || req == kNullOp) return; + CHECK_GT(weight.shape_.Size(), 0); + CHECK_GT(mom.shape_.Size(), 0); + + MSHADOW_REAL_TYPE_SWITCH(weight.type_flag_, DType, { + MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(kIdx), IType, { + MXNET_ASSIGN_REQ_SWITCH(req, req_type, { + auto weight_data = weight.dptr(); + auto grad_idx = grad.aux_data(kIdx).dptr(); + auto grad_val = grad.data().dptr(); + auto mom_data = mom.dptr(); + auto out_data = out->dptr(); + auto num_rows = grad.aux_shape(kIdx)[0]; + auto width = weight.shape_.ProdShape(1, weight.ndim()); + Kernel, xpu>::Launch(s, num_rows, width, + out_data, mom_data, weight_data, grad_idx, grad_val, + static_cast(param.clip_gradient), static_cast(param.momentum), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad)); + }); + }); + }); +} + +template +struct SGDMomRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, size_t num_cols, DType* out, DType* mom, + const DType* weight, const DType *grad, + const DType clip_gradient, const DType momentum, + const DType lr, const DType wd, const DType rescale_grad) { + bool contains_non_zeros = false; + index_t j = 0; + index_t offset = i * num_cols; + for (; j < num_cols; ++j) { + if (grad[offset + j] != 0) { + contains_non_zeros = true; + break; + } + } + if (!contains_non_zeros) return; + const DType rate = lr * wd; + for (index_t j = 0; j < num_cols; j++) { + auto index = offset + j; + if (clip_gradient >= 0.0f) { + mom[index] = momentum * mom[index] - rate * weight[index] + - lr * mshadow_op::clip::Map(rescale_grad * grad[index], clip_gradient); + } else { + mom[index] = momentum * mom[index] - rate * weight[index] + - lr * rescale_grad * grad[index]; + } + KERNEL_ASSIGN(out[index], req, weight[index] + mom[index]); + } + } +}; + +template +inline void SGDMomUpdateRspDnsImpl(const SGDMomParam& param, + const OpContext &ctx, + const NDArray& weight, + const TBlob& grad, + const NDArray& mom, + const OpReqType req, + NDArray *out) { + using namespace mshadow; + using namespace mxnet_op; + using namespace rowsparse; + CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights"); + Stream* s = ctx.get_stream(); + CHECK_EQ(weight.storage_type(), kRowSparseStorage); + if (req == kNullOp) return; + CHECK(weight.storage_initialized()); + // fill mom with zero values if not initialized yet + if (!mom.storage_initialized()) { + NDArray mom_zeros = mom; + FillDnsZerosRspImpl(s, &mom_zeros); + } + // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only + // feed in kWriteTo as req for all operators. + // For sgd we don't want to assign zeros to the output values when req == kWriteTo + auto out_req = req; + if (out_req == kWriteTo) out_req = kWriteInplace; + MSHADOW_REAL_TYPE_SWITCH(weight.dtype(), DType, { + MXNET_ASSIGN_REQ_SWITCH(out_req, req_type, { + auto weight_data = weight.data().dptr(); + auto grad_data = grad.dptr(); + auto mom_data = mom.data().dptr(); + auto num_rows = weight.aux_shape(kIdx)[0]; + auto num_cols = weight.shape().ProdShape(1, weight.shape().ndim()); + Kernel, xpu>::Launch(s, num_rows, num_cols, + out->data().dptr(), mom_data, weight_data, grad_data, + static_cast(param.clip_gradient), static_cast(param.momentum), + static_cast(param.lr), static_cast(param.wd), + static_cast(param.rescale_grad)); + }); + }); +} + + +template +inline void SGDMomUpdateRspRspRspImpl(const SGDMomParam& param, + const OpContext& ctx, + const NDArray& weight, + const NDArray& grad, + const NDArray& mom, + const OpReqType& req, + NDArray *out) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + using namespace rowsparse; + CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SGDMomUpdate", "weights"); + Stream* s = ctx.get_stream(); + // fill mom with zero values in order to reuse the sgd mom dns impl + if (!mom.storage_initialized()) { + NDArray mom_zeros = mom; + FillDnsZerosRspImpl(s, &mom_zeros); + } + // TODO(haibin) this is a temporary solution, due to the fact that imperative_invoke only + // feed in kWriteTo as req for all operators. + // For sgd we don't want to assign zeros to the output values when req == kWriteTo + auto out_req = req; + if (out_req == kWriteTo) out_req = kWriteInplace; + TBlob out_blob = out->data(); + // reuse dns rsp implementation when storage_shape == shape + SGDMomUpdateDnsRspDnsImpl(param, ctx, weight.data(), grad, + mom.data(), out_req, &out_blob); +} + +template +inline void SGDMomUpdateEx(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + const SGDMomParam& param = nnvm::get(attrs.parsed); + auto &weight = inputs[0]; + auto &grad = inputs[1]; + auto &mom = inputs[2]; + auto weight_stype = weight.storage_type(); + auto grad_stype = grad.storage_type(); + auto mom_stype = mom.storage_type(); + if (weight_stype == kDefaultStorage && grad_stype == kRowSparseStorage && + mom_stype == kDefaultStorage) { + TBlob out = outputs[0].data(); + SGDMomUpdateDnsRspDnsImpl(param, ctx, weight.data(), grad, + mom.data(), req[0], &out); + } else if (weight_stype == kRowSparseStorage && grad_stype == kRowSparseStorage && + mom_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + SGDMomUpdateRspRspRspImpl(param, ctx, weight, grad, mom, req[0], &out); + } else if (weight_stype == kRowSparseStorage && grad_stype == kDefaultStorage && + mom_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + SGDMomUpdateRspDnsImpl(param, ctx, weight, grad.data(), mom, req[0], &out); + } else if (weight_stype == kDefaultStorage && grad_stype == kDefaultStorage && + mom_stype == kDefaultStorage) { + FCompExFallback(attrs, ctx, inputs, req, outputs, SGDMomUpdate, "SGDMomUpdate"); + } +} + struct AdamParam : public dmlc::Parameter { float lr; float beta1; diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 3fdb9c2498fb..980fd1956448 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -22,6 +22,9 @@ It updates the weights using:: weight = weight - learning_rate * gradient +If weights are stored with `row_sparse` storage, +update is applied only to rows whose gradient has non-zero entries. + )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -29,6 +32,7 @@ It updates the weights using:: .set_attr("FInferShape", ElemwiseShape<2, 1>) .set_attr("FInferType", ElemwiseType<2, 1>) .set_attr("FCompute", SGDUpdate) +.set_attr("FComputeEx", SGDUpdateEx) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_arguments(SGDParam::__FIELDS__()); @@ -52,6 +56,9 @@ It updates the weights using:: Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. +If weights are stored with `row_sparse` storage, +only rows whose gradients contain non-zero entries are updated (for both weight and momentum). + )code" ADD_FILELINE) .set_num_inputs(3) .set_num_outputs(1) @@ -63,6 +70,7 @@ Where the parameter ``momentum`` is the decay rate of momentum estimates at each return std::vector{2}; }) .set_attr("FCompute", SGDMomUpdate) +.set_attr("FComputeEx", SGDMomUpdateEx) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_argument("mom", "NDArray-or-Symbol", "Momentum") diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index a30584dd183f..fc1f47bef5db 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -10,10 +10,12 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(sgd_update) -.set_attr("FCompute", SGDUpdate); +.set_attr("FCompute", SGDUpdate) +.set_attr("FComputeEx", SGDUpdateEx); NNVM_REGISTER_OP(sgd_mom_update) -.set_attr("FCompute", SGDMomUpdate); +.set_attr("FCompute", SGDMomUpdate) +.set_attr("FComputeEx", SGDMomUpdateEx); NNVM_REGISTER_OP(mp_sgd_update) .set_attr("FCompute", MP_SGDUpdate); diff --git a/src/operator/tensor/cast_storage-inl.cuh b/src/operator/tensor/cast_storage-inl.cuh new file mode 100644 index 000000000000..0d4e601d0d2e --- /dev/null +++ b/src/operator/tensor/cast_storage-inl.cuh @@ -0,0 +1,26 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file cast_storage-inl.cuh + * \brief implementation of cast_storage op on GPU + */ +#ifndef MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_ +#define MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_ + +#include +#include + +namespace mxnet { +namespace op { + +inline void CastStorageDnsRspImpl(mshadow::Stream* s, const TBlob& dns, NDArray* rsp) { + LOG(FATAL) << "CastStorageDnsRspImpl gpu version is not implemented."; +} + +inline void CastStorageDnsCsrImpl(mshadow::Stream* s, const TBlob& dns, NDArray* csr) { + LOG(FATAL) << "CastStorageDnsCsrImpl gpu version is not implemented."; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_CUH_ diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h new file mode 100644 index 000000000000..da9ed30b998a --- /dev/null +++ b/src/operator/tensor/cast_storage-inl.h @@ -0,0 +1,336 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file cast_storage-inl.h + * \brief cast_storage implementation for dense and sparse tensors + */ +#ifndef MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_H_ +#define MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_H_ + +#include +#include +#include +#include "../mxnet_op.h" +#include "../operator_common.h" +#ifdef __CUDACC__ +#include "./cast_storage-inl.cuh" +#endif // __CUDACC__ + + +namespace mxnet { +namespace op { + +/*! + * \brief Kernel for marking row_idx of a RSP matrix per row + */ +struct MarkRspRowIdx { + // i represents the row index of the matrix data + template + MSHADOW_XINLINE static void Map(int i, RType* row_idx, const DType* data, + const index_t num_cols) { + index_t j = 0; + index_t offset = i * num_cols; + for (; j < num_cols; ++j) { + if (data[offset+j] != 0) { + break; + } + } + if (num_cols == j) { + row_idx[i] = 0; // mark as zero for zero row + } else { + row_idx[i] = 1; // mark as one for non-zero row + } + } +}; + +/*! + * \brief + * CPU implementation of casting a dns tensor to rsp type. + */ +inline void CastStorageDnsRspImpl(mshadow::Stream* s, const TBlob& dns, NDArray* rsp) { + CHECK(rsp != nullptr); + CHECK_EQ(rsp->storage_type(), kRowSparseStorage); + CHECK_EQ(dns.shape_, rsp->shape()); + MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(rsp->aux_type(rowsparse::kIdx), RType, { // row idx type + const index_t num_rows = dns.shape_[0]; + const index_t num_cols = dns.shape_[1]; + rsp->CheckAndAllocAuxData(rowsparse::kIdx, mshadow::Shape1(num_rows)); + TBlob row_idx_blob = rsp->aux_data(rowsparse::kIdx); + RType* row_idx = row_idx_blob.dptr(); + mxnet_op::Kernel::Launch(s, num_rows, row_idx, + dns.dptr(), num_cols); + index_t nnr = 0; + nnr = mxnet::common::ParallelAccumulate(row_idx, num_rows, nnr); + rsp->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); + if (0 == nnr) return; + rsp->CheckAndAllocData(mshadow::Shape2(nnr, num_cols)); + mshadow::Tensor dns_data = dns.FlatTo2D(s); + mshadow::Tensor rsp_data = rsp->data().FlatTo2D(s); + size_t idx = 0; + for (index_t i = 0; i < num_rows; ++i) { + if (row_idx[i] > 0) { + row_idx[idx] = i; + mshadow::Copy(rsp_data[idx], dns_data[i], s); + ++idx; + } + } + }); + }); +} + +// TODO(haibin) Use memcopy instead will be much faster than assigning each individual element +struct CastStorageRspDnsKernel { + template + MSHADOW_XINLINE static void Map(int i, const index_t width, const IType* idx, const DType *data, + DType* dns) { + auto rid = idx[i]; + auto dns_offset = rid * width; + auto rsp_offset = i * width; + for (size_t col = 0; col < width; col++) { + dns[dns_offset + col] = data[rsp_offset + col]; + } + } +}; + +/*! + * \brief This function assumes that the meomry for dns has been allocated already + * since the shape is known at binding stage. + */ +template +void CastStorageRspDnsImpl(mshadow::Stream* s, const NDArray& rsp, TBlob* dns) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(rsp.storage_type(), kRowSparseStorage); + MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { + MSHADOW_IDX_TYPE_SWITCH(rsp.aux_type(rowsparse::kIdx), IType, { + // assign zeros + mxnet_op::Kernel::Launch(s, dns->Size(), dns->dptr()); + if (rsp.storage_initialized()) { + // copy over row by row + auto in_idx = rsp.aux_data(rowsparse::kIdx).FlatTo1D(s).dptr_; + auto in_data = rsp.data().FlatTo2D(s).dptr_; + auto out_data = dns->FlatTo2D(s).dptr_; + auto num_rows = rsp.aux_shape(rowsparse::kIdx).Size(); + auto rsp_shape = rsp.shape(); + auto width = rsp_shape.ProdShape(1, rsp_shape.ndim()); + mxnet_op::Kernel::Launch(s, num_rows, width, in_idx, + in_data, out_data); + } + }); + }); +} + +/*! + * \brief This is the kernel for initializing the indptr in a csr tensor. + */ +struct FillCsrIndPtr { + /*! + * \brief + * \param i the i-th row of the dns tensor + * \param indptr indptr of the csr tensor + * \param dns the dns tensor + * \param num_rows + * \param num_cols + */ + template + MSHADOW_XINLINE static void Map(int i, IType* indptr, const DType* dns, + const int num_rows, const int num_cols) { + indptr[i+1] = 0; + const int offset = i * num_cols; + for (int j = 0; j < num_cols; ++j) { + if (dns[offset+j] != 0) { + ++indptr[i+1]; + } + } + } +}; + +/*! + * \brief This is the kernel for initializing the col_idx and value array + * of the csr tensor + */ +struct FillCsrColIdxAndVals { + /*! + * \brief + * \param i the i-th row of the dns tensor + * \param val value array of the csr + * \param col_idx column idx array of the csr + * \param indptr indptr array of the csr + * \param dns the dns tensor + * \param num_rows number of rows of the dns + * \param num_cols number of columns of the dns + */ + template + MSHADOW_XINLINE static void Map(int i, DType* val, CType* col_idx, + const IType* indptr, const DType* dns, + const int num_rows, const int num_cols) { + const int offset = i * num_cols; + int k = indptr[i]; + for (int j = 0; j < num_cols; ++j) { + if (dns[offset+j] != 0) { + val[k] = dns[offset+j]; + col_idx[k] = j; + ++k; + } + } + } +}; + +/*! + * \brief + * CPU implementation of casting a dns tensor to csr type. + */ +inline void CastStorageDnsCsrImpl(mshadow::Stream* s, const TBlob& dns, NDArray* csr) { + CHECK(csr != nullptr); + CHECK_EQ(csr->storage_type(), kCSRStorage); + CHECK_EQ(dns.shape_.ndim(), 2); + CHECK_EQ(dns.shape_, csr->shape()); + MSHADOW_TYPE_SWITCH(dns.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIndPtr), IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(csr->aux_type(csr::kIdx), CType, { // col idx type + const index_t num_rows = dns.shape_[0]; + const index_t num_cols = dns.shape_[1]; + csr->CheckAndAllocAuxData(csr::kIndPtr, mshadow::Shape1(num_rows+1)); + IType* indptr = csr->aux_data(csr::kIndPtr).dptr(); + DType* dns_data = dns.dptr(); + mxnet_op::Kernel::Launch(s, num_rows, indptr, + dns_data, num_rows, num_cols); + // single thread to accumulate indptr + // indptr[num_rows] indicates the number of non-zero elements + indptr[0] = 0; + for (index_t i = 0; i < num_rows; ++i) { + indptr[i+1] += indptr[i]; + } + // allocate column idx array and value array + csr->CheckAndAllocAuxData(csr::kIdx, + mshadow::Shape1(static_cast(indptr[num_rows]))); + csr->CheckAndAllocData(mshadow::Shape1(static_cast(indptr[num_rows]))); + // fill col_idx and value arrays of the csr + mxnet_op::Kernel::Launch(s, num_rows, + csr->data().dptr(), csr->aux_data(csr::kIdx).dptr(), + indptr, dns_data, num_rows, num_cols); + }); + }); + }); +} + +/*! + * \brief This is the kernel for copying csr.data to its corresponding dns tensor. + */ +struct CopyCsrDataToDns { + /*! + * \brief + * \param i the i-th row of the dns tensor + * \param dns_data data blob of the dns tensor + * \param col_idx column idx array of the csr + * \param indptr indptr array of the csr + * \param csr_data data blob of the csr tensor + * \param num_cols number of columns of the dns + */ + template + MSHADOW_XINLINE static void Map(int i, DType* dns_data, const CType* col_idx, + const IType* indptr, const DType* csr_data, + const int num_cols) { + const int offset = i * num_cols; + for (auto j = indptr[i]; j < indptr[i+1]; ++j) { + dns_data[offset+col_idx[j]] = csr_data[j]; + } + } +}; + +/*! + * \brief Casts a csr tensor to dns format. + */ +template +void CastStorageCsrDnsImpl(mshadow::Stream* s, const NDArray& csr, TBlob* dns) { + CHECK(dns != nullptr); + CHECK_EQ(csr.storage_type(), kCSRStorage); + CHECK_EQ(dns->shape_.ndim(), 2); + CHECK_EQ(dns->shape_, csr.shape()); + MSHADOW_TYPE_SWITCH(dns->type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(csr::kIndPtr), IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(csr.aux_type(csr::kIdx), CType, { // col idx type + const index_t num_rows = dns->shape_[0]; + const index_t num_cols = dns->shape_[1]; + DType* dns_data = dns->dptr(); + mxnet_op::Kernel::Launch(s, dns->shape_.Size(), dns_data); + if (!csr.storage_initialized()) return; + const IType* indptr = csr.aux_data(csr::kIndPtr).dptr(); + const CType* col_idx = csr.aux_data(csr::kIdx).dptr(); + const DType* csr_data = csr.data().dptr(); + mxnet_op::Kernel::Launch(s, num_rows, dns_data, + col_idx, indptr, csr_data, num_cols); + }); + }); + }); +} + +template +void CastStorageComputeImpl(mshadow::Stream* s, + const NDArray& input, + const NDArray& output) { + using namespace mshadow; + using namespace mshadow::expr; + const auto src_stype = input.storage_type(); + const auto dst_stype = output.storage_type(); + if (src_stype == kRowSparseStorage && dst_stype == kDefaultStorage) { + TBlob ret = output.data(); + CastStorageRspDnsImpl(s, input, &ret); + } else if (src_stype == kDefaultStorage && dst_stype == kRowSparseStorage) { + NDArray ret = output; // get rid of the const qualifer + CastStorageDnsRspImpl(s, input.data(), &ret); + } else if (src_stype == kDefaultStorage && dst_stype == kCSRStorage) { + NDArray ret = output; // get rid of the const qualifer + CastStorageDnsCsrImpl(s, input.data(), &ret); + } else if (src_stype == kCSRStorage && dst_stype == kDefaultStorage) { + TBlob ret = output.data(); + CastStorageCsrDnsImpl(s, input, &ret); + } else { + LOG(FATAL) << "Not implemented"; + } +} + +struct CastStorageParam : public dmlc::Parameter { + int stype; + DMLC_DECLARE_PARAMETER(CastStorageParam) { + DMLC_DECLARE_FIELD(stype) + .add_enum("default", kDefaultStorage) + .add_enum("row_sparse", kRowSparseStorage) + .add_enum("csr", kCSRStorage) + .describe("Output storage type."); + } +}; + +inline bool CastStorageInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + CHECK_NE(in_attrs->at(0), kUndefinedStorage) + << "src ndarray's storage type must be specified"; + const CastStorageParam& param = nnvm::get(attrs.parsed); + CHECK_NE(param.stype, kUndefinedStorage) + << "dst ndarray's storage type must be specified"; + TYPE_ASSIGN_CHECK(*out_attrs, 0, param.stype); + return true; +} + +template +void CastStorageComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), 1); + CastStorageComputeImpl(s, inputs[0], outputs[0]); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_TENSOR_CAST_STORAGE_INL_H_ diff --git a/src/operator/tensor/cast_storage.cc b/src/operator/tensor/cast_storage.cc new file mode 100644 index 000000000000..c435146a730b --- /dev/null +++ b/src/operator/tensor/cast_storage.cc @@ -0,0 +1,31 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file cast_storage.cc + * \brief CPU Implementation of cast_storage operator. + */ + +#include "./cast_storage-inl.h" +#include "../elemwise_op_common.h" +#include "../tensor/elemwise_unary_op.h" + +namespace mxnet { +namespace op { + +// TODO(haibin) declare backward op for cast storage +DMLC_REGISTER_PARAMETER(CastStorageParam); +NNVM_REGISTER_OP(cast_storage) +.describe(R"code(Casts tensor storage type to the new type. +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferStorageType", CastStorageInferStorageType) +.set_attr("FCompute", IdentityCompute) +.set_attr("FComputeEx", CastStorageComputeEx) +.add_argument("data", "NDArray-or-Symbol", "The input.") +.add_arguments(CastStorageParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/cast_storage.cu b/src/operator/tensor/cast_storage.cu new file mode 100644 index 000000000000..79f369fb2054 --- /dev/null +++ b/src/operator/tensor/cast_storage.cu @@ -0,0 +1,17 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file cast_storage.cu + * \brief GPU Implementation of cast_storage operator. + */ +#include "./cast_storage-inl.h" +#include "../tensor/elemwise_unary_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(cast_storage) +.set_attr("FCompute", IdentityCompute) +.set_attr("FComputeEx", CastStorageComputeEx); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/dot-inl.cuh b/src/operator/tensor/dot-inl.cuh new file mode 100644 index 000000000000..8960798c7a0c --- /dev/null +++ b/src/operator/tensor/dot-inl.cuh @@ -0,0 +1,382 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file dot-inl.cuh + * \brief implementation of matrix dot op on GPU + */ +#ifndef MXNET_OPERATOR_TENSOR_DOT_INL_CUH_ +#define MXNET_OPERATOR_TENSOR_DOT_INL_CUH_ + +#include +#include + +namespace mxnet { +namespace op { +using mshadow::cuda::kBaseThreadNum; + +/*! + * \brief Scalar kernel of dot(csr, dns1) = dns2 + * Parallelization by output matrix elements: 1 thread/element + */ +template +struct DotCsrDnsDnsScalarKernel { + /*! + * \brief This function represents performing an inner product between a row of lhs + * and a column of rhs and then assigning the value to out[i]. + * \param i i-th element in out 1D view + * \param out output matrix + * \param data_l csr values of lhs + * \param indptr_l csr indptr of lhs + * \param col_idx_l csr col_idx of lhs + * \param data_r dense data of rhs + * \param num_cols number of columns of output + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols) { + const int irow = i / num_cols; // row id of the lhs + const int icol = i % num_cols; // col id of the rhs + DType sum = 0; + for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) { + const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs + sum += data_l[j] * data_r[cur_col*num_cols+icol]; + } + KERNEL_ASSIGN(out[i], req, sum); + } +}; + +/*! + * \brief Vector kernel of dot(csr, dns1) = dns2 + * Parallelization by output matrix elements: 1 warp/element + */ +template +struct DotCsrDnsDnsVectorKernel { + template + __device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols_r) { + __shared__ volatile DType vals[kBaseThreadNum]; + + const int warp_id = tid / 32; // global warp id + const int lane = tid & (32-1); // local thread id within warp + const int irow = warp_id / num_cols_r; // lhs row that this warp computes + const int kcol = warp_id % num_cols_r; // rhs column that this warp computes + + // Range of nnz elements in this row + const int low = static_cast(indptr_l[irow]); + const int high = static_cast(indptr_l[irow+1]); + + // Compute running sum per thread + DType sum = 0; + for (int j = low+lane; j < high; j+=32) { + sum += data_l[j] * data_r[col_idx_l[j]*num_cols_r + kcol]; + } + vals[threadIdx.x] = sum; __syncwarp(); + + // Parallel reduction in shared memory + if (lane < 16) {vals[threadIdx.x] += vals[threadIdx.x+16];} __syncwarp(); + if (lane < 8) {vals[threadIdx.x] += vals[threadIdx.x+ 8];} __syncwarp(); + if (lane < 4) {vals[threadIdx.x] += vals[threadIdx.x+ 4];} __syncwarp(); + if (lane < 2) {vals[threadIdx.x] += vals[threadIdx.x+ 2];} __syncwarp(); + if (lane < 1) {vals[threadIdx.x] += vals[threadIdx.x+ 1];} __syncwarp(); + + if (lane == 0) { + KERNEL_ASSIGN(out[irow*num_cols_r+kcol], req, vals[threadIdx.x]); + } + } +}; + +/*! + * \brief Scalar kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by output matrix elements: 1 thread/element + */ +template +struct DotCsrTransDnsDnsScalarKernel { + /*! + * \brief This function represents performing an inner product between a column of lhs + * and a column of rhs and then assigning the value to out[i]. + * \param i i-th element in out 1D view + * \param out output matrix + * \param data_l csr values of lhs + * \param indptr_l csr indptr of lhs + * \param col_idx_l csr col_idx of lhs + * \param data_r dense data of rhs + * \param num_rows_l number of rows of lhs + * \param num_cols number of columns of outputs + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const int num_rows_l, + const int num_cols) { + const int irow = i / num_cols; // col id of the lhs + const int icol = i % num_cols; // col id of the rhs + DType sum = 0; + + // Each thread scans each column with binary search to find nnz elements in its row + for (int k = 0; k < num_rows_l; ++k) { + const IType low = indptr_l[k]; + const IType high = indptr_l[k+1]; + if (low == high || irow < col_idx_l[low] || irow > col_idx_l[high-1]) continue; + int j = -1, l = low, r = high - 1; + while (l <= r) { + int m = l + (r - l) / 2; + if (col_idx_l[m] == irow) { + j = m; break; + } + if (col_idx_l[m] < irow) { + l = m + 1; + } else { + r = m - 1; + } + } + if (j >= 0) { + sum += data_l[j] * data_r[k*num_cols+icol]; + } + } + KERNEL_ASSIGN(out[i], req, sum); + } +}; + +/*! + * \brief Warp kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by columns: 1 warp computes one lhs column for one rhs column + */ +template +struct DotCsrTransDnsDnsWarpKernel { + template + __device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols_r) { + const int warp_id = tid / 32; // global warp id + const int lane = tid & (32-1); // local thread id within warp + const int icol = warp_id / num_cols_r; // lhs column that this warp computes + const int kcol = warp_id % num_cols_r; // rhs column that this warp computes + + // Compute range of nnz elements in this column + const int low = static_cast(indptr_l[icol]); + const int high = static_cast(indptr_l[icol+1]); + + // Iterate through the nnz elements in this column + for (int j = low+lane; j < high; j+=32) { + const int irow = static_cast(col_idx_l[j]); + const DType val = data_l[j]*data_r[icol*num_cols_r+kcol]; + atomicAdd(static_cast(&(out[irow*num_cols_r+kcol])), val); + } + } +}; + +/*! + * \brief Thread block kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by columns: 1 thread block computes one lhs column for all rhs columns + */ +template +struct DotCsrTransDnsDnsThreadBlockKernel { + template + __device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols_r) { + const int warps_per_block = blockDim.x / 32; // number of warps in this thread block + const int warp_id = tid / 32; // global warp id + const int lane = tid & (32-1); // local thread id within warp + const int icol = blockIdx.x; // lhs column that this thread block computes + const int kcol = warp_id % warps_per_block; // rhs column where warp starts computing (offset) + + // Compute range of nnz elements in this lhs column + const int low = static_cast(indptr_l[icol]); + const int high = static_cast(indptr_l[icol+1]); + + // Iterate through the nnz elements in this lhs column + for (int j = low+lane; j < high; j+=32) { + const int irow = static_cast(col_idx_l[j]); + const DType datum_l = data_l[j]; + // Iterate over rhs columns that this warp computes + for (int k = kcol; k < num_cols_r; k+=warps_per_block) { + const DType val = datum_l*data_r[icol*num_cols_r+k]; + atomicAdd(static_cast(&(out[irow*num_cols_r+k])), val); + } + } + } +}; + +/*! + * \brief Warp block kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by columns: 1 warp computes one lhs column for all rhs columns + */ +template +struct DotCsrTransDnsDnsWarpBlockKernel { + template + __device__ __forceinline__ static void Map(int tid, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const int num_cols_r) { + const int warp_id = tid / 32; // global warp id + const int lane = tid & (32-1); // local thread id within warp + const int icol = warp_id; // lhs column that this warp computes + + // Compute range of nnz elements in this column + const int low = static_cast(indptr_l[icol]); + const int high = static_cast(indptr_l[icol+1]); + + // Iterate through the nnz elements in lhs column + for (int j = low+lane; j < high; j+=32) { + const int irow = static_cast(col_idx_l[j]); + const DType datum_l = data_l[j]; + // Iterate over all rhs columns + for (int k = 0; k < num_cols_r; k++) { + const DType val = datum_l*data_r[icol*num_cols_r+k]; + atomicAdd(static_cast(&(out[irow*num_cols_r+k])), val); + } + } + } +}; + +inline void DotCsrDnsDnsImpl(mshadow::Stream* s, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + TBlob* ret) { + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + if (!lhs.storage_initialized()) return; + + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob& data_r = rhs; + const TBlob data_out = *ret; + + MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + if (kWriteTo == req) { + mxnet_op::Kernel::Launch(s, data_out.Size(), data_out.dptr()); + } + int num_threads; + const int threads_per_warp = 32; + const int threads_per_block = kBaseThreadNum; + const int num_rows_l = lhs.shape()[0]; + const int num_cols_r = rhs.shape_[1]; + if (trans_lhs) { + // Different kernel versions are optimized for different matrix instances + // TODO: switch between kernel versions depending on input + // (1) 'Scalar kernel' (one thread computing one output element ) + // (2) 'Warp kernel' (one warp computing one lhs column for one rhs column ) + // (3) 'Thread block kernel' (one thread block computing one lhs column for all rhs columns) + // (4) 'Warp block kernel' (one warp computing one lhs column for all rhs columns) + const int kernel_version = 0; + switch (kernel_version) { + case 1: + num_threads = data_out.Size(); + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_rows_l, num_cols_r); + }); + break; + case 2: + num_threads = threads_per_warp * num_rows_l * num_cols_r; + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + break; + case 3: + num_threads = threads_per_block * num_rows_l; + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + break; + case 4: + num_threads = threads_per_warp * num_rows_l; + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + break; + default: + num_threads = threads_per_warp * num_rows_l * num_cols_r; + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + break; + } + } else { + // Different kernel versions are optimized for different matrix instances + // (1) 'Scalar kernel' (one thread computing one output element) + // (2) 'Vector kernel' (one warp computing one output element) + const int kernel_version = 0; + switch (kernel_version) { + case 1: + num_threads = data_out.Size(); + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + break; + case 2: + num_threads = threads_per_warp * num_rows_l * num_cols_r; + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + break; + default: + if (num_cols_r > 4) { + num_threads = data_out.Size(); + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + } else { + num_threads = threads_per_warp * num_rows_l * num_cols_r; + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, gpu>::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), num_cols_r); + }); + } + break; + } + } + }); + }); + }); +} + +/*! + * \brief Impl of dot(csr.T, dns) = rsp + */ +inline void DotCsrDnsRspImpl(mshadow::Stream* s, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + LOG(FATAL) << "DotCsrDnsRspImpl gpu version is not implemented."; +} + +/*! + * \brief Impl of dot(csr.T, rsp) = rsp2 + */ +inline void DotCsrRspRspImpl(mshadow::Stream* s, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + LOG(FATAL) << "DotCsrRspRspImpl gpu version is not implemented."; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_TENSOR_DOT_INL_CUH_ diff --git a/src/operator/tensor/dot-inl.h b/src/operator/tensor/dot-inl.h new file mode 100644 index 000000000000..7440128dce09 --- /dev/null +++ b/src/operator/tensor/dot-inl.h @@ -0,0 +1,925 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file dot-inl.h + * \brief Function definition of matrix dot operator + */ + +#ifndef MXNET_OPERATOR_TENSOR_DOT_INL_H_ +#define MXNET_OPERATOR_TENSOR_DOT_INL_H_ + +#include +#include +#include +#include +#include +#include "../mshadow_op.h" +#include "../elemwise_op_common.h" +#include "../mxnet_op.h" +#ifdef __CUDACC__ +#include "./dot-inl.cuh" +#endif // __CUDACC__ + +namespace mxnet { +namespace op { + +struct DotParam : public dmlc::Parameter { + bool transpose_a; + bool transpose_b; + DMLC_DECLARE_PARAMETER(DotParam) { + DMLC_DECLARE_FIELD(transpose_a) + .describe("If true then transpose the first input before dot.") + .set_default(false); + DMLC_DECLARE_FIELD(transpose_b) + .describe("If true then transpose the second input before dot.") + .set_default(false); + } +}; + +template +void DotForward_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const DotParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(outputs[0].type_flag_, kFloat32) + << "dot only support 32 bit float so far"; + + if (inputs[0].ndim() == 1 && inputs[1].ndim() == 1) { + CHECK_NE(req[0], kAddTo) << "AddTo not yet suported"; + Tensor out = outputs[0].get(s); + VectorDot(out, + inputs[0].get(s), + inputs[1].get(s)); + } else { + int ma, na, mb, nb, m, n; + if (param.transpose_a) { + ma = inputs[0].size(0); + na = inputs[0].Size()/ma; + m = na; + } else { + na = inputs[0].size(inputs[0].ndim()-1); + ma = inputs[0].Size()/na; + m = ma; + } + if (param.transpose_b) { + nb = inputs[1].size(inputs[1].ndim()-1); + mb = inputs[1].Size()/nb; + n = mb; + } else { + mb = inputs[1].size(0); + nb = inputs[1].Size()/mb; + n = nb; + } + + Tensor input0 = + inputs[0].get_with_shape(Shape2(ma, na), s); + Tensor input1 = + inputs[1].get_with_shape(Shape2(mb, nb), s); + Tensor out = + outputs[0].get_with_shape(Shape2(m, n), s); + if (param.transpose_a && param.transpose_b) { + ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1.T())); + } else if (!param.transpose_a && param.transpose_b) { + ASSIGN_DISPATCH(out, req[0], dot(input0, input1.T())); + } else if (param.transpose_a && !param.transpose_b) { + ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1)); + } else { + ASSIGN_DISPATCH(out, req[0], dot(input0, input1)); + } + } +} + +template +void DotBackward_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const DotParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + CHECK_NE(req[0], kWriteInplace); + CHECK_NE(req[1], kWriteInplace); + + if (inputs[1].ndim() == 1 && inputs[2].ndim() == 1) { + Tensor mout_grad = inputs[0].get(s); + Tensor mlhs_data = inputs[1].get(s); + Tensor mrhs_data = inputs[2].get(s); + Tensor mlhs_grad = outputs[0].get(s); + Tensor mrhs_grad = outputs[1].get(s); + ASSIGN_DISPATCH(mrhs_grad, req[1], + broadcast_scalar(mout_grad, mlhs_data.shape_) * mlhs_data); + ASSIGN_DISPATCH(mlhs_grad, req[0], + broadcast_scalar(mout_grad, mlhs_data.shape_) * mrhs_data); + } else { + int ma, na, mb, nb, m, n; + if (param.transpose_a) { + ma = outputs[0].size(0); + na = outputs[0].Size()/ma; + m = na; + } else { + na = outputs[0].size(outputs[0].ndim()-1); + ma = outputs[0].Size()/na; + m = ma; + } + if (param.transpose_b) { + nb = outputs[1].size(outputs[1].ndim()-1); + mb = outputs[1].Size()/nb; + n = mb; + } else { + mb = outputs[1].size(0); + nb = outputs[1].Size()/mb; + n = nb; + } + + Tensor mout_grad = + inputs[0].get_with_shape(Shape2(m, n), s); + Tensor mlhs_data = + inputs[1].get_with_shape(Shape2(ma, na), s); + Tensor mrhs_data = + inputs[2].get_with_shape(Shape2(mb, nb), s); + Tensor mlhs_grad = + outputs[0].get_with_shape(Shape2(ma, na), s); + Tensor mrhs_grad = + outputs[1].get_with_shape(Shape2(mb, nb), s); + if (param.transpose_a && param.transpose_b) { + // Gradient of z = dot(x.T, y.T) + // dy = dot(x, dz).T = dot(dz.T, x.T) + // dx = dot(dz, y).T = dot(y.T, dz.T) + ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mout_grad.T(), mlhs_data.T())); + ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mrhs_data.T(), mout_grad.T())); + } else if (!param.transpose_a && param.transpose_b) { + // Gradient of z = dot(x, y.T) + // dy = dot(x.T, dz).T = dot(dz.T, x) + // dx = dot(dz, y) + ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mout_grad.T(), mlhs_data)); + ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mout_grad, mrhs_data)); + } else if (param.transpose_a && !param.transpose_b) { + // Gradient of z = dot(x.T, y) + // dy = dot(x, dz) + // dx = dot(dz, y.T).T = dot(y, dz.T) + ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mlhs_data, mout_grad)); + ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mrhs_data, mout_grad.T())); + } else { + // Gradient of z = dot(x, y) + // dy = dot(x.T, dz) + // dx = dot(dz, y.T) + ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mlhs_data.T(), mout_grad)); + ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mout_grad, mrhs_data.T())); + } + } +} + +inline bool DotForwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const DotParam& param = nnvm::get(attrs.parsed); + // csr has many zero columns, so the result of dot(csr.T, matrix) should be rsp + // dot(csr.T,dns)=rsp not yet implemented on gpu + if (param.transpose_a && kCSRStorage == (*in_attrs)[0] && ctx.dev_type != Context::kGPU) { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kRowSparseStorage); + } else { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); + } + return true; +} + +inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3U); + CHECK_EQ(out_attrs->size(), 2U); + const DotParam& param = nnvm::get(attrs.parsed); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 0, kDefaultStorage); + if (!param.transpose_a && kCSRStorage == (*in_attrs)[1]) { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kRowSparseStorage); + } else { + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, 1, kDefaultStorage); + } + return true; +} + +/*! + * \brief Kernel of dot(csr, dns1) = dns2 + * Parallelization by row blocks + */ +struct DotCsrDnsDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const size_t seg_len, + const size_t num_rows, const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows); + for (size_t j = seg_start; j < seg_end; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_out = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto val = data_l[k]; + const size_t offset_r = col_idx_l[k] * num_cols; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by row blocks + */ +struct DotCsrTransDnsDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const size_t seg_len, + const size_t num_rows_l, const size_t num_rows, + const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (i + 1) * seg_len; + for (size_t j = 0; j < num_rows_l; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_r = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto col_idx = col_idx_l[k]; + if (col_idx < seg_start || col_idx >= seg_end) continue; + const size_t offset_out = col_idx * num_cols; + const auto val = data_l[k]; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns) = rsp + * Parallelization by row blocks. + * This kernel fills up the row_idx array + * of the rsp with 1 for nonzero rows and 0 + * for zero rows. + * The matrix will be compacted after this kernel call. + */ +struct DotCsrTransDnsRspByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx, const DType* data_l, + const IType* indptr_l, const CType* col_idx_l, + const DType* data_r, const size_t seg_len, + const size_t num_rows_l, const size_t num_rows, + const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (i + 1) * seg_len; + for (size_t j = 0; j < num_rows_l; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_r = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto col_idx = col_idx_l[k]; + if (col_idx < seg_start || col_idx >= seg_end) continue; + const size_t offset_out = col_idx * num_cols; + row_idx[col_idx] = 1; + const auto val = data_l[k]; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr, rsp) = dns + * Parallelization by row blocks + */ +struct DotCsrRspDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + * \param nnr_r storage_shape[0] of the rsp + * \param num_rows dns.shape[0] + * \param num_cols dns.shape[1] + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, + const IType* indptr_l, const CType* col_idx_l, + const DType* data_r, const RType* row_idx_r, + const size_t nnr_r, const size_t num_rows, + const size_t num_cols, const size_t seg_len) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows); + for (size_t j = seg_start; j < seg_end; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_out = j * num_cols; + // Use binary search to find the lower_bound of val in row_idx array + const RType* first = row_idx_r; + const RType* last = row_idx_r + nnr_r; + const auto val = col_idx_l[indptr_l[j]]; + const RType* it; + int count = last - first, step; + while (count > 0) { + it = first; + step = count / 2; + it += step; + if (*it < val) { + first = ++it; + count -= step + 1; + } else { + count = step; + } + } + const RType* row_idx_ptr = first; + // end of binary search + if (row_idx_ptr == row_idx_r+nnr_r || *row_idx_ptr> col_idx_l[indptr_l[j+1]-1]) continue; + for (auto k = indptr_l[j]; k < indptr_l[j+1] && row_idx_ptr != row_idx_r+nnr_r;) { + if (col_idx_l[k] == *row_idx_ptr) { + const size_t offset_r = (row_idx_ptr - row_idx_r) * num_cols; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_l[k] * data_r[offset_r+l]; + } + ++k; + ++row_idx_ptr; + } else if (col_idx_l[k] < *row_idx_ptr) { + ++k; + } else { + ++row_idx_ptr; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr.T(), rsp) = dns with row_idx marked for non-zero rows + * Parallelization by row blocks + */ +struct DotCsrTransRspRspByRowBlocks { + /*! + * \brief + * \param i the i-th thread + * \param num_rows_l number of rows of lhs matrix + * \param nnr_r number of non-zero rows of rhs matrix + * \param num_rows number of rows of out matrix + * \param num_cols number of cols of out matrix + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, RType* row_idx_out, + const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, + const RType* row_idx_r, const size_t num_rows_l, + const size_t nnr_r, const size_t num_rows, + const size_t num_cols, const size_t seg_len) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (i + 1) * seg_len; + for (size_t rid = 0; rid < nnr_r; ++rid) { + const auto j = row_idx_r[rid]; + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_r = rid * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto col_idx = col_idx_l[k]; + if (col_idx < seg_start || col_idx >= seg_end) continue; + row_idx_out[col_idx] = 1; // mark nonzero row as 1 + const size_t offset_out = col_idx * num_cols; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * data_l[k]; + } + } + } + } +}; + +inline void DotCsrDnsDnsImpl(mshadow::Stream* s, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + TBlob* ret) { + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + if (!lhs.storage_initialized()) return; + + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob& data_r = rhs; + const TBlob data_out = *ret; + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), seg_len, + lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); + } else { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), seg_len, + data_out.shape_[0], data_out.shape_[1]); + } + }); + }); + }); +} + +/*! + * \brief Impl of dot(csr, rsp) + */ +inline void DotCsrDnsRspImpl(mshadow::Stream* s, + const NDArray& lhs, + const TBlob& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + CHECK_EQ(ret->storage_type(), kRowSparseStorage); + if (!lhs.storage_initialized()) return; + + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob& data_r = rhs; + + // pre-allocate spaces for ret using the dense dimension size + ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])}); + const TBlob data_out = ret->data(); + const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + MSHADOW_IDX_TYPE_SWITCH(row_idx_out.type_flag_, RType, { // col idx type + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + RType* row_idx = row_idx_out.dptr(); + mxnet_op::Kernel::Launch( + s, row_idx_out.Size(), row_idx); + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), row_idx, data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + seg_len, lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); + index_t nnr = 0; + nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); + ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); + ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); + if (0 == nnr) return; + mshadow::Tensor rsp_data = data_out.FlatTo2D(s); + size_t idx = 0; + for (index_t i = 0; i < ret->shape()[0]; ++i) { + if (row_idx[i] > 0) { + row_idx[idx] = i; + mshadow::Copy(rsp_data[idx], rsp_data[i], s); + ++idx; + } + } + } else { + LOG(FATAL) << "DotCsrDnsRspImpl has not implemented dot(csr, dns)=rsp yet." + " Only the cpu version of dot(csr.T, dns)=rsp is supported now"; + } + }); + }); + }); + }); +} + +template +void DotCsrRspDnsImpl(mshadow::Stream* s, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + TBlob* ret) { + // reuse csr dns implementation when storage_shape == shape for rhs + if (rhs.storage_shape()[0] == rhs.shape()[0]) { // if rsp is actually dense + DotCsrDnsDnsImpl(s, lhs, rhs.data(), req, trans_lhs, ret); + return; + } + + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + CHECK_EQ(rhs.storage_type(), kRowSparseStorage); + if (!lhs.storage_initialized() || !rhs.storage_initialized()) { + if (kWriteTo == req) { + MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { // data type + mxnet_op::Kernel::Launch( + s, ret->Size(), ret->dptr()); + }); + } + return; + } + + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob data_r = rhs.data(); + const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx); + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // col idx type + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, ret->Size(), ret->dptr()); + } + int num_threads = mxnet_op::get_num_threads(ret->shape_[0]); + size_t seg_len = (ret->shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + LOG(FATAL) << "DotCsrRspDnsImpl has not implemented dot(csr.T, rsp) = dns yet"; + } else { + mxnet_op::Kernel::Launch(s, num_threads, + ret->dptr(), data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + row_idx_r.dptr(), rhs.storage_shape()[0], + ret->shape_[0], ret->shape_[1], seg_len); + } + }); + }); + }); + }); +} + +/*! + * \brief Impl of dot(csr.T, rsp) = rsp2 + */ +inline void DotCsrRspRspImpl(mshadow::Stream* s, + const NDArray& lhs, + const NDArray& rhs, + const OpReqType req, + const bool trans_lhs, + NDArray* ret) { + // reuse csr dns implementation when storage_shape == shape for rhs + if (rhs.storage_shape()[0] == rhs.shape()[0]) { // if rsp is actually dense + DotCsrDnsRspImpl(s, lhs, rhs.data(), req, trans_lhs, ret); + return; + } + + if (kNullOp == req) return; + CHECK_EQ(lhs.storage_type(), kCSRStorage); + CHECK_EQ(rhs.storage_type(), kRowSparseStorage); + CHECK_EQ(ret->storage_type(), kRowSparseStorage); + if (!lhs.storage_initialized() || !rhs.storage_initialized()) return; + + const TBlob data_l = lhs.data(); + const TBlob indptr_l = lhs.aux_data(csr::kIndPtr); + const TBlob col_idx_l = lhs.aux_data(csr::kIdx); + const TBlob data_r = rhs.data(); + const TBlob row_idx_r = rhs.aux_data(rowsparse::kIdx); + + // pre-allocate spaces for ret using the dense dimension size + if (ret->storage_type() == kRowSparseStorage) { + ret->CheckAndAlloc({mshadow::Shape1(lhs.shape()[1])}); + } + const TBlob data_out = ret->data(); + const TBlob row_idx_out = ret->aux_data(rowsparse::kIdx); + + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + MSHADOW_IDX_TYPE_SWITCH(row_idx_r.type_flag_, RType, { // col idx type + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; + if (trans_lhs) { + RType* row_idx = row_idx_out.dptr(); + mxnet_op::Kernel::Launch( + s, row_idx_out.Size(), row_idx); + mxnet_op::Kernel::Launch(s, num_threads, + data_out.dptr(), row_idx, data_l.dptr(), + indptr_l.dptr(), col_idx_l.dptr(), data_r.dptr(), + row_idx_r.dptr(), lhs.shape()[0], rhs.storage_shape()[0], + ret->shape()[0], ret->shape()[1], seg_len); + index_t nnr = 0; + nnr = mxnet::common::ParallelAccumulate(row_idx, ret->shape()[0], nnr); + ret->set_aux_shape(rowsparse::kIdx, mshadow::Shape1(nnr)); + ret->set_storage_shape(mshadow::Shape2(nnr, ret->shape()[1])); + if (0 == nnr) return; + mshadow::Tensor rsp_data = data_out.FlatTo2D(s); + size_t idx = 0; + for (index_t i = 0; i < ret->shape()[0]; ++i) { + if (row_idx[i] > 0) { + row_idx[idx] = i; + mshadow::Copy(rsp_data[idx], rsp_data[i], s); + ++idx; + } + } + } else { + LOG(FATAL) << "DotCsrRspRspImpl has not implemented dot(csr.T, rsp) = rsp2 yet"; + } + }); + }); + }); + }); +} + +inline bool DotShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const DotParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + TShape& lshape = (*in_attrs)[0]; + TShape& rshape = (*in_attrs)[1]; + if (lshape.ndim() == 1 && rshape.ndim() == 1) { + CHECK(!param.transpose_a && !param.transpose_b) << "Cannot transpose vectors"; + CHECK_EQ(lshape[0], rshape[0]) << "dot shape error: " << lshape << " X " << rshape; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape1(1)); + } else { + bool Ta = param.transpose_a, Tb = param.transpose_b; + TShape L[2], R[2]; + if (Ta) { + L[0] = mshadow::Shape1(lshape[0]); + L[1] = lshape.ndim() > 1 ? TShape(&lshape[1], &lshape[lshape.ndim()]) : TShape(1); + } else { + L[0] = lshape.ndim() > 1 ? TShape(&lshape[0], &lshape[lshape.ndim()-1]) : TShape(1); + L[1] = mshadow::Shape1(lshape[lshape.ndim()-1]); + } + if (Tb) { + R[0] = rshape.ndim() > 1 ? TShape(&rshape[0], &rshape[rshape.ndim()-1]) : TShape(1); + R[1] = mshadow::Shape1(rshape[rshape.ndim()-1]); + } else { + R[0] = mshadow::Shape1(rshape[0]); + R[1] = rshape.ndim() > 1 ? TShape(&rshape[1], &rshape[rshape.ndim()]) : TShape(1); + } + + if (L[!Ta].Size() != 0 && R[Tb].Size() != 0) { + CHECK_EQ(L[!Ta].Size(), R[Tb].Size()) + << "dot shape error: " << lshape << " X " << rshape; + } + std::vector buf; + if (lshape.ndim() > 1) buf.insert(buf.end(), &L[Ta][0], &L[Ta][L[Ta].ndim()]); + if (rshape.ndim() > 1) buf.insert(buf.end(), &R[!Tb][0], &R[!Tb][R[!Tb].ndim()]); + TShape oshape(buf.begin(), buf.end()); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); + } + return true; +} + +template +void DotForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const DotParam& param = nnvm::get(attrs.parsed); + CHECK(!param.transpose_b) << "tranposing rhs of the op dot is not supported"; + auto lhs_stype = inputs[0].storage_type(); + auto rhs_stype = inputs[1].storage_type(); + auto out_stype = outputs[0].storage_type(); + mshadow::Stream* s = ctx.get_stream(); + if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage && out_stype == kDefaultStorage) { + TBlob ret = outputs[0].data(); + DotCsrDnsDnsImpl(s, inputs[0], inputs[1].data(), req[0], param.transpose_a, &ret); + } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage + && out_stype == kDefaultStorage) { + TBlob ret = outputs[0].data(); + DotCsrRspDnsImpl(s, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + } else if (lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage + && out_stype == kRowSparseStorage) { + NDArray out = outputs[0]; + DotCsrDnsRspImpl(s, inputs[0], inputs[1].data(), req[0], param.transpose_a, &out); + } else if (lhs_stype == kCSRStorage && rhs_stype == kRowSparseStorage + && out_stype == kRowSparseStorage) { + NDArray ret = outputs[0]; + DotCsrRspRspImpl(s, inputs[0], inputs[1], req[0], param.transpose_a, &ret); + } else { + FCompExFallback(attrs, ctx, inputs, req, outputs, DotForward_, "DotForward_"); + } +} + +template +void DotBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + CHECK_EQ(kNullOp, req[0]) + << "sparse dot does not support computing the gradient of the csr/lhs"; + CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; + + const DotParam& param = nnvm::get(attrs.parsed); + CHECK(!param.transpose_b) << "sparse dot only supports dot(A, X) and dot(A.T(), X)"; + const auto ograd_stype = inputs[0].storage_type(); + const auto lhs_stype = inputs[1].storage_type(); + const auto rhs_stype = inputs[2].storage_type(); + const auto grad_rhs_stype = outputs[1].storage_type(); + mshadow::Stream* s = ctx.get_stream(); + if (ograd_stype == kDefaultStorage // ograd dns format + && lhs_stype == kCSRStorage // csr input lhs of the op + && grad_rhs_stype == kDefaultStorage) { // grad(rhs) dns format + TBlob ret = outputs[1].data(); + DotCsrDnsDnsImpl(s, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); + } else if (ograd_stype == kDefaultStorage + && lhs_stype == kCSRStorage + && grad_rhs_stype == kRowSparseStorage) { + NDArray ret = outputs[1]; + DotCsrDnsRspImpl(s, inputs[1], inputs[0].data(), req[1], !param.transpose_a, &ret); + } else { + FCompExFallback(attrs, ctx, inputs, req, outputs, DotBackward_, "DotBackward_"); + } +} + +template +void BatchDotForward_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow::expr; + mshadow::Stream *s = ctx.get_stream(); + const DotParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_) + << "Binary function only support input/output with the same type"; + CHECK_EQ(outputs[0].type_flag_, mshadow::kFloat32) + << "dot only support 32 bit float so far"; + + mshadow::Tensor out = outputs[0].get(s); + mshadow::Tensor mlhs = inputs[0].get(s); + mshadow::Tensor mrhs = inputs[1].get(s); + mshadow::Tensor workspace = + ctx.requested[0].get_space_typed(mshadow::Shape1(3 * out.size(0)), s); + if (kNullOp != req[0]) { + if (param.transpose_a && param.transpose_b) { + mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + workspace); + } else if (!param.transpose_a && param.transpose_b) { + mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + workspace); + } else if (param.transpose_a && !param.transpose_b) { + mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + workspace); + } else { + mshadow::BatchGEMM(out, mlhs, mrhs, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + workspace); + } + } +} + +template +void BatchDotBackward_(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow::expr; + mshadow::Stream *s = ctx.get_stream(); + const DotParam& param = nnvm::get(attrs.parsed); + CHECK_NE(req[1], kWriteInplace); + CHECK_NE(req[0], kWriteInplace); + + mshadow::Tensor mout_grad = inputs[0].get(s); + mshadow::Tensor mlhs_data = inputs[1].get(s); + mshadow::Tensor mrhs_data = inputs[2].get(s); + mshadow::Tensor mlhs_grad = outputs[0].get(s); + mshadow::Tensor mrhs_grad = outputs[1].get(s); + mshadow::Tensor workspace = + ctx.requested[0].get_space_typed( + mshadow::Shape2(2, 3 * mout_grad.size(0)), s); + mshadow::Tensor rhs_workspace = workspace[0]; + mshadow::Tensor lhs_workspace = workspace[1]; + if (param.transpose_a && param.transpose_b) { + // Gradient of z = dot(x.T, y.T) + // dy = dot(x, dz).T = dot(dz.T, x.T) + // dx = dot(dz, y).T = dot(y.T, dz.T) + if (kNullOp != req[1]) { + mshadow::BatchGEMM(mrhs_grad, mout_grad, mlhs_data, 1.0f, + (kAddTo == req[1]) ? 1.0f : 0.0f, + rhs_workspace); + } + if (kNullOp != req[0]) { + mshadow::BatchGEMM(mlhs_grad, mrhs_data, mout_grad, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + lhs_workspace); + } + } else if (!param.transpose_a && param.transpose_b) { + // Gradient of z = dot(x, y.T) + // dy = dot(x.T, dz).T = dot(dz.T, x) + // dx = dot(dz, y) + if (kNullOp != req[1]) { + mshadow::BatchGEMM(mrhs_grad, mout_grad, mlhs_data, 1.0f, + (kAddTo == req[1]) ? 1.0f : 0.0f, + rhs_workspace); + } + if (kNullOp != req[0]) { + mshadow::BatchGEMM(mlhs_grad, mout_grad, mrhs_data, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + lhs_workspace); + } + } else if (param.transpose_a && !param.transpose_b) { + // Gradient of z = dot(x.T, y) + // dy = dot(x, dz) + // dx = dot(dz, y.T).T = dot(y, dz.T) + if (kNullOp != req[1]) { + mshadow::BatchGEMM(mrhs_grad, mlhs_data, mout_grad, 1.0f, + (kAddTo == req[1]) ? 1.0f : 0.0f, + rhs_workspace); + } + if (kNullOp != req[0]) { + mshadow::BatchGEMM(mlhs_grad, mrhs_data, mout_grad, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + lhs_workspace); + } + } else { + // Gradient of z = dot(x, y) + // dy = dot(x.T, dz) + // dx = dot(dz, y.T) + if (kNullOp != req[1]) { + mshadow::BatchGEMM(mrhs_grad, mlhs_data, mout_grad, 1.0f, + (kAddTo == req[1]) ? 1.0f : 0.0f, + rhs_workspace); + } + if (kNullOp != req[0]) { + mshadow::BatchGEMM(mlhs_grad, mout_grad, mrhs_data, 1.0f, + (kAddTo == req[0]) ? 1.0f : 0.0f, + lhs_workspace); + } + } +} + +inline bool BatchDotShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const DotParam& param = nnvm::get(attrs.parsed); + TShape& lshape = (*in_attrs)[0]; + TShape& rshape = (*in_attrs)[1]; + if (lshape.ndim() == 3 && rshape.ndim() == 3) { + CHECK(lshape[0] == rshape[0]) + << "batch_dot shape error(batch_size must be equal): " << lshape << " X " << rshape + << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; + index_t out_m = param.transpose_a ? lshape[2] : lshape[1]; + index_t lshape_k = param.transpose_a ? lshape[1] : lshape[2]; + index_t out_n = param.transpose_b ? rshape[1] : rshape[2]; + index_t rshape_k = param.transpose_b ? rshape[2] : rshape[1]; + CHECK(lshape_k == rshape_k) + << "batch_dot shape error(shape mismatch): " << lshape << " X " << rshape + << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape3(lshape[0], out_m, out_n)); + } else { + LOG(FATAL) << "batch_dot currently only support 3D*3D array" + << lshape << " v.s. " << rshape; + } + return true; +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_TENSOR_DOT_INL_H_ diff --git a/src/operator/tensor/dot.cc b/src/operator/tensor/dot.cc new file mode 100644 index 000000000000..fc476a75eec8 --- /dev/null +++ b/src/operator/tensor/dot.cc @@ -0,0 +1,114 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file dot.cc + * \brief CPU Implementation of matrix dot + */ + +#include "./dot-inl.h" + +namespace mxnet { +namespace op { +DMLC_REGISTER_PARAMETER(DotParam); + +NNVM_REGISTER_OP(dot) +.describe(R"doc(Dot product of two arrays. + +``dot``'s behavior depends on the input array dimensions: + +- 1-D arrays: inner product of vectors +- 2-D arrays: matrix multiplication +- N-D arrays: a sum product over the last axis of the first input and the first + axis of the second input + + For example, given 3-D ``x`` with shape `(n,m,k)` and ``y`` with shape `(k,r,s)`, the + result array will have shape `(n,m,r,s)`. It is computed by:: + + dot(x,y)[i,j,a,b] = sum(x[i,j,:]*y[:,a,b]) + + Example:: + + x = reshape([0,1,2,3,4,5,6,7], shape=(2,2,2)) + y = reshape([7,6,5,4,3,2,1,0], shape=(2,2,2)) + dot(x,y)[0,0,1,1] = 0 + sum(x[0,0,:]*y[:,1,1]) = 0 +)doc" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; + }) +.set_attr("FInferShape", DotShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FInferStorageType", DotForwardInferStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", DotForward_) +.set_attr("FComputeEx", DotForwardEx) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_dot"}) +.add_argument("lhs", "NDArray-or-Symbol", "The first input") +.add_argument("rhs", "NDArray-or-Symbol", "The second input") +.add_arguments(DotParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_dot) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FInferStorageType", DotBackwardInferStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", DotBackward_) +.set_attr("FComputeEx", DotBackwardEx) +.add_arguments(DotParam::__FIELDS__()); + +NNVM_REGISTER_OP(batch_dot) +.describe(R"doc(Batchwise dot product. + +``batch_dot`` is used to compute dot product of ``x`` and ``y`` when ``x`` and +``y`` are data in batch, namely 3D arrays in shape of `(batch_size, :, :)`. + +For example, given ``x`` with shape `(batch_size, n, m)` and ``y`` with shape +`(batch_size, m, k)`, the result array will have shape `(batch_size, n, k)`, +which is computed by:: + + batch_dot(x,y)[i,:,:] = dot(x[i,:,:], y[i,:,:]) + +)doc" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; + }) +.set_attr("FInferShape", BatchDotShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BatchDotForward_) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_batch_dot"}) +.add_argument("lhs", "NDArray-or-Symbol", "The first input") +.add_argument("rhs", "NDArray-or-Symbol", "The second input") +.add_arguments(DotParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_batch_dot) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr_parser(ParamParser) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("TIsBackward", true) +.set_attr("FCompute", BatchDotBackward_); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/dot.cu b/src/operator/tensor/dot.cu new file mode 100644 index 000000000000..ae00566d5d45 --- /dev/null +++ b/src/operator/tensor/dot.cu @@ -0,0 +1,27 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file dot.cu + * \brief GPU Implementation of matrix dot + */ + +#include "./dot-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(dot) +.set_attr("FCompute", DotForward_) +.set_attr("FComputeEx", DotForwardEx); + +NNVM_REGISTER_OP(_backward_dot) +.set_attr("FCompute", DotBackward_) +.set_attr("FComputeEx", DotBackwardEx); + +NNVM_REGISTER_OP(batch_dot) +.set_attr("FCompute", BatchDotForward_); + +NNVM_REGISTER_OP(_backward_batch_dot) +.set_attr("FCompute", BatchDotBackward_); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc index 27a4b5f25c82..95df985266ff 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_broadcast_op_basic.cc @@ -105,6 +105,7 @@ Example:: .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"}); + NNVM_REGISTER_OP(_backward_broadcast_mul) .set_num_inputs(3) .set_num_outputs(2) diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index 6062febe2d9e..222b0d1ffc31 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -10,10 +10,11 @@ #include #include #include +#include #include "../mxnet_op.h" #include "../mshadow_op.h" #include "../elemwise_op_common.h" -#include "../mxnet_op.h" +#include "../../common/utils.h" namespace mxnet { namespace op { @@ -123,6 +124,109 @@ void BinaryBackwardUseNone_(const nnvm::NodeAttrs& attrs, } } +// TODO(haibin) This is a single-thread inefficient implementation +// Binary Compute between two row-sparse ndarray +// This implementation only works on CPU +template +void BinaryComputeRspRsp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + auto &lhs = inputs[0]; + auto &rhs = inputs[1]; + auto &output = outputs[0]; + + bool init_l = lhs.storage_initialized(); + bool init_r = rhs.storage_initialized(); + // both inputs are zeros + if (!init_l && !init_r) return; + // Memory Estimation: This is (roughly) the number of result rows. We still + // need to subtract the number of common rows + unsigned int num_rows_l = lhs.aux_shape(rowsparse::kIdx).Size(); + unsigned int num_rows_r = rhs.aux_shape(rowsparse::kIdx).Size(); + output.CheckAndAlloc({mshadow::Shape1(num_rows_l + num_rows_r)}); + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(output.dtype(), DType, { + MSHADOW_TYPE_SWITCH(lhs.aux_type(rowsparse::kIdx), IType, { + // Indices + auto indices_l = lhs.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto indices_r = rhs.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto indices_out = output.aux_data(rowsparse::kIdx).FlatTo1D(s); + // Data + auto data_l = lhs.data().FlatTo2D(s); + auto data_r = rhs.data().FlatTo2D(s); + auto out = output.data().FlatTo2D(s); + + // TODO(haibin) A more appropriate way: Copy to output, then apply ops + size_t iter_l = 0; + size_t iter_r = 0; + size_t iter_out = 0; + int32_t num_common_rows = 0; + while (iter_l < num_rows_l && iter_r < num_rows_r) { + auto idx_l = indices_l[iter_l]; + auto idx_r = indices_r[iter_r]; + if (idx_l == idx_r) { + // Same row + indices_out[iter_out] = idx_l; + mshadow::Copy(out[iter_out], data_l[iter_l++], s); + out[iter_out] += data_r[iter_r++]; + num_common_rows++; + } else if (idx_l < idx_r) { + // Left only + indices_out[iter_out] = idx_l; + mshadow::Copy(out[iter_out], data_l[iter_l++], s); + } else { + // Right only + indices_out[iter_out] = idx_r; + mshadow::Copy(out[iter_out], data_r[iter_r++], s); + } + iter_out++; + } + // Copying over the rest of the rows + while (iter_l < num_rows_l) { + indices_out[iter_out] = indices_l[iter_l]; + mshadow::Copy(out[iter_out++], data_l[iter_l++], s); + } + while (iter_r < num_rows_r) { + indices_out[iter_out] = indices_r[iter_r]; + mshadow::Copy(out[iter_out++], data_r[iter_r++], s); + } + auto new_shape = output.aux_shape(rowsparse::kIdx); + new_shape[0] -= num_common_rows; + output.set_aux_shape(rowsparse::kIdx, new_shape); + }); + }); +} + +template +void BinaryComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2); + CHECK_EQ(outputs.size(), 1); + if (typeid(OP) == typeid(mshadow::op::plus)) { + // If any input is dense, fallback to FCompute + // TODO(haibin) implement dns + rsp in a separate kernel + if (mxnet::common::ContainsDefaultStorage(inputs)) { + FCompExFallback(attrs, ctx, inputs, req, outputs, + BinaryCompute, "BinaryCompute"); + return; + } + CHECK_EQ(inputs[0].storage_type(), kRowSparseStorage) << "Sparse type not supported yet"; + CHECK_EQ(inputs[1].storage_type(), kRowSparseStorage) << "Sparse type not supported yet"; + BinaryComputeRspRsp(attrs, ctx, inputs, req, outputs); + return; + } else { + LOG(FATAL) << "Not implemented"; + } +} + template void BinaryBackwardUseNone(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -134,6 +238,55 @@ void BinaryBackwardUseNone(const nnvm::NodeAttrs& attrs, }); } +// Only implemented for _backward_add for now +template +void BinaryBackwardUseNoneRsp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs[0].storage_type(), kRowSparseStorage); + CHECK_EQ(outputs[0].storage_type(), kRowSparseStorage); + CHECK_EQ(outputs[1].storage_type(), kRowSparseStorage); + CHECK(typeid(LOP) == typeid(mshadow_op::identity)); + CHECK(typeid(ROP) == typeid(mshadow_op::identity)); + TShape shape = inputs[0].aux_shape(rowsparse::kIdx); + outputs[0].CheckAndAlloc({shape}); + outputs[1].CheckAndAlloc({shape}); + MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, { + MSHADOW_TYPE_SWITCH(outputs[0].aux_type(rowsparse::kIdx), IType, { + auto lgrad_idx = outputs[0].aux_data(rowsparse::kIdx).FlatTo1D(s); + auto rgrad_idx = outputs[1].aux_data(rowsparse::kIdx).FlatTo1D(s); + auto ograd_idx = inputs[0].aux_data(rowsparse::kIdx).FlatTo1D(s); + auto lgrad = outputs[0].data().FlatTo1D(s); + Tensor rgrad = outputs[1].data().FlatTo1D(s); + Tensor ograd = inputs[0].data().FlatTo1D(s); + ASSIGN_DISPATCH(lgrad, req[0], F(ograd)); + ASSIGN_DISPATCH(rgrad, req[1], F(ograd)); + ASSIGN_DISPATCH(lgrad_idx, req[0], F(ograd_idx)); + ASSIGN_DISPATCH(rgrad_idx, req[1], F(ograd_idx)); + }); + }); +} +// Only implemented for _backward_add for now +template +void BinaryBackwardUseNoneEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + auto stype = inputs[0].storage_type(); + CHECK_EQ(stype, kRowSparseStorage) << "Not implemented yet"; + BinaryBackwardUseNoneRsp(attrs, ctx, inputs, req, outputs); + // TODO(haibin) fallback for kDefaultStorage +} + template void BinaryBackwardUseNoneWithHalf2(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -214,7 +367,7 @@ void BinaryBackwardUseInWithHalf2(const nnvm::NodeAttrs& attrs, [](const NodeAttrs& attrs){ \ return std::vector >{{0, 0}, {1, 0}}; \ }) \ - .add_argument("lhs", "NDArray-or-Symbol", "first input") \ + .add_argument("lhs", "NDArray-or-Symbol", "first input") \ .add_argument("rhs", "NDArray-or-Symbol", "second input") } // namespace op diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index 635f2a8692aa..6df538152698 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -13,6 +13,8 @@ MXNET_OPERATOR_REGISTER_BINARY(elemwise_add) .describe("Adds arguments element-wise.") .set_attr("FCompute", BinaryCompute) .set_attr("FGradient", CloneGradient{"_backward_add"}); +.set_attr("FComputeEx", BinaryComputeEx) +.set_attr("FInferStorageType", ElemwiseStorageType<2, 1>); // specialized gradient add function to do add to optimization // this must differ from elemwise_add to prevent add to optimization in forward pass. @@ -28,7 +30,10 @@ NNVM_REGISTER_OP(_backward_add) return std::vector >{{0, 0}, {0, 1}}; }) .set_attr("FCompute", BinaryBackwardUseNone); + mshadow_op::identity>) +.set_attr("FComputeEx", + BinaryBackwardUseNoneEx) +.set_attr("FInferStorageType", ElemwiseStorageType<1, 2>); MXNET_OPERATOR_REGISTER_BINARY(_sub) .add_alias("_minus").add_alias("_Minus") diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu index 6355c4e5cf01..ffdb57a1f8b0 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_op_basic.cu @@ -9,7 +9,8 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(elemwise_add) -.set_attr("FCompute", BinaryComputeWithHalf2); +.set_attr("FCompute", BinaryComputeWithHalf2) +.set_attr("FComputeEx", BinaryComputeEx); NNVM_REGISTER_OP(_grad_add) .set_attr("FCompute", BinaryComputeWithHalf2); @@ -17,7 +18,9 @@ NNVM_REGISTER_OP(_grad_add) NNVM_REGISTER_OP(_backward_add) .set_attr("FCompute", BinaryBackwardUseNoneWithHalf2); + mshadow_op::identity, mshadow_op::identity>) +.set_attr("FComputeEx", + BinaryBackwardUseNoneEx); NNVM_REGISTER_OP(_sub) .set_attr("FCompute", BinaryComputeWithHalf2); diff --git a/src/operator/tensor/elemwise_unary_op.cc b/src/operator/tensor/elemwise_unary_op.cc index ff03846ab5b3..d3749fb7a2ec 100644 --- a/src/operator/tensor/elemwise_unary_op.cc +++ b/src/operator/tensor/elemwise_unary_op.cc @@ -144,7 +144,9 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs) .set_attr("FIgnoreInputs", [](const NodeAttrs& attrs) { return std::vector(1, 1); }) .set_attr("FCompute", IdentityCompute) +.set_attr("FComputeEx", IdentityLikeRhsComputeEx) .set_attr("FInferShape", ElemwiseShape<2, 1>) +.set_attr("FInferStorageType", IdentityAttrLikeRhsStorageType) .set_attr( "FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { @@ -201,6 +203,7 @@ NNVM_REGISTER_OP(_backward_cast) }) .set_attr("FCompute", CastCompute); + // negative MXNET_OPERATOR_REGISTER_UNARY(negative) .MXNET_DESCRIBE("Numerical negative of the argument, element-wise.") diff --git a/src/operator/tensor/elemwise_unary_op.cu b/src/operator/tensor/elemwise_unary_op.cu index 67ceb1ce5093..6da7ceff16ac 100644 --- a/src/operator/tensor/elemwise_unary_op.cu +++ b/src/operator/tensor/elemwise_unary_op.cu @@ -35,7 +35,9 @@ NNVM_REGISTER_OP(make_loss) // identity output as first input, but attributes are constrainted to be like rhs NNVM_REGISTER_OP(_identity_with_attr_like_rhs) -.set_attr("FCompute", IdentityCompute); +.set_attr("FCompute", IdentityCompute) +.set_attr("FComputeEx", IdentityLikeRhsComputeEx); + NNVM_REGISTER_OP(Cast) .set_attr("FCompute", CastCompute); diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 97a7e36535f0..f3aab781eddb 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -13,15 +13,16 @@ #include "../mshadow_op.h" #include "../elemwise_op_common.h" #include "../special_functions-inl.h" +#include "./broadcast_reduce-inl.h" namespace mxnet { namespace op { template void UnaryLaunch(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mxnet_op; Stream *s = ctx.get_stream(); @@ -77,6 +78,54 @@ void IdentityCompute(const nnvm::NodeAttrs& attrs, }); } +template +void IdentityComputeRsp(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + auto &input = inputs[0]; + auto &output = outputs[0]; + CHECK_NE(req[0], kNullOp) << "kNullOp in IdentityComputeEx not supported yet"; + CHECK_NE(req[0], kWriteInplace) << "kWriteInplace in IdentityComputeEx not supported yet"; + if (!input.storage_initialized()) return; + TShape shape = input.aux_shape(rowsparse::kIdx); + output.CheckAndAlloc({shape}); + MSHADOW_TYPE_SWITCH(output.dtype(), DType, { + MSHADOW_TYPE_SWITCH(output.aux_type(rowsparse::kIdx), AuxType, { + auto out_d = output.data().FlatTo1D(s); + auto out_aux = output.aux_data(rowsparse::kIdx).FlatTo1D(s); + auto in_aux = input.aux_data(rowsparse::kIdx).FlatTo1D(s); + ASSIGN_DISPATCH(out_d, req[0], + F(input.data().FlatTo1D(s))); + ASSIGN_DISPATCH(out_aux, req[0], F(in_aux)); + }); + }); +} + +template +void IdentityLikeRhsComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(inputs.size(), 2); + CHECK_EQ(outputs.size(), 1); + Stream *s = ctx.get_stream(); + size_t rhs_idx = 1; + NDArrayStorageType stype = inputs[rhs_idx].storage_type(); + if (stype == kRowSparseStorage) { + IdentityComputeRsp(attrs, ctx, inputs, req, outputs); + } else { + LOG(FATAL) << "Not implemented yet"; + } +} + struct CastParam : public dmlc::Parameter { // use int for enumeration int dtype; @@ -168,4 +217,5 @@ struct relu_grad { } // namespace op } // namespace mxnet + #endif // MXNET_OPERATOR_TENSOR_ELEMWISE_UNARY_OP_H_ diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 5f010fdfc62c..f55f7d8cf563 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -86,6 +86,48 @@ NNVM_REGISTER_OP(_backward_Embedding) .set_attr("TIsBackward", true) .set_attr("FCompute", EmbeddingOpBackward); +NNVM_REGISTER_OP(SparseEmbedding) +.describe(R"doc(Represents words or other sparse inputs by dense continuous vectors. +It assumes that the input is in one-hot form. E.g., for a vocabulary size of 10,000, + each input vector is expected to have dimension 10,000. +The index of the non-zero entry is the index of the word or item it represents. + +The corresponding embedding vectors are stored as rows of a matrix. +Hence, mapping an input word to its embedding is implemented as a matrix product. + +The gradient of an embedding matrix has the form of gradient vectors that are only + non-zero for words seen in a minibatch. +)doc" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "weight"}; + }) +.set_attr("FInferShape", SparseEmbeddingShape) +.set_attr("FInferType", EmbeddingOpType) +.set_attr("FInferStorageType", SparseEmbeddingForwardStorageType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FComputeEx", SparseEmbeddingForwardEx) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + return MakeNonlossGradNode("_backward_SparseEmbedding", n, ograds, + {n->inputs[0]}, n->attrs.dict); + }) +.add_argument("data", "NDArray-or-Symbol", + "The input array to the sparse embedding operator.") +.add_argument("weight", "NDArray-or-Symbol", "The embedding weight matrix.") +.add_arguments(EmbeddingParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_SparseEmbedding) +.set_num_inputs(2) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FComputeEx", SparseEmbeddingBackwardEx); NNVM_REGISTER_OP(take) .describe(R"code(Takes elements from an input array along the given axis. @@ -230,5 +272,46 @@ Examples:: .add_argument("indices", "NDArray-or-Symbol", "array of locations where to set on_value") .add_arguments(OneHotParam::__FIELDS__()); +NNVM_REGISTER_OP(sparse_retain) +.describe(R"code(pick rows specified by user input index array from a row sparse matrix +and save them in the output sparse matrix. + +Example:: + + data = [[1, 2], [3, 4], [5, 6]] + indices = [0, 1, 3] + shape = (4, 2) + rsp_in = row_sparse(data, indices) + to_retain = [0, 3] + rsp_out = sparse_retain(rsp_in, to_retain) + rsp_out.values = [[1, 2], [5, 6]] + rsp_out.indices = [0, 3] + +)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "indices"}; + }) +.set_attr("FInferShape", SparseRetainOpShape) +.set_attr("FInferType", SparseRetainOpType) +.set_attr("FInferStorageType", SparseRetainForwardInferStorageType) +.set_attr("FComputeEx", SparseRetainOpForwardEx) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + return MakeNonlossGradNode("_backward_sparse_retain", n, ograds, + {n->inputs[sr::kIdx]}, n->attrs.dict); + }) +.add_argument("data", "NDArray-or-Symbol", "The input array for sparse_retain operator.") +.add_argument("indices", "NDArray-or-Symbol", "The index array of rows ids that will be retained."); + +NNVM_REGISTER_OP(_backward_sparse_retain) +.set_num_inputs(2) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInferStorageType", SparseRetainBackwardInferStorageType) +.set_attr("FComputeEx", SparseRetainOpBackwardEx); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 287ec25d70be..4378bd574932 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -26,6 +26,12 @@ NNVM_REGISTER_OP(batch_take) NNVM_REGISTER_OP(one_hot) .set_attr("FCompute", OneHotOpForward); +NNVM_REGISTER_OP(sparse_retain) +.set_attr("FComputeEx", SparseRetainOpForwardEx); + +NNVM_REGISTER_OP(_backward_sparse_retain) +.set_attr("FComputeEx", SparseRetainOpBackwardEx); + } // namespace op } // namespace mxnet diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 5fd6e81d0b2f..6e4b380f893b 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -22,6 +22,7 @@ #include "../elemwise_op_common.h" #include "../mxnet_op.h" #include "./sort_op.h" +#include "./dot-inl.h" namespace mxnet { namespace op { @@ -203,6 +204,79 @@ void EmbeddingOpForward(const nnvm::NodeAttrs& attrs, }); } +template +void SparseEmbeddingForwardRspImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const NDArray& data, + const NDArray& weight, + const OpReqType req, + NDArray *out) { + CHECK_RSP_ALL_ROWS_NON_ZERO(weight, "SparseEmbedding", "weight"); + TBlob out_blob = out->data(); + // forward to dns implementation when storage_shape equals shape + bool transpose_a = false; + DotCsrRspDnsImpl(ctx.get_stream(), data, weight, req, transpose_a, &out_blob); +} + +template +void SparseEmbeddingForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(req[embedding::kOut], kWriteTo); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + + NDArray output = outputs[embedding::kOut]; + auto data_stype = inputs[embedding::kData].storage_type(); + auto weight_stype = inputs[embedding::kWeight].storage_type(); + auto out_stype = outputs[embedding::kOut].storage_type(); + if (data_stype == kCSRStorage && weight_stype == kRowSparseStorage && + out_stype == kDefaultStorage) { + NDArray ret = outputs[embedding::kOut]; + SparseEmbeddingForwardRspImpl(attrs, ctx, inputs[embedding::kData], + inputs[embedding::kWeight], + req[embedding::kOut], &ret); + } else { + LOG(FATAL) << "Not supported SparseEmbedding operation for data.storage_type = " + << data_stype << ", weight.storage_type = " << weight_stype + << ", out.storage_type = " << out_stype; + } +} + +inline bool SparseEmbeddingForwardStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, embedding::kData, kCSRStorage); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, embedding::kOut, kDefaultStorage); + // override the default storage type generated in nnvm + in_attrs->at(embedding::kWeight) = kRowSparseStorage; + return true; +} + +inline bool SparseEmbeddingShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + using namespace mshadow; + const EmbeddingParam& param = nnvm::get(attrs.parsed); + const TShape &dshape = (*in_attrs)[embedding::kData]; + CHECK_EQ(dshape.ndim(), 2) + << "SparseEmbedding shape error: data is expected to be 2D."; + SHAPE_ASSIGN_CHECK(*in_attrs, embedding::kWeight, + Shape2(param.input_dim, param.output_dim)); + out_attrs->clear(); + std::vector buf(2); + buf[0] = dshape[0]; + buf[1] = param.output_dim; + out_attrs->emplace_back(buf.begin(), buf.end()); + return true; +} + // Returns integer log2(a) rounded up inline int ilog2(unsigned int a) { int k = 1; @@ -315,6 +389,31 @@ void EmbeddingOpBackward(const nnvm::NodeAttrs& attrs, }); } +template +void SparseEmbeddingBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + // CHECK_EQ(req[embedding::kData], kNullOp) + // << "Embedding layer doesn't support calculate data gradient" << req[0] << " " << req[1]; + // CHECK_NE(req[1], kWriteInplace) << "DotBackwardEx does not support WriteInplace"; + + auto data_stype = inputs[1].storage_type(); + auto grad_stype = inputs[0].storage_type(); + auto output_stype = outputs[1].storage_type(); + if (data_stype == kCSRStorage && grad_stype == kDefaultStorage && + output_stype == kDefaultStorage) { + TBlob ret = outputs[1].data(); + DotCsrDnsDnsImpl(ctx.get_stream(), inputs[1], inputs[0].data(), req[1], true, &ret); + } else { + LOG(FATAL) << "Not supported dot backward for sparse input(s) with sparse gradients"; + } +} + namespace take_ { // to avoid name conflict enum TakeOpInputs {kArr, kIdx}; enum TakeOpOutputs {kOut}; @@ -667,6 +766,202 @@ void OneHotOpForward(const nnvm::NodeAttrs& attrs, }); } +/*! + * \brief sparse retain namespace + */ +namespace sr { +enum SparseRetainOpInputs {kArr, kIdx}; +enum SparseRetainOpOutputs {kOut}; +} // namespace sr + +inline bool SparseRetainOpShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U) + << "sparse_retain operator takes 2 arguments (" << in_attrs->size() << " given)"; + CHECK_EQ(out_attrs->size(), 1U); + + TShape tshape((*in_attrs)[sr::kArr]); + shape_assign(&tshape, (*out_attrs)[sr::kOut]); + SHAPE_ASSIGN_CHECK(*in_attrs, sr::kArr, tshape); + SHAPE_ASSIGN_CHECK(*out_attrs, sr::kOut, tshape); + return true; +} + +inline bool SparseRetainOpType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + CHECK_NE((*in_attrs)[sr::kIdx], -1) << "Index type must be set for sparse_retain operator"; + + TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[sr::kArr]); + TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[sr::kOut]); + return (*in_attrs)[0] != -1; +} + +inline bool SparseRetainForwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kArr, kRowSparseStorage); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, sr::kOut, kRowSparseStorage); + return true; +} + +inline bool SparseRetainBackwardInferStorageType(const nnvm::NodeAttrs& attrs, + const Context& ctx, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 2U); + STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kOut, kDefaultStorage); + STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, sr::kIdx, kDefaultStorage); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, sr::kArr, kRowSparseStorage); + STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, sr::kIdx, kDefaultStorage); + return true; +} + +struct SparseRetainRspForward { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, RType* out_idx, + const DType* in_data, const RType* in_idx, + const IType* idx, const size_t nnr, + const size_t num_cols) { + const RType irow = idx[i]; + int j = -1, left = 0, right = nnr - 1; + while (left <= right) { + int m = left + (right - left) / 2; + const auto in_idx_m = in_idx[m]; + if (in_idx_m == irow) { + j = m; + break; + } else if (in_idx_m < irow) { + left = m + 1; + } else { + right = m - 1; + } + } + out_idx[i] = idx[i]; + if (j >= 0) { + const size_t in_offset = j * num_cols; + const size_t out_offset = i * num_cols; + for (size_t k = 0; k < num_cols; ++k) { + out_data[out_offset+k] = in_data[in_offset+k]; + } + } + } +}; + +template +void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(req[sr::kOut], kWriteTo) << "sparse_retain only supports req=\'write\'"; + + CHECK_EQ(inputs[sr::kArr].storage_type(), kRowSparseStorage) + << "sparse_retain operator only takes row sparse NDArray as input"; + CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage) + << "sparse_retain operator only takes default NDArray as its index array"; + CHECK_EQ(outputs[sr::kOut].storage_type(), kRowSparseStorage) + << "sparse_retain operator only outputs row sparse NDArray"; + + const NDArray& input_nd = inputs[sr::kArr]; + const TBlob idx_data = inputs[sr::kIdx].data(); + + if (req[sr::kOut] == kNullOp + || !input_nd.storage_initialized() + || idx_data.Size() == 0U) return; + + const TBlob input_data = input_nd.data(); + if (input_data.shape_[0] == 0) return; + const TBlob input_idx = input_nd.aux_data(rowsparse::kIdx); + + NDArray output_nd = outputs[sr::kOut]; + output_nd.CheckAndAlloc({mshadow::Shape1(idx_data.Size())}); + TBlob output_data = output_nd.data(); + TBlob output_idx = output_nd.aux_data(rowsparse::kIdx); + + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(output_data.type_flag_, DType, { // output data type + MSHADOW_IDX_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type + MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type + Kernel::Launch(s, output_data.Size(), output_data.dptr()); + Kernel::Launch(s, idx_data.Size(), output_data.dptr(), + output_idx.dptr(), input_data.dptr(), input_idx.dptr(), + idx_data.dptr(), input_data.shape_[0], input_data.shape_[1]); + }); + }); + }); +} + +template +struct SparseRetainRspBackward { + template + MSHADOW_XINLINE static void Map(int i, DType* in_grad, RType* in_grad_idx, + const DType* out_grad, const IType* idx, + const size_t num_cols) { + const RType irow = idx[i]; + in_grad_idx[i] = irow; + const size_t out_offset = irow * num_cols; + const size_t in_offset = i * num_cols; + for (size_t j = 0; j < num_cols; ++j) { + KERNEL_ASSIGN(in_grad[in_offset+j], req, out_grad[out_offset+j]); + } + } +}; + +template +void SparseRetainOpBackwardEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 2U); + CHECK_EQ(req.size(), 2U); + CHECK_NE(req[sr::kArr], kWriteInplace); + CHECK_EQ(req[sr::kIdx], kNullOp) + << "sparse_retain does not support calculating gradients of indices"; + + CHECK_EQ(inputs[sr::kOut].storage_type(), kDefaultStorage) + << "sparse_retain backward only takes default NDArray as ograd"; + CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage) + << "sparse_retain backward only takes default NDArray as its index array"; + CHECK_EQ(outputs[sr::kArr].storage_type(), kRowSparseStorage) + << "sparse_retain backward only outputs row sparse NDArray as grad of input"; + + const TBlob out_grad_data = inputs[sr::kOut].data(); + const TBlob idx_data = inputs[sr::kIdx].data(); + + NDArray in_grad_nd = outputs[sr::kArr]; + in_grad_nd.CheckAndAlloc({mshadow::Shape1(idx_data.Size())}); + TBlob in_grad_data = in_grad_nd.data(); + TBlob in_grad_idx = in_grad_nd.aux_data(rowsparse::kIdx); + + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(out_grad_data.type_flag_, DType, { // output data type + MSHADOW_IDX_TYPE_SWITCH(in_grad_idx.type_flag_, RType, { // row index data type + MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type + MXNET_ASSIGN_REQ_SWITCH(req[sr::kArr], req_type, { + Kernel, xpu>::Launch( + s, in_grad_idx.Size(), in_grad_data.dptr(), in_grad_idx.dptr(), + out_grad_data.dptr(), idx_data.dptr(), out_grad_data.shape_[1]); + }); + }); + }); + }); +} + } // namespace op } // namespace mxnet #ifdef __CUDACC__ diff --git a/src/operator/tensor/init_op.cc b/src/operator/tensor/init_op.cc index 16f71fc7e4e3..679d1fb55bab 100644 --- a/src/operator/tensor/init_op.cc +++ b/src/operator/tensor/init_op.cc @@ -21,6 +21,7 @@ NNVM_REGISTER_OP(_zeros) .set_attr("FInferShape", InitShape) .set_attr("FInferType", InitType) .set_attr("FCompute", FillCompute) +.set_attr("FComputeEx", FillComputeZerosEx) .add_arguments(InitOpParam::__FIELDS__()); NNVM_REGISTER_OP(_ones) diff --git a/src/operator/tensor/init_op.cu b/src/operator/tensor/init_op.cu index a798f26db60d..7c643ee00129 100644 --- a/src/operator/tensor/init_op.cu +++ b/src/operator/tensor/init_op.cu @@ -9,7 +9,8 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(_zeros) -.set_attr("FCompute", FillCompute); +.set_attr("FCompute", FillCompute) +.set_attr("FComputeEx", FillComputeZerosEx); NNVM_REGISTER_OP(_ones) .set_attr("FCompute", FillCompute); diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index 5ce132d4bebf..7cbf986cbf9c 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -15,6 +15,8 @@ #include #include #include "../elemwise_op_common.h" +#include "../mxnet_op.h" + namespace mxnet { namespace op { @@ -111,7 +113,6 @@ inline bool InitType(const nnvm::NodeAttrs& attrs, return true; } - template void FillCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -127,6 +128,72 @@ void FillCompute(const nnvm::NodeAttrs& attrs, }); } +// Fill in the indices and values of a RowSparse NDArray to represent a zeros NDArray, +// instead of the usual compact representation. +template +inline void FillDnsZerosRspImpl(mshadow::Stream *s, NDArray *dst) { + using namespace rowsparse; + using namespace mshadow::expr; + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(dst->storage_type(), kRowSparseStorage); + MSHADOW_REAL_TYPE_SWITCH(dst->dtype(), DType, { + MSHADOW_IDX_TYPE_SWITCH(dst->aux_type(kIdx), IType, { + auto num_rows = dst->shape()[0]; + dst->CheckAndAlloc({Shape1(num_rows)}); + auto idx = dst->aux_data(kIdx).FlatTo1D(s); + auto val = dst->data(); + Kernel::Launch(s, val.Size(), val.dptr()); + ASSIGN_DISPATCH(idx, kWriteTo, range(0, num_rows, 1, 1)) + }); + }); +} + +// Fill a rsp NDArray with zeros by updating the aux shape. +template +void FillZerosRspImpl(mshadow::Stream *s, NDArray *dst) { + if (!dst->storage_initialized()) return; + // reset the shapes if it's not zeros + auto storage_shape = dst->storage_shape(); + storage_shape[0] = 0; + dst->set_aux_shape(rowsparse::kIdx, TShape(mshadow::Shape1(0))); + dst->set_storage_shape(storage_shape); +} + +// Fill a CSR NDArray with zeros by updating the aux shape. +template +void FillZerosCsrImpl(mshadow::Stream *s, NDArray *dst) { + if (!dst->storage_initialized()) return; + // reset the shapes if it's not zeros + TShape new_shape(mshadow::Shape1(0)); + dst->set_aux_shape(csr::kIndPtr, new_shape); + dst->set_aux_shape(csr::kIdx, new_shape); + dst->set_storage_shape(new_shape); +} + +// This operator never needs to fall back, since there's no input NDArray +template +void FillComputeZerosEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + CHECK_EQ(outputs.size(), 1); + CHECK_EQ(inputs.size(), 0); + auto stype = outputs[0].storage_type(); + if (stype == kRowSparseStorage) { + NDArray nd(outputs[0]); + FillZerosRspImpl(s, &nd); + } else if (stype == kCSRStorage) { + NDArray nd(outputs[0]); + FillZerosCsrImpl(s, &nd); + } else { + LOG(FATAL) << "storage type not implemented."; + } +} template void RangeCompute(const nnvm::NodeAttrs& attrs, diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 72fd2773c8f8..3ae6938bf82c 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -10,6 +10,7 @@ #include #include #include +#include #include "../mshadow_op.h" #include "../elemwise_op_common.h" #include "../mxnet_op.h" @@ -349,364 +350,6 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs, return true; } -struct DotParam : public dmlc::Parameter { - bool transpose_a; - bool transpose_b; - DMLC_DECLARE_PARAMETER(DotParam) { - DMLC_DECLARE_FIELD(transpose_a) - .describe("If true then transpose the first input before dot.") - .set_default(false); - DMLC_DECLARE_FIELD(transpose_b) - .describe("If true then transpose the second input before dot.") - .set_default(false); - } -}; - -template -void DotForward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mshadow::expr; - const DotParam& param = nnvm::get(attrs.parsed); - Stream *s = ctx.get_stream(); - CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_) - << "Binary function only support input/output with the same type"; - CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_) - << "Binary function only support input/output with the same type"; - CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64) - << "dot only supports float32 and float64"; - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - if (inputs[0].ndim() == 1 && inputs[1].ndim() == 1) { - CHECK_NE(req[0], kAddTo) << "AddTo not yet suported"; - Tensor out = outputs[0].get(s); - VectorDot(out, - inputs[0].get(s), - inputs[1].get(s)); - } else { - int ma, na, mb, nb, m, n; - if (param.transpose_a) { - ma = inputs[0].size(0); - na = inputs[0].Size()/ma; - m = na; - } else { - na = inputs[0].size(inputs[0].ndim()-1); - ma = inputs[0].Size()/na; - m = ma; - } - if (param.transpose_b) { - nb = inputs[1].size(inputs[1].ndim()-1); - mb = inputs[1].Size()/nb; - n = mb; - } else { - mb = inputs[1].size(0); - nb = inputs[1].Size()/mb; - n = nb; - } - Tensor input0 = - inputs[0].get_with_shape(Shape2(ma, na), s); - Tensor input1 = - inputs[1].get_with_shape(Shape2(mb, nb), s); - Tensor out = - outputs[0].get_with_shape(Shape2(m, n), s); - if (param.transpose_a && param.transpose_b) { - ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1.T())); - } else if (!param.transpose_a && param.transpose_b) { - ASSIGN_DISPATCH(out, req[0], dot(input0, input1.T())); - } else if (param.transpose_a && !param.transpose_b) { - ASSIGN_DISPATCH(out, req[0], dot(input0.T(), input1)); - } else { - ASSIGN_DISPATCH(out, req[0], dot(input0, input1)); - } - } - }); -} - -template -void DotBackward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mshadow::expr; - const DotParam& param = nnvm::get(attrs.parsed); - Stream *s = ctx.get_stream(); - CHECK_NE(req[0], kWriteInplace); - CHECK_NE(req[1], kWriteInplace); - CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64) - << "dot only supports float32 and float64"; - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - if (inputs[1].ndim() == 1 && inputs[2].ndim() == 1) { - Tensor mout_grad = inputs[0].get(s); - Tensor mlhs_data = inputs[1].get(s); - Tensor mrhs_data = inputs[2].get(s); - Tensor mlhs_grad = outputs[0].get(s); - Tensor mrhs_grad = outputs[1].get(s); - ASSIGN_DISPATCH(mrhs_grad, req[1], - broadcast_scalar(mout_grad, mlhs_data.shape_) * mlhs_data); - ASSIGN_DISPATCH(mlhs_grad, req[0], - broadcast_scalar(mout_grad, mlhs_data.shape_) * mrhs_data); - } else { - int ma, na, mb, nb, m, n; - if (param.transpose_a) { - ma = outputs[0].size(0); - na = outputs[0].Size()/ma; - m = na; - } else { - na = outputs[0].size(outputs[0].ndim()-1); - ma = outputs[0].Size()/na; - m = ma; - } - if (param.transpose_b) { - nb = outputs[1].size(outputs[1].ndim()-1); - mb = outputs[1].Size()/nb; - n = mb; - } else { - mb = outputs[1].size(0); - nb = outputs[1].Size()/mb; - n = nb; - } - Tensor mout_grad = - inputs[0].get_with_shape(Shape2(m, n), s); - Tensor mlhs_data = - inputs[1].get_with_shape(Shape2(ma, na), s); - Tensor mrhs_data = - inputs[2].get_with_shape(Shape2(mb, nb), s); - Tensor mlhs_grad = - outputs[0].get_with_shape(Shape2(ma, na), s); - Tensor mrhs_grad = - outputs[1].get_with_shape(Shape2(mb, nb), s); - if (param.transpose_a && param.transpose_b) { - // Gradient of z = dot(x.T, y.T) - // dy = dot(x, dz).T = dot(dz.T, x.T) - // dx = dot(dz, y).T = dot(y.T, dz.T) - ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mout_grad.T(), mlhs_data.T())); - ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mrhs_data.T(), mout_grad.T())); - } else if (!param.transpose_a && param.transpose_b) { - // Gradient of z = dot(x, y.T) - // dy = dot(x.T, dz).T = dot(dz.T, x) - // dx = dot(dz, y) - ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mout_grad.T(), mlhs_data)); - ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mout_grad, mrhs_data)); - } else if (param.transpose_a && !param.transpose_b) { - // Gradient of z = dot(x.T, y) - // dy = dot(x, dz) - // dx = dot(dz, y.T).T = dot(y, dz.T) - ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mlhs_data, mout_grad)); - ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mrhs_data, mout_grad.T())); - } else { - // Gradient of z = dot(x, y) - // dy = dot(x.T, dz) - // dx = dot(dz, y.T) - ASSIGN_DISPATCH(mrhs_grad, req[1], dot(mlhs_data.T(), mout_grad)); - ASSIGN_DISPATCH(mlhs_grad, req[0], dot(mout_grad, mrhs_data.T())); - } - } - }); -} - -inline bool DotShape(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - const DotParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - TShape& lshape = (*in_attrs)[0]; - TShape& rshape = (*in_attrs)[1]; - if (lshape.ndim() == 1 && rshape.ndim() == 1) { - CHECK(!param.transpose_a && !param.transpose_b) << "Cannot transpose vectors"; - CHECK_EQ(lshape[0], rshape[0]) << "dot shape error: " << lshape << " X " << rshape; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape1(1)); - } else { - bool Ta = param.transpose_a, Tb = param.transpose_b; - TShape L[2], R[2]; - if (Ta) { - L[0] = mshadow::Shape1(lshape[0]); - L[1] = lshape.ndim() > 1 ? TShape(&lshape[1], &lshape[lshape.ndim()]) : TShape(1); - } else { - L[0] = lshape.ndim() > 1 ? TShape(&lshape[0], &lshape[lshape.ndim()-1]) : TShape(1); - L[1] = mshadow::Shape1(lshape[lshape.ndim()-1]); - } - if (Tb) { - R[0] = rshape.ndim() > 1 ? TShape(&rshape[0], &rshape[rshape.ndim()-1]) : TShape(1); - R[1] = mshadow::Shape1(rshape[rshape.ndim()-1]); - } else { - R[0] = mshadow::Shape1(rshape[0]); - R[1] = rshape.ndim() > 1 ? TShape(&rshape[1], &rshape[rshape.ndim()]) : TShape(1); - } - - if (L[!Ta].Size() != 0 && R[Tb].Size() != 0) { - CHECK_EQ(L[!Ta].Size(), R[Tb].Size()) - << "dot shape error: " << lshape << " X " << rshape; - } - std::vector buf; - if (lshape.ndim() > 1) buf.insert(buf.end(), &L[Ta][0], &L[Ta][L[Ta].ndim()]); - if (rshape.ndim() > 1) buf.insert(buf.end(), &R[!Tb][0], &R[!Tb][R[!Tb].ndim()]); - TShape oshape(buf.begin(), buf.end()); - SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - } - return true; -} - -template -void BatchDotForward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mshadow::expr; - mshadow::Stream *s = ctx.get_stream(); - const DotParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(outputs[0].type_flag_, inputs[0].type_flag_) - << "Binary function only support input/output with the same type"; - CHECK_EQ(outputs[0].type_flag_, inputs[1].type_flag_) - << "Binary function only support input/output with the same type"; - CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64) - << "dot only supports float32 and float64"; - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - mshadow::Tensor out = outputs[0].get(s); - mshadow::Tensor mlhs = inputs[0].get(s); - mshadow::Tensor mrhs = inputs[1].get(s); - mshadow::Tensor workspace = - ctx.requested[0].get_space_typed(mshadow::Shape1(3 * out.size(0)), s); - if (kNullOp != req[0]) { - if (param.transpose_a && param.transpose_b) { - mshadow::BatchGEMM(out, mlhs, mrhs, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - workspace); - } else if (!param.transpose_a && param.transpose_b) { - mshadow::BatchGEMM(out, mlhs, mrhs, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - workspace); - } else if (param.transpose_a && !param.transpose_b) { - mshadow::BatchGEMM(out, mlhs, mrhs, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - workspace); - } else { - mshadow::BatchGEMM(out, mlhs, mrhs, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - workspace); - } - } - }); -} - -template -void BatchDotBackward_(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - using namespace mshadow; - using namespace mshadow::expr; - mshadow::Stream *s = ctx.get_stream(); - const DotParam& param = nnvm::get(attrs.parsed); - CHECK_NE(req[1], kWriteInplace); - CHECK_NE(req[0], kWriteInplace); - CHECK(outputs[0].type_flag_ == kFloat32 || outputs[0].type_flag_ == kFloat64) - << "dot only supports float32 and float64"; - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - mshadow::Tensor mout_grad = inputs[0].get(s); - mshadow::Tensor mlhs_data = inputs[1].get(s); - mshadow::Tensor mrhs_data = inputs[2].get(s); - mshadow::Tensor mlhs_grad = outputs[0].get(s); - mshadow::Tensor mrhs_grad = outputs[1].get(s); - mshadow::Tensor workspace = - ctx.requested[0].get_space_typed( - mshadow::Shape2(2, 3 * mout_grad.size(0)), s); - mshadow::Tensor rhs_workspace = workspace[0]; - mshadow::Tensor lhs_workspace = workspace[1]; - if (param.transpose_a && param.transpose_b) { - // Gradient of z = dot(x.T, y.T) - // dy = dot(x, dz).T = dot(dz.T, x.T) - // dx = dot(dz, y).T = dot(y.T, dz.T) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mout_grad, mlhs_data, (DType)1.0f, - (kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mrhs_data, mout_grad, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - lhs_workspace); - } - } else if (!param.transpose_a && param.transpose_b) { - // Gradient of z = dot(x, y.T) - // dy = dot(x.T, dz).T = dot(dz.T, x) - // dx = dot(dz, y) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mout_grad, mlhs_data, (DType)1.0f, - (kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mout_grad, mrhs_data, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - lhs_workspace); - } - } else if (param.transpose_a && !param.transpose_b) { - // Gradient of z = dot(x.T, y) - // dy = dot(x, dz) - // dx = dot(dz, y.T).T = dot(y, dz.T) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mlhs_data, mout_grad, (DType)1.0f, - (kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mrhs_data, mout_grad, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - lhs_workspace); - } - } else { - // Gradient of z = dot(x, y) - // dy = dot(x.T, dz) - // dx = dot(dz, y.T) - if (kNullOp != req[1]) { - mshadow::BatchGEMM(mrhs_grad, mlhs_data, mout_grad, (DType)1.0f, - (kAddTo == req[1]) ? (DType)1.0f : (DType)0.0f, - rhs_workspace); - } - if (kNullOp != req[0]) { - mshadow::BatchGEMM(mlhs_grad, mout_grad, mrhs_data, (DType)1.0f, - (kAddTo == req[0]) ? (DType)1.0f : (DType)0.0f, - lhs_workspace); - } - } - }); -} - -inline bool BatchDotShape(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - const DotParam& param = nnvm::get(attrs.parsed); - TShape& lshape = (*in_attrs)[0]; - TShape& rshape = (*in_attrs)[1]; - if (lshape.ndim() == 3 && rshape.ndim() == 3) { - CHECK(lshape[0] == rshape[0]) - << "batch_dot shape error(batch_size must be equal): " << lshape << " X " << rshape - << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; - index_t out_m = param.transpose_a ? lshape[2] : lshape[1]; - index_t lshape_k = param.transpose_a ? lshape[1] : lshape[2]; - index_t out_n = param.transpose_b ? rshape[1] : rshape[2]; - index_t rshape_k = param.transpose_b ? rshape[2] : rshape[1]; - CHECK(lshape_k == rshape_k) - << "batch_dot shape error(shape mismatch): " << lshape << " X " << rshape - << " trans_a=" << param.transpose_a << " trans_b=" << param.transpose_b; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape3(lshape[0], out_m, out_n)); - } else { - LOG(FATAL) << "batch_dot currently only support 3D*3D array" - << lshape << " v.s. " << rshape; - } - return true; -} - struct SliceParam : public dmlc::Parameter { nnvm::Tuple > begin, end; DMLC_DECLARE_PARAMETER(SliceParam) { @@ -826,6 +469,96 @@ void Slice(const nnvm::NodeAttrs& attrs, }); } +// slice the indptr of a csr +struct SliceCsrIndPtr { + template + MSHADOW_XINLINE static void Map(int i, IType* out, const IType* in, const IType* base) { + KERNEL_ASSIGN(out[i], kWriteTo, in[i] - *base); + } +}; + +/* + * a wrapper to launch SliceCsrIndPtr kernel. + * slice [src[begin] .. src[end]) and store in dst[0, end - begin) + */ +template +void SliceCsrIndPtrImpl(const int begin, const int end, RunContext ctx, + const IType* src, IType* dst) { + using namespace mshadow; + using namespace mxnet_op; + Stream *s = ctx.get_stream(); + int indptr_len = end - begin + 1; + Kernel::Launch(s, indptr_len, dst, src + begin, src + begin); +} + +/* + * Slice a CSR NDArray + * Only implemented for CPU + */ +template +void SliceCsrImpl(const SliceParam ¶m, const OpContext& ctx, + const NDArray &in, OpReqType req, const NDArray &out) { + using namespace mshadow; + using namespace mxnet_op; + using namespace csr; + CHECK((std::is_same::value)) << "Slice for CSR input only implemented for CPU"; + if (req == kNullOp) return; + CHECK_NE(req, kAddTo) << "kAddTo for Slice on CSR input is not supported"; + CHECK_NE(req, kWriteInplace) << "kWriteInplace for Slice on CSR input is not supported"; + Stream *s = ctx.get_stream(); + int begin = *param.begin[0]; + int end = *param.end[0]; + int indptr_len = end - begin + 1; + out.CheckAndAllocAuxData(kIndPtr, Shape1(indptr_len)); + if (!in.storage_initialized()) { + out.set_aux_shape(kIndPtr, Shape1(0)); + return; + } + // assume idx indptr share the same type + MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIndPtr), RType, { + MSHADOW_IDX_TYPE_SWITCH(in.aux_type(kIdx), IType, { + MSHADOW_TYPE_SWITCH(in.dtype(), DType, { + auto in_indptr = in.aux_data(kIndPtr).dptr(); + auto out_indptr = out.aux_data(kIndPtr).dptr(); + SliceCsrIndPtrImpl(begin, end, ctx.run_ctx, in_indptr, out_indptr); + + // retrieve nnz (CPU implementation) + int nnz = out_indptr[indptr_len - 1]; + // copy indices and values + out.CheckAndAllocAuxData(kIdx, Shape1(nnz)); + out.CheckAndAllocData(Shape1(nnz)); + auto in_idx = in.aux_data(kIdx).dptr(); + auto out_idx = out.aux_data(kIdx).dptr(); + auto in_data = in.data().dptr(); + auto out_data = out.data().dptr(); + int offset = in_indptr[begin]; + // this is also a CPU-only implementation + memcpy(out_idx, in_idx + offset, nnz * sizeof(IType)); + memcpy(out_data, in_data + offset, nnz * sizeof(DType)); + }); + }); + }); +} + +template +void SliceEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1); + CHECK_EQ(outputs.size(), 1); + const SliceParam& param = nnvm::get(attrs.parsed); + auto in_stype = inputs[0].storage_type(); + CHECK_NE(in_stype, kDefaultStorage) + << "SliceEx is not expected to execute for input with default storage type"; + if (in_stype == kCSRStorage) { + SliceCsrImpl(param, ctx, inputs[0], req[0], outputs[0]); + } else { + LOG(FATAL) << "Slice not implemented for storage type" << in_stype; + } +} + inline bool SliceAssignShape(const nnvm::NodeAttrs& attrs, std::vector *in_attrs, std::vector *out_attrs) { diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 6a51d46db25c..e80e8463000e 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -16,7 +16,6 @@ DMLC_REGISTER_PARAMETER(ClipParam); DMLC_REGISTER_PARAMETER(SimpleCropAssignScalarParam); DMLC_REGISTER_PARAMETER(SliceParam); DMLC_REGISTER_PARAMETER(SliceAxisParam); -DMLC_REGISTER_PARAMETER(DotParam); DMLC_REGISTER_PARAMETER(RepeatParam); DMLC_REGISTER_PARAMETER(TileParam); DMLC_REGISTER_PARAMETER(ReverseParam); @@ -244,6 +243,9 @@ and ``end=(e_1, e_2, ... e_n)`` indices will result in an array with the shape The resulting array's *k*-th dimension contains elements from the *k*-th dimension of the input array with the open range ``[b_k, e_k)``. +For an input array of non-default storage type(e.g. `csr` or `row_sparse`), it only supports +slicing on the first dimension. + Example:: x = [[ 1., 2., 3., 4.], @@ -257,8 +259,10 @@ Example:: .set_attr_parser(ParamParser) .set_attr("FInferShape", SliceShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FInferStorageType", ElemwiseStorageType<1, 1>) .set_attr("FGradient", ElemwiseGradUseNone{"_backward_slice"}) .set_attr("FCompute", Slice) +.set_attr("FComputeEx", SliceEx) .add_argument("data", "NDArray-or-Symbol", "Source input") .add_arguments(SliceParam::__FIELDS__()); @@ -351,94 +355,6 @@ NNVM_REGISTER_OP(_backward_slice_axis) .set_attr("TIsBackward", true) .set_attr("FCompute", SliceAxisGrad_); -NNVM_REGISTER_OP(dot) -.describe(R"doc(Dot product of two arrays. - -``dot``'s behavior depends on the input array dimensions: - -- 1-D arrays: inner product of vectors -- 2-D arrays: matrix multiplication -- N-D arrays: a sum product over the last axis of the first input and the first - axis of the second input - - For example, given 3-D ``x`` with shape `(n,m,k)` and ``y`` with shape `(k,r,s)`, the - result array will have shape `(n,m,r,s)`. It is computed by:: - - dot(x,y)[i,j,a,b] = sum(x[i,j,:]*y[:,a,b]) - - Example:: - - x = reshape([0,1,2,3,4,5,6,7], shape=(2,2,2)) - y = reshape([7,6,5,4,3,2,1,0], shape=(2,2,2)) - dot(x,y)[0,0,1,1] = 0 - sum(x[0,0,:]*y[:,1,1]) = 0 -)doc" ADD_FILELINE) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"lhs", "rhs"}; - }) -.set_attr("FInferShape", DotShape) -.set_attr("FInferType", ElemwiseType<2, 1>) -.set_attr("FCompute", DotForward_) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_dot"}) -.add_argument("lhs", "NDArray-or-Symbol", "The first input") -.add_argument("rhs", "NDArray-or-Symbol", "The second input") -.add_arguments(DotParam::__FIELDS__()); - -NNVM_REGISTER_OP(_backward_dot) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr_parser(ParamParser) -.set_attr("TIsBackward", true) -.set_attr("FCompute", DotBackward_) -.add_arguments(DotParam::__FIELDS__()); - -NNVM_REGISTER_OP(batch_dot) -.describe(R"doc(Batchwise dot product. - -``batch_dot`` is used to compute dot product of ``x`` and ``y`` when ``x`` and -``y`` are data in batch, namely 3D arrays in shape of `(batch_size, :, :)`. - -For example, given ``x`` with shape `(batch_size, n, m)` and ``y`` with shape -`(batch_size, m, k)`, the result array will have shape `(batch_size, n, k)`, -which is computed by:: - - batch_dot(x,y)[i,:,:] = dot(x[i,:,:], y[i,:,:]) - -)doc" ADD_FILELINE) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"lhs", "rhs"}; - }) -.set_attr("FInferShape", BatchDotShape) -.set_attr("FInferType", ElemwiseType<2, 1>) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", BatchDotForward_) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_batch_dot"}) -.add_argument("lhs", "NDArray-or-Symbol", "The first input") -.add_argument("rhs", "NDArray-or-Symbol", "The second input") -.add_arguments(DotParam::__FIELDS__()); - -NNVM_REGISTER_OP(_backward_batch_dot) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr_parser(ParamParser) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("TIsBackward", true) -.set_attr("FCompute", BatchDotBackward_); - NNVM_REGISTER_OP(clip) .describe(R"code(Clips (limits) the values in an array. diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu index 96c075a7d483..91a6757b962c 100644 --- a/src/operator/tensor/matrix_op.cu +++ b/src/operator/tensor/matrix_op.cu @@ -39,18 +39,6 @@ NNVM_REGISTER_OP(slice_axis) NNVM_REGISTER_OP(_backward_slice_axis) .set_attr("FCompute", SliceAxisGrad_); -NNVM_REGISTER_OP(dot) -.set_attr("FCompute", DotForward_); - -NNVM_REGISTER_OP(_backward_dot) -.set_attr("FCompute", DotBackward_); - -NNVM_REGISTER_OP(batch_dot) -.set_attr("FCompute", BatchDotForward_); - -NNVM_REGISTER_OP(_backward_batch_dot) -.set_attr("FCompute", BatchDotBackward_); - NNVM_REGISTER_OP(clip) .set_attr("FCompute", Clip); diff --git a/tests/ci_build/install/ubuntu_install_python.sh b/tests/ci_build/install/ubuntu_install_python.sh index 973523d0c8f3..357ac9e9c3b2 100755 --- a/tests/ci_build/install/ubuntu_install_python.sh +++ b/tests/ci_build/install/ubuntu_install_python.sh @@ -6,5 +6,5 @@ apt-get update && apt-get install -y python-dev python3-dev # the version of the pip shipped with ubuntu may be too lower, install a recent version here cd /tmp && wget https://bootstrap.pypa.io/get-pip.py && python3 get-pip.py && python2 get-pip.py -pip2 install nose pylint numpy nose-timer requests h5py -pip3 install nose pylint numpy nose-timer requests h5py +pip2 install nose pylint numpy nose-timer requests h5py scipy +pip3 install nose pylint numpy nose-timer requests h5py scipy diff --git a/tests/cpp/include/test_ndarray_utils.h b/tests/cpp/include/test_ndarray_utils.h new file mode 100644 index 000000000000..4a99d2759c3b --- /dev/null +++ b/tests/cpp/include/test_ndarray_utils.h @@ -0,0 +1,115 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file test_utils.h + * \brief operator unit test utility functions + * \author Haibin Lin +*/ +#ifndef TESTS_CPP_INCLUDE_TEST_NDARRAY_UTILS_H_ +#define TESTS_CPP_INCLUDE_TEST_NDARRAY_UTILS_H_ + +/*#include +#include +#include +#include +#include +#include +#include +#include + +#include "../src/operator/tensor/elemwise_binary_op.h" +#include "../src/operator/tensor/elemwise_unary_op.h" +#include "../src/operator/optimizer_op-inl.h" +#include "../src/operator/tensor/init_op.h" + +using namespace mxnet; +#define TEST_DTYPE float +#define TEST_ITYPE int32_t + +void CheckDataRegion(const TBlob &src, const TBlob &dst) { + auto size = src.shape_.Size() * mshadow::mshadow_sizeof(src.type_flag_); + auto equals = memcmp(src.dptr_, dst.dptr_, size); + EXPECT_EQ(equals, 0); +} + +float RandFloat() { + float v = rand() * 1.0 / RAND_MAX; + return v; +} + +// Get an NDArray with provided indices, prepared for a RowSparse NDArray. +NDArray RspIdxND(const TShape shape, const Context ctx, const std::vector &values) { + NDArray nd(shape, ctx, false, ROW_SPARSE_IDX_TYPE); + size_t num_val = values.size(); + MSHADOW_TYPE_SWITCH(nd.dtype(), DType, { + auto tensor = nd.data().FlatTo1D(); + for (size_t i = 0; i < num_val; i++) { + tensor[i] = values[i]; + } + }); + return nd; +} + +// Get a dense NDArray with provided values. +NDArray DnsND(const TShape shape, const Context ctx, std::vector vs) { + NDArray nd(shape, ctx, false); + size_t num_val = shape.Size(); + // generate random values + while (vs.size() < num_val) { + auto v = RandFloat(); + vs.push_back(v); + } + CHECK_EQ(vs.size(), nd.shape().Size()); + MSHADOW_TYPE_SWITCH(nd.dtype(), DType, { + auto tensor = nd.data().FlatTo1D(); + for (size_t i = 0; i < num_val; i++) { + tensor[i] = vs[i]; + } + }); + return nd; +} + +// Get a RowSparse NDArray with provided indices and values +NDArray RspND(const TShape shape, const Context ctx, const std::vector idx, + std::vector vals) { + CHECK(shape.ndim() <= 2) << "High dimensional row sparse not implemented yet"; + index_t num_rows = idx.size(); + index_t num_cols = vals.size() / idx.size(); + // create index NDArray + NDArray index = RspIdxND(mshadow::Shape1(num_rows), ctx, idx); + CHECK_EQ(vals.size() % idx.size(), 0); + // create value NDArray + NDArray data = DnsND(mshadow::Shape2(num_rows, num_cols), ctx, vals); + // create result nd + NDArray nd(kRowSparseStorage, shape, ctx, false, mshadow::default_type_flag, + {}, {mshadow::Shape1(num_rows)}); + // assign values + NDArray nd_aux = nd.aux_ndarray(0); + NDArray nd_data = nd.data_ndarray(); + CopyFromTo(index, &nd_aux); + CopyFromTo(data, &nd_data); + return nd; +} + +// TODO(haibin) support other types +NDArray Convert(NDArrayStorageType type, NDArray src) { + CHECK_EQ(type, kDefaultStorage); + NDArray converted(src.shape(), src.ctx(), false); + Engine::Get()->PushSync([src, converted](RunContext ctx) { + // TODO provide type in attrs, which is empty now + OpContext op_ctx; + op_ctx.run_ctx = ctx; + if (src.storage_type() == kRowSparseStorage) { + std::vector inputs({src}), outputs({converted}); + op::CastStorageComputeEx({}, op_ctx, inputs, {}, outputs); + } else if (src.storage_type() == kDefaultStorage) { + std::vector inputs({src.data()}), outputs({converted.data()}); + op::IdentityCompute({}, op_ctx, inputs, {kWriteTo}, outputs); + } else { + LOG(FATAL) << "unsupported storage type"; + } + }, src.ctx(), {src.var()}, {converted.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); + converted.WaitToRead(); + return converted; +}*/ +#endif // TESTS_CPP_INCLUDE_TEST_NDARRAY_UTILS_H_ diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc index 719980b5d4f5..32d60cf3e4e4 100644 --- a/tests/cpp/operator/batchnorm_test.cc +++ b/tests/cpp/operator/batchnorm_test.cc @@ -1,7 +1,7 @@ /*! * Copyright (c) 2017 by Contributors * \file batchnorm_test.cc - * \brief operator unit test utility functions + * \brief batchnorm operator unit test utility functions * \author Chris Olivier */ @@ -874,8 +874,8 @@ TEST(BATCH_NORM, TestIterAll) { kwargs.push_back({ "cudnn_off", "True" }); } for (TShape shape : shapes) { - for (int g1 = 0; g1 < 2U; ++g1) { - for (int g2 = 0; g2 < 2U; ++g2) { + for (int g1 = 0; g1 < 2; ++g1) { + for (int g2 = 0; g2 < 2; ++g2) { for (int type : v2_types) { MSHADOW_REAL_TYPE_SWITCH_EX( type, DType, AccReal, diff --git a/tests/cpp/operator/ndarray_test.cc b/tests/cpp/operator/ndarray_test.cc new file mode 100644 index 000000000000..f2ed30793881 --- /dev/null +++ b/tests/cpp/operator/ndarray_test.cc @@ -0,0 +1,6 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file ndarray_test.cc + * \brief ndarray unit test utility functions + * \author Haibin Lin +*/ diff --git a/tests/cpp/unittest.mk b/tests/cpp/unittest.mk index 808b655e9dba..ec7bb55ec983 100644 --- a/tests/cpp/unittest.mk +++ b/tests/cpp/unittest.mk @@ -47,4 +47,4 @@ testclean: -include build/tests/cpp/*.d -include build/tests/cpp/operator/*.d -include build/tests/cpp/storage/*.d --include build/tests/cpp/engine/*.d \ No newline at end of file +-include build/tests/cpp/engine/*.d diff --git a/tests/nightly/dist_sync_kvstore.py b/tests/nightly/dist_sync_kvstore.py index ebed6c57586d..c30aaed13a7a 100644 --- a/tests/nightly/dist_sync_kvstore.py +++ b/tests/nightly/dist_sync_kvstore.py @@ -11,38 +11,89 @@ def check_diff_to_scalar(A, x): assert(np.sum(np.abs((A - x).asnumpy())) == 0), A.asnumpy() # setup -keys = [3, 5, 7] +keys = ['3', '5', '7'] +rsp_keys = ['9', '11', '13'] + rate = 2 shape = (2, 2) big_shape = (1200, 1200) # big than BIGARRAY_BOUND -kv = mx.kv.create('dist_sync') - -# init kv -kv.init(keys, [mx.nd.ones(shape)] * len(keys)) -kv.init(99, mx.nd.ones(big_shape)) -# init updater on servers -kv.set_optimizer(mx.optimizer.create('test', rate)) +def init_kv(): + kv = mx.kv.create('dist_sync') + # init kv + kv.init(keys, [mx.nd.ones(shape)] * len(keys)) + kv.init('99', mx.nd.ones(big_shape)) + my_rank = kv.rank + nworker = kv.num_workers + # init updater on servers + kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate)) + return kv, my_rank, nworker -my_rank = kv.rank -nworker = kv.num_workers +def init_kv_rsp(): + kv = mx.kv.create('dist_sync') + # init kv + kv.init(rsp_keys, [mx.nd.ones(shape)._to_rsp()] * len(rsp_keys)) + # kv.init(99, mx.nd.ones(big_shape)) + my_rank = kv.rank + nworker = kv.num_workers + # init updater on servers + kv.set_optimizer(mx.optimizer.create('test', rescale_grad=rate)) + return kv, my_rank, nworker def test_sync_push_pull(): + kv, my_rank, nworker = init_kv() nrepeat = 3 for i in range(nrepeat): - kv.push(3, mx.nd.ones(shape)*(my_rank+1)) - kv.push(99, mx.nd.ones(big_shape)*(my_rank+1)) + kv.push('3', mx.nd.ones(shape)*(my_rank+1)) + kv.push('99', mx.nd.ones(big_shape)*(my_rank+1)) num = (nworker + 1 ) * nworker * rate / 2 * nrepeat + 1 val = mx.nd.zeros(shape) - kv.pull(3, out = val) + kv.pull('3', out = val) check_diff_to_scalar(val, num) - # print val.asnumpy() val2 = mx.nd.zeros(big_shape) - kv.pull(99, out = val2) + kv.pull('99', out = val2) check_diff_to_scalar(val2, num) + print('done') + +def test_sync_push_pull_row_sparse(): + kv, my_rank, nworker = init_kv_rsp() + nrepeat = 2 + + v = mx.nd.zeros(shape) + my_row = my_rank % shape[0] + for col in range(shape[1]): + v[my_row][col] = my_rank + 1 + + for i in range(nrepeat): + kv.push('9', v._to_rsp()) + # kv.push(99, mx.nd.ones(big_shape)*(my_rank+1)) + + # pull a subset of rows this worker is interested in + val = v.copyto(mx.cpu())._to_rsp() + kv.pull('9', out = val) + + expected = mx.nd.zeros(shape) + # initial value + for col in range(shape[1]): + expected[my_row][col] = 1 + # apply updates from workers + for rank in range(nworker): + row = rank % shape[0] + if row != my_row: + continue + for col in range(shape[1]): + expected[my_row][col] += (rank + 1) * rate * nrepeat + #print("expect ", expected.asnumpy()) + + check_diff_to_scalar(val, expected) + # print('done') + #val2 = mx.nd.zeros(big_shape) + #kv.pull(99, out = val2) + #check_diff_to_scalar(val2, num) if __name__ == "__main__": test_sync_push_pull() + test_sync_push_pull_row_sparse() diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 487197f2ad7e..e37e665a3ef6 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -14,6 +14,7 @@ from test_nn import * #from test_rnn import * from test_gluon_rnn import * +from test_sparse_operator import test_sparse_dot set_default_context(mx.gpu(0)) del test_support_vector_machine_l1_svm diff --git a/tests/python/unittest/test_infer_shape.py b/tests/python/unittest/test_infer_shape.py index 35598bc55be8..ceb965b43a72 100644 --- a/tests/python/unittest/test_infer_shape.py +++ b/tests/python/unittest/test_infer_shape.py @@ -112,6 +112,24 @@ def test_incomplete_infer_concat(): assert arg_shapes['b'] == (2, 5) assert arg_shapes['d'] == (2, 15) +def test_fc_infer_type(): + mx_real_t = mx.base.mx_real_t + data = mx.symbol.Variable('data') + out = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=1000) + + # infer type + data_type = mx_real_t + arg_types, out_types, aux_types = out.infer_type(data=data_type) + arg_type_dict = dict(zip(out.list_arguments(), arg_types)) + assert len(out_types) == 1 + assert out_types[0] == mx_real_t + true_types = { + 'fc1_bias' : mx_real_t, + 'fc1_weight' : mx_real_t } + for k, v in true_types.items(): + assert arg_type_dict[k] == v + + if __name__ == "__main__": test_mlp2_infer_shape() test_mlp2_infer_error() diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 18326754c851..b39d2ddf3b54 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -1,5 +1,6 @@ # pylint: skip-file import mxnet as mx +from mxnet.test_utils import * import numpy as np import os, gzip import pickle as pickle @@ -135,6 +136,101 @@ def test_NDArrayIter_h5py(): else: assert(labelcount[i] == 100) +def test_NDArrayIter_csr(): + import scipy.sparse as sp + # creating toy data + num_rows = rnd.randint(5, 15) + num_cols = rnd.randint(1, 20) + batch_size = rnd.randint(1, num_rows) + shape = (num_rows, num_cols) + csr, _ = rand_sparse_ndarray(shape, 'csr') + dns = csr.asnumpy() + + # make iterators + csr_iter = iter(mx.io.NDArrayIter(csr, csr, batch_size)) + begin = 0 + for batch in csr_iter: + expected = np.zeros((batch_size, num_cols)) + end = begin + batch_size + expected[:num_rows - begin] = dns[begin:end] + if end > num_rows: + expected[num_rows - begin:] = dns[0:end - num_rows] + assert_almost_equal(batch.data[0].asnumpy(), expected) + begin += batch_size + +def test_LibSVMIter(): + def get_data(data_dir, data_name, url, data_origin_name): + if not os.path.isdir(data_dir): + os.system("mkdir " + data_dir) + os.chdir(data_dir) + if (not os.path.exists(data_name)): + if sys.version_info[0] >= 3: + from urllib.request import urlretrieve + else: + from urllib import urlretrieve + zippath = os.path.join(data_dir, data_origin_name) + urlretrieve(url, zippath) + import bz2 + bz_file = bz2.BZ2File(data_origin_name, 'rb') + with open(data_name, 'wb') as fout: + try: + content = bz_file.read() + fout.write(content) + finally: + bz_file.close() + os.chdir("..") + + def check_libSVMIter_synthetic(): + cwd = os.getcwd() + data_path = os.path.join(cwd, 'data.t') + label_path = os.path.join(cwd, 'label.t') + with open(data_path, 'w') as fout: + fout.write('1.0 0:0.5 2:1.2\n') + fout.write('-2.0\n') + fout.write('-3.0 0:0.6 1:2.4 2:1.2\n') + fout.write('4 2:-1.2\n') + + with open(label_path, 'w') as fout: + fout.write('1.0\n') + fout.write('-2.0 0:0.125\n') + fout.write('-3.0 2:1.2\n') + fout.write('4 1:1.0 2:-1.2\n') + + data_dir = os.path.join(cwd, 'data') + data_train = mx.io.LibSVMIter(data_libsvm=data_path, label_libsvm=label_path, + data_shape=(3, ), label_shape=(3, ), batch_size=3) + + first = mx.nd.array([[ 0.5, 0., 1.2], [ 0., 0., 0.], [ 0.6, 2.4, 1.2]]) + second = mx.nd.array([[ 0., 0., -1.2], [ 0.5, 0., 1.2], [ 0., 0., 0.]]) + i = 0 + for batch in iter(data_train): + expected = first.asnumpy() if i == 0 else second.asnumpy() + assert_almost_equal(data_train.getdata().asnumpy(), expected) + i += 1 + + def check_libSVMIter_news_metadata(): + news_metadata = { + 'name': 'news20.t', + 'origin_name': 'news20.t.bz2', + 'url': "http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/news20.t.bz2", + 'shape': 62060, + 'num_classes': 20, + } + data_dir = os.path.join(os.getcwd(), 'data') + get_data(data_dir, news_metadata['name'], news_metadata['url'], + news_metadata['origin_name']) + path = os.path.join(data_dir, news_metadata['name']) + data_train = mx.io.LibSVMIter(data_libsvm=path, + data_shape=(news_metadata['shape'], ), + batch_size=512) + iterator = iter(data_train) + for batch in iterator: + # check the range of labels + assert(np.sum(batch.label[0].asnumpy() > 20) == 0) + assert(np.sum(batch.label[0].asnumpy() <= 0) == 0) + + check_libSVMIter_synthetic() + check_libSVMIter_news_metadata() if __name__ == "__main__": test_NDArrayIter() @@ -142,3 +238,5 @@ def test_NDArrayIter_h5py(): test_NDArrayIter_h5py() test_MNISTIter() test_Cifar10Rec() + test_LibSVMIter() + test_NDArrayIter_csr() diff --git a/tests/python/unittest/test_kvstore.py b/tests/python/unittest/test_kvstore.py index 87e5e0027241..e839cb869bc8 100644 --- a/tests/python/unittest/test_kvstore.py +++ b/tests/python/unittest/test_kvstore.py @@ -1,18 +1,19 @@ # pylint: skip-file import mxnet as mx import numpy as np +from mxnet.test_utils import rand_ndarray, assert_almost_equal shape = (4, 4) keys = [5, 7, 11] str_keys = ['b', 'c', 'd'] -def init_kv(): +def init_kv(stype='default'): """init kv """ kv = mx.kv.create() # single - kv.init(3, mx.nd.zeros(shape)) + kv.init(3, mx.nd.zeros(shape=shape, stype=stype)) # list - kv.init(keys, [mx.nd.zeros(shape)] * len(keys)) + kv.init(keys, [mx.nd.zeros(shape=shape, stype=stype)] * len(keys)) return kv def init_kv_with_str(): @@ -28,6 +29,7 @@ def check_diff_to_scalar(A, x): """ assert A == x""" assert(np.sum(np.abs((A - x).asnumpy())) == 0) + def test_single_kv_pair(): """single key-value pair push & pull""" def check_single_kv_pair(kv, key): @@ -93,10 +95,49 @@ def check_aggregator(kv, key, key_list): check_aggregator(init_kv_with_str(), 'a', str_keys) +def test_sparse_aggregator(): + """aggregate sparse ndarray on muliple devices""" + + stype = 'row_sparse' + kv = init_kv(stype) + + # devices + num_devs = 4 + devs = [mx.Context('cpu', i) for i in range(num_devs)] + + # single + vals = [rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)] + expected_sum = np.zeros(shape) + for v in vals: + expected_sum += v.asnumpy() + + kv.push(3, vals) + kv.pull(3, out = vals) + result_sum = np.zeros(shape) + for v in vals: + result_sum += v.asnumpy() + assert_almost_equal(result_sum, expected_sum * num_devs) + + # list + vals = [[rand_ndarray(shape, stype).copyto(devs[i]) for i in range(num_devs)]] * len(keys) + expected_sum = np.zeros(shape) + for v in vals[0]: + expected_sum += v.asnumpy() + + kv.push(keys, vals) + kv.pull(keys, out = vals) + for vv in vals: + result_sum = np.zeros(shape) + for v in vv: + result_sum += v.asnumpy() + assert_almost_equal(result_sum, expected_sum * num_devs) + + def updater(key, recv, local): """use updater: +=""" local += recv + def test_updater(dev = 'cpu'): """updater""" @@ -135,7 +176,6 @@ def check_updater(kv, key, key_list): str_kv._set_updater(updater) check_updater(str_kv, 'a', str_keys) - def test_get_type(): kvtype = 'local_allreduce_cpu' kv = mx.kv.create(kvtype) @@ -146,5 +186,6 @@ def test_get_type(): test_get_type() test_single_kv_pair() test_list_kv_pair() + test_sparse_aggregator() test_aggregator() test_updater() diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 766995dd2ac9..ad0c2083117c 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -1,11 +1,13 @@ -import mxnet as mx import mxnet.ndarray as nd +from mxnet.test_utils import * import numpy as np from functools import reduce from mxnet.module.executor_group import DataParallelExecutorGroup from common import assertRaises from collections import namedtuple +import numpy.random as rnd + def test_module_dtype(): dtype = np.float16 @@ -328,7 +330,6 @@ def mean_abs(x): break assert(mon_result_counts == [2, 2, 1, 6, 6, 4]) - def test_executor_group(): def get_rnn_sym(num_layers, num_words, num_hidden, num_embed, seq_len): stack = mx.rnn.SequentialRNNCell() @@ -440,6 +441,96 @@ def test_shared_exec_group(exec_grp_shared, exec_grp_created, shared_arg_names=N test_shared_exec_group(exec_grp_shared=exec_group1, exec_grp_created=exec_group2, shared_arg_names=shared_arg_names, extra_args=extra_args) +def test_module_fm(): + mx.random.seed(11) + rnd.seed(11) + def fm_model(k, feature_dim): + norm = mx.initializer.Normal(sigma=0.01) + x = mx.symbol.Variable("data", stype='csr') + v = mx.symbol.Variable("v", shape=(feature_dim, k), init=norm, stype='row_sparse') + + w1_weight = mx.symbol.var('w1_weight', shape=(feature_dim, 1), init=norm, stype='row_sparse') + w1 = mx.symbol.dot(x, w1_weight) + + v_s = mx.symbol.sum(data=mx.symbol.square(data=v), axis=1) + x_s = mx.symbol.square(data=x) + bd = 0.5 * mx.symbol.negative(data=mx.symbol.broadcast_mul(x_s, v_s)) + + w2 = mx.symbol.dot(x, v) + w2_squared = 0.5 * mx.symbol.square(data=w2) + + w_all = mx.symbol.Concat(w1, w2_squared, bd, dim=1) + model = mx.symbol.sum(data=w_all, axis=1, keepdims=True) + y = mx.symbol.Variable("out_label") + model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out") + return model + + # model + ctx = default_context() + k = 5 + feature_dim = 20 + model = fm_model(k, feature_dim) + + # data iter + num_batches = 8 + batch_size = 25 + num_samples = batch_size * num_batches + import scipy.sparse as sp + # generate some random scipy csr data + csr_sp = sp.rand(num_samples, feature_dim, density=0.5, format='csr') + csr_nd = mx.nd.csr(csr_sp.data, csr_sp.indptr, csr_sp.indices, + (num_samples, feature_dim)) + label = mx.nd.ones((num_samples,1)) + # the alternative is to use LibSVMIter + train_iter = mx.io.NDArrayIter(data=csr_nd, + label={'out_label':label}, + batch_size=batch_size) + # create module + mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['out_label']) + # allocate memory by given the input data and lable shapes + mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label) + # initialize parameters by uniform random numbers + mod.init_params(initializer=mx.init.Uniform(scale=.1)) + # use Sparse SGD with learning rate 0.1 to train + mod.init_optimizer(optimizer='sgd') + # use accuracy as the metric + metric = mx.metric.create('MSE') + # train 10 epoch + for epoch in range(10): + train_iter.reset() + metric.reset() + for batch in train_iter: + mod.forward(batch, is_train=True) # compute predictions + mod.update_metric(metric, batch.label) # accumulate prediction accuracy + mod.backward() # compute gradients + mod.update() # update parameters + # print('Epoch %d, Training %s' % (epoch, metric.get())) + assert(metric.get()[1] < 0.2) + +def test_module_initializer(): + def regression_model(m): + x = mx.symbol.var("data", stype='csr') + v = mx.symbol.var("v", shape=(m, 1), init=mx.init.Uniform(scale=.1), + stype='row_sparse') + model = mx.symbol.dot(lhs=x, rhs=v) + y = mx.symbol.Variable("label") + model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out") + return model + + n, m = 128, 100 + model = regression_model(m) + + data = mx.nd.zeros(shape=(n, m), stype='csr') + label = mx.nd.zeros((n, 1)) + iterator = mx.io.NDArrayIter(data=data, label={'label':label}, batch_size=n) + + # create module + mod = mx.mod.Module(symbol=model, data_names=['data'], label_names=['label']) + mod.bind(data_shapes=iterator.provide_data, label_shapes=iterator.provide_label) + mod.init_params() + v = mod._arg_params['v'] + assert(v.stype == 'row_sparse') + assert(np.sum(v.asnumpy()) != 0) def test_forward_reshape(): num_class=10 diff --git a/tests/python/unittest/test_multi_device_exec.py b/tests/python/unittest/test_multi_device_exec.py index 8956c4edebac..9823036867d6 100644 --- a/tests/python/unittest/test_multi_device_exec.py +++ b/tests/python/unittest/test_multi_device_exec.py @@ -1,4 +1,5 @@ import os +import numpy as np import mxnet as mx def test_ctx_group(): @@ -32,5 +33,35 @@ def test_ctx_group(): else: assert arr.context == group2ctx['stage2'] +def check_ctx_group_sparse(lhs_stype, rhs_stype): + with mx.AttrScope(ctx_group='stage1'): + lhs = mx.symbol.Variable('lhs', stype=lhs_stype) + rhs = mx.symbol.Variable('rhs', stype=rhs_stype) + plus = mx.symbol.elemwise_add(lhs, rhs, name='plus') + + set_stage1 = set(plus.list_arguments()) + with mx.AttrScope(ctx_group='stage2'): + softmax = mx.symbol.SoftmaxOutput(data = plus, name = 'softmax') + + set_stage2 = set(softmax.list_arguments()) - set_stage1 + + group2ctx = { + 'stage1' : mx.cpu(1), + 'stage2' : mx.cpu(2) + } + texec = softmax.simple_bind(mx.cpu(0), group2ctx=group2ctx, lhs=(1,200), rhs=(1,200)) + + for arr, name in zip(texec.arg_arrays, softmax.list_arguments()): + if name in set_stage1: + assert arr.context == group2ctx['stage1'] + else: + assert arr.context == group2ctx['stage2'] + +def test_ctx_group_sparse(): + check_ctx_group_sparse('default', 'default') + check_ctx_group_sparse('default', 'row_sparse') + check_ctx_group_sparse('row_sparse', 'row_sparse') + if __name__ == '__main__': test_ctx_group() + test_ctx_group_sparse() diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 8b7f8d6d7bf3..c0ed1aef2ba3 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -338,6 +338,7 @@ def test_dot(): assert_almost_equal(c, C.asnumpy()) + def test_reduce(): sample_num = 200 def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes): diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index cf7b82eaaa88..2b66eed4eaf2 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -30,26 +30,45 @@ def test_lr_wd_mult(): assert not mx.test_utils.almost_equal(args1['fc2_weight'], args2['fc2_weight'], 1e-1) -def compare_optimizer(opt1, opt2, shape, dtype): - w1 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) - g1 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) - - w2 = w1.copyto(default_context()) - g2 = g1.copyto(default_context()) +def compare_optimizer(opt1, opt2, shape, dtype, w_stype='default', g_stype='default'): + if w_stype == 'default': + w2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) + w1 = w2.copyto(default_context()) + elif w_stype == 'row_sparse': + w2 = rand_ndarray(shape, w_stype, density=1) + w1 = w2.copyto(default_context()).todense() + else: + raise Exception("type not supported yet") + if g_stype == 'default': + g2 = mx.random.uniform(shape=shape, ctx=default_context(), dtype=dtype) + g1 = g2.copyto(default_context()) + elif g_stype == 'row_sparse': + g2 = rand_ndarray(shape, g_stype) + g1 = g2.copyto(default_context()).todense() + else: + raise Exception("type not supported yet") state1 = opt1.create_state(0, w1) state2 = opt2.create_state(0, w2) if state1 is not None and state2 is not None: for s1, s2, in zip(state1, state2): - if s1 is not None or s2 is not None: assert(same(s1.asnumpy(), s2.asnumpy())) + if isinstance(state1, tuple): + for s1, s2, in zip(state1, state2): + if s1 is not None or s2 is not None: + assert(same(s1.asnumpy(), s2.asnumpy())) + else: + assert_almost_equal(state1.asnumpy(), state2.asnumpy()) opt1.update(0, w1, g1, state1) opt2.update(0, w2, g2, state2) if state1 is not None and state2 is not None: - for s1, s2, in zip(state1, state2): - if s1 is not None or s2 is not None: - assert_almost_equal(s1.asnumpy(), s2.asnumpy(), rtol=1e-4, atol=1e-5) + if isinstance(state1, tuple): + for s1, s2, in zip(state1, state2): + if s1 is not None or s2 is not None: + assert_almost_equal(s1.asnumpy(), s2.asnumpy(), rtol=1e-4, atol=1e-5) + else: + assert_almost_equal(state1.asnumpy(), state2.asnumpy()) assert_almost_equal(w1.asnumpy(), w2.asnumpy(), rtol=1e-4, atol=1e-5) # SGD @@ -170,6 +189,98 @@ def test_sgd(): continue compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype) +class PySparseSGD(mx.optimizer.Optimizer): + """python reference implemenation of sgd""" + def __init__(self, learning_rate=0.01, momentum=0.0, **kwargs): + super(PySparseSGD, self).__init__(learning_rate=learning_rate, **kwargs) + self.momentum = momentum + + def create_state(self, index, weight): + """Create additional optimizer state: momentum + + Parameters + ---------- + weight : NDArray + The weight data + + """ + if self.momentum == 0.0: + return None + else: + return mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype) + + def update(self, index, weight, grad, state): + """Update the parameters. + + Parameters + ---------- + index : int + An unique integer key used to index the parameters + + weight : NDArray + weight ndarray + + grad : NDArray + grad ndarray + + state : NDArray or other objects returned by init_state + The auxiliary state used in optimization. + """ + lr = self._get_lr(index) + wd = self._get_wd(index) + self._update_count(index) + num_rows = weight.shape[0] + if self.momentum == 0.0: + # Update on a per row basis, skip all-zero rows + for row in range(num_rows): + grad_row = grad[row].asnumpy() + all_zeros = mx.test_utils.almost_equal(grad_row, np.zeros_like(grad_row)) + if all_zeros: + continue + if self.clip_gradient is not None: + weight[row] = ((1 - lr*wd)*weight[row] - + lr*mx.nd.clip(grad[row]*self.rescale_grad, + -self.clip_gradient, self.clip_gradient)) + else: + weight[row] = (1 - lr*wd)*weight[row] - lr*self.rescale_grad*grad[row] + else: + mom = state + for row in range(num_rows): + grad_row = grad[row].asnumpy() + all_zeros = mx.test_utils.almost_equal(grad_row, np.zeros_like(grad_row)) + if all_zeros: + continue + if self.clip_gradient is not None: + mom[row] = (self.momentum*mom[row] - lr*wd*weight[row] - + lr*mx.nd.clip(grad[row]*self.rescale_grad, -self.clip_gradient, self.clip_gradient)) + weight[row] += mom[row] + else: + mom[row] = self.momentum*mom[row] - lr*wd*weight[row] - lr*self.rescale_grad*grad[row] + weight[row] += mom[row] + +def test_sparse_sgd(): + mx.random.seed(0) + opt1 = PySparseSGD + opt2 = mx.optimizer.SGD + shape = (3, 4) + kwargs = [{}, + {'momentum': 0.9}, + {'clip_gradient': 0.5}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14}, + {'rescale_grad': 0.8}, + {'clip_gradient': 0.5, 'wd': 0.07}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'wd': 0.03}, + {'rescale_grad': 0.8, 'wd': 0.05}, + {'clip_gradient': 0.5, 'momentum': 0.9}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'momentum': 0.9}, + {'rescale_grad': 0.8, 'momentum': 0.9}, + {'clip_gradient': 0.5, 'wd': 0.07, 'momentum': 0.9}, + {'clip_gradient': 0.4, 'rescale_grad': 0.14, 'wd': 0.03, 'momentum': 0.9}, + {'rescale_grad': 0.8, 'wd': 0.05, 'momentum': 0.9}] + for kwarg in kwargs: + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, w_stype='row_sparse', g_stype='row_sparse') + compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, w_stype='row_sparse', g_stype='default') + # ADAM class PyAdam(mx.optimizer.Optimizer): @@ -394,3 +505,4 @@ def test_rms(): test_adam() test_rms() test_sgd() + test_sparse_sgd() diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py new file mode 100644 index 000000000000..1ba219c97aae --- /dev/null +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -0,0 +1,395 @@ +import pickle as pkl + +from mxnet.ndarray import NDArray +from mxnet.test_utils import * +from numpy.testing import assert_allclose +import numpy.random as rnd + +from mxnet.ndarray import RowSparseNDArray, CSRNDArray + + +def assert_fcompex(f, *args, **kwargs): + prev_val = mx.test_utils.set_env_var("MXNET_EXEC_STORAGE_FALLBACK", "0", "1") + f(*args, **kwargs) + mx.test_utils.set_env_var("MXNET_EXEC_STORAGE_FALLBACK", prev_val) + + +def sparse_nd_ones(shape, stype): + return mx.nd.cast_storage(mx.nd.ones(shape), stype=stype) + + +def check_sparse_nd_elemwise_binary(shapes, stypes, f, g): + # generate inputs + nds = [] + for i, storage_type in enumerate(stypes): + if storage_type == 'row_sparse': + nd, _ = rand_sparse_ndarray(shapes[i], storage_type) + elif storage_type == 'default': + nd = mx.nd.array(random_arrays(shapes[i]), dtype = np.float32) + else: + assert(False) + nds.append(nd) + # check result + test = f(nds[0], nds[1]) + assert_almost_equal(test.asnumpy(), g(nds[0].asnumpy(), nds[1].asnumpy())) + + +def test_sparse_nd_elemwise_add(): + num_repeats = 10 + g = lambda x,y: x + y + op = mx.nd.elemwise_add + for i in range(num_repeats): + shape = [rand_shape_2d()] * 2 + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['default'] * 2, op, g) + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['default', 'row_sparse'], op, g) + assert_fcompex(check_sparse_nd_elemwise_binary, + shape, ['row_sparse', 'row_sparse'], op, g) + + +# Test a operator which doesn't implement FComputeEx +def test_sparse_nd_elementwise_fallback(): + num_repeats = 10 + g = lambda x,y: x + y + op = mx.nd.add_n + for i in range(num_repeats): + shape = [rand_shape_2d()] * 2 + check_sparse_nd_elemwise_binary(shape, ['default'] * 2, op, g) + check_sparse_nd_elemwise_binary(shape, ['default', 'row_sparse'], op, g) + check_sparse_nd_elemwise_binary(shape, ['row_sparse', 'row_sparse'], op, g) + + +def test_sparse_nd_zeros(): + def check_sparse_nd_zeros(stype, shape): + zero = mx.nd.zeros(shape) + sparse_zero = mx.nd.zeros(shape=shape, stype=stype) + assert_almost_equal(sparse_zero.asnumpy(), zero.asnumpy()) + + shape = rand_shape_2d() + check_sparse_nd_zeros('row_sparse', shape) + check_sparse_nd_zeros('csr', shape) + check_sparse_nd_zeros('default', shape) + + +def test_sparse_nd_copy(): + def check_sparse_nd_copy(from_stype, to_stype): + shape = rand_shape_2d() + from_nd = rand_ndarray(shape, from_stype) + # copy to ctx + to_ctx = from_nd.copyto(default_context()) + # copy to stype + to_nd = rand_ndarray(shape, to_stype) + to_nd = from_nd.copyto(to_nd) + assert np.sum(np.abs(from_nd.asnumpy() != to_ctx.asnumpy())) == 0.0 + assert np.sum(np.abs(from_nd.asnumpy() != to_nd.asnumpy())) == 0.0 + + check_sparse_nd_copy('row_sparse', 'row_sparse') + check_sparse_nd_copy('row_sparse', 'default') + check_sparse_nd_copy('default', 'row_sparse') + check_sparse_nd_copy('default', 'csr') + + +def check_sparse_nd_prop_rsp(): + storage_type = 'row_sparse' + shape = rand_shape_2d() + nd, (v, idx) = rand_sparse_ndarray(shape, storage_type) + assert(nd._num_aux == 1) + assert(nd.indices.dtype == np.int64) + assert(nd.stype == 'row_sparse') + assert_almost_equal(nd.indices.asnumpy(), idx) + + +def test_sparse_nd_basic(): + def check_rsp_creation(values, indices, shape): + rsp = mx.nd.row_sparse(values, indices, shape) + dns = mx.nd.zeros(shape) + dns[1] = mx.nd.array(values[0]) + dns[3] = mx.nd.array(values[1]) + indices_np = mx.nd.array(indices, dtype='int64').asnumpy() + assert_almost_equal(rsp.indices.asnumpy(), indices_np) + + def check_csr_creation(shape): + csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr') + assert_almost_equal(csr.indptr.asnumpy(), indptr) + assert_almost_equal(csr.indices.asnumpy(), indices) + assert_almost_equal(csr.data.asnumpy(), values) + + shape = (4,2) + values = np.random.rand(2,2) + indices = np.array([1,3], dtype='int64') + check_rsp_creation(values, indices, shape) + + values = mx.nd.array(np.random.rand(2,2)) + indices = mx.nd.array([1,3], dtype='int64') + check_rsp_creation(values, indices, shape) + + values = [[0.1, 0.2], [0.3, 0.4]] + indices = [1,3] + check_rsp_creation(values, indices, shape) + + check_csr_creation(shape) + check_sparse_nd_prop_rsp() + + +def test_sparse_nd_setitem(): + def check_sparse_nd_setitem(stype, shape, dst): + x = mx.nd.zeros(shape=shape, stype=stype) + x[:] = dst + dst_nd = mx.nd.array(dst) if isinstance(dst, (np.ndarray, np.generic)) else dst + assert same(x.asnumpy(), dst_nd.asnumpy()) + + shape = rand_shape_2d() + for stype in ['row_sparse', 'csr']: + # ndarray assignment + check_sparse_nd_setitem(stype, shape, rand_ndarray(shape, 'default')) + check_sparse_nd_setitem(stype, shape, rand_ndarray(shape, stype)) + # numpy assignment + check_sparse_nd_setitem(stype, shape, np.ones(shape)) + + +def test_sparse_nd_slice(): + def check_sparse_nd_csr_slice(shape): + storage_type = 'csr' + A, _ = rand_sparse_ndarray(shape, storage_type) + A2 = A.asnumpy() + start = rnd.randint(0, shape[0] - 1) + end = rnd.randint(start + 1, shape[0]) + assert same(A[start:end].asnumpy(), A2[start:end]) + assert same(A[start:].asnumpy(), A2[start:]) + assert same(A[:end].asnumpy(), A2[:end]) + + shape = (rnd.randint(2, 10), rnd.randint(1, 10)) + check_sparse_nd_csr_slice(shape) + + +def test_sparse_nd_equal(): + for stype in ['row_sparse', 'csr']: + shape = rand_shape_2d() + x = mx.nd.zeros(shape=shape, stype=stype) + y = sparse_nd_ones(shape, stype) + z = x == y + assert (z.asnumpy() == np.zeros(shape)).all() + z = 0 == x + assert (z.asnumpy() == np.ones(shape)).all() + + +def test_sparse_nd_not_equal(): + for stype in ['row_sparse', 'csr']: + shape = rand_shape_2d() + x = mx.nd.zeros(shape=shape, stype=stype) + y = sparse_nd_ones(shape, stype) + z = x != y + assert (z.asnumpy() == np.ones(shape)).all() + z = 0 != x + assert (z.asnumpy() == np.zeros(shape)).all() + + +def test_sparse_nd_greater(): + for stype in ['row_sparse', 'csr']: + shape = rand_shape_2d() + x = mx.nd.zeros(shape=shape, stype=stype) + y = sparse_nd_ones(shape, stype) + z = x > y + assert (z.asnumpy() == np.zeros(shape)).all() + z = y > 0 + assert (z.asnumpy() == np.ones(shape)).all() + z = 0 > y + assert (z.asnumpy() == np.zeros(shape)).all() + + +def test_sparse_nd_greater_equal(): + for stype in ['row_sparse', 'csr']: + shape = rand_shape_2d() + x = mx.nd.zeros(shape=shape, stype=stype) + y = sparse_nd_ones(shape, stype) + z = x >= y + assert (z.asnumpy() == np.zeros(shape)).all() + z = y >= 0 + assert (z.asnumpy() == np.ones(shape)).all() + z = 0 >= y + assert (z.asnumpy() == np.zeros(shape)).all() + z = y >= 1 + assert (z.asnumpy() == np.ones(shape)).all() + + +def test_sparse_nd_lesser(): + for stype in ['row_sparse', 'csr']: + shape = rand_shape_2d() + x = mx.nd.zeros(shape=shape, stype=stype) + y = sparse_nd_ones(shape, stype) + z = y < x + assert (z.asnumpy() == np.zeros(shape)).all() + z = 0 < y + assert (z.asnumpy() == np.ones(shape)).all() + z = y < 0 + assert (z.asnumpy() == np.zeros(shape)).all() + + +def test_sparse_nd_lesser_equal(): + for stype in ['row_sparse', 'csr']: + shape = rand_shape_2d() + x = mx.nd.zeros(shape=shape, stype=stype) + y = sparse_nd_ones(shape, stype) + z = y <= x + assert (z.asnumpy() == np.zeros(shape)).all() + z = 0 <= y + assert (z.asnumpy() == np.ones(shape)).all() + z = y <= 0 + assert (z.asnumpy() == np.zeros(shape)).all() + z = 1 <= y + assert (z.asnumpy() == np.ones(shape)).all() + + +def test_sparse_nd_binary(): + N = 100 + def check_binary(fn): + for _ in range(N): + ndim = 2 + oshape = np.random.randint(1, 6, size=(ndim,)) + bdim = 2 + lshape = list(oshape) + rshape = list(oshape[ndim-bdim:]) + for i in range(bdim): + sep = np.random.uniform(0, 1) + if sep < 0.33: + lshape[ndim-i-1] = 1 + elif sep < 0.66: + rshape[bdim-i-1] = 1 + lhs = np.random.uniform(0, 1, size=lshape) + rhs = np.random.uniform(0, 1, size=rshape) + lhs_nd_csr = mx.nd.array(lhs)._to_csr() + rhs_nd_csr = mx.nd.array(rhs)._to_csr() + lhs_nd_rsp = mx.nd.array(lhs)._to_rsp() + rhs_nd_rsp = mx.nd.array(rhs)._to_rsp() + for lhs_nd, rhs_nd in [(lhs_nd_csr, rhs_nd_csr), (lhs_nd_rsp, rhs_nd_rsp)]: + assert_allclose(fn(lhs, rhs), + fn(lhs_nd, rhs_nd).asnumpy(), + rtol=1e-4, atol=1e-4) + + check_binary(lambda x, y: x + y) + check_binary(lambda x, y: x - y) + check_binary(lambda x, y: x * y) + check_binary(lambda x, y: x / y) + check_binary(lambda x, y: x ** y) + check_binary(lambda x, y: x > y) + check_binary(lambda x, y: x < y) + check_binary(lambda x, y: x >= y) + check_binary(lambda x, y: x <= y) + check_binary(lambda x, y: x == y) + + +def test_sparse_nd_binary_rop(): + N = 100 + def check(fn): + for _ in range(N): + ndim = 2 + shape = np.random.randint(1, 6, size=(ndim,)) + npy_nd = np.random.normal(0, 1, size=shape) + csr_nd = mx.nd.array(npy_nd)._to_csr() + rsp_nd = mx.nd.array(npy_nd)._to_rsp() + for sparse_nd in [csr_nd, rsp_nd]: + assert_allclose( + fn(npy_nd), + fn(sparse_nd).asnumpy(), + rtol=1e-4, + atol=1e-4 + ) + check(lambda x: 1 + x) + check(lambda x: 1 - x) + check(lambda x: 1 * x) + check(lambda x: 1 / x) + check(lambda x: 2 ** x) + check(lambda x: 1 > x) + check(lambda x: 0.5 > x) + check(lambda x: 0.5 < x) + check(lambda x: 0.5 >= x) + check(lambda x: 0.5 <= x) + check(lambda x: 0.5 == x) + + +def test_sparse_nd_negate(): + npy = np.random.uniform(-10, 10, rand_shape_2d()) + arr_csr = mx.nd.array(npy)._to_csr() + arr_rsp = mx.nd.array(npy)._to_rsp() + for arr in [arr_csr, arr_rsp]: + assert_almost_equal(npy, arr.asnumpy()) + assert_almost_equal(-npy, (-arr).asnumpy()) + + # a final check to make sure the negation (-) is not implemented + # as inplace operation, so the contents of arr does not change after + # we compute (-arr) + assert_almost_equal(npy, arr.asnumpy()) + + +def test_sparse_nd_output_fallback(): + shape = (10, 10) + out = mx.nd.zeros(shape=shape, stype='row_sparse') + mx.nd.random_normal(shape=shape, out=out) + assert(np.sum(out.asnumpy()) != 0) + + +def test_sparse_nd_astype(): + stypes = ['row_sparse', 'csr'] + for stype in stypes: + x = mx.nd.zeros(shape=rand_shape_2d(), stype=stype, dtype='float32') + y = x.astype('int32') + assert(y.dtype == np.int32), y.dtype + + +def test_sparse_ndarray_pickle(): + np.random.seed(0) + repeat = 10 + dim0 = 40 + dim1 = 40 + stypes = ['row_sparse', 'csr'] + densities = [0, 0.01, 0.1, 0.2, 0.5] + stype_dict = {'row_sparse': RowSparseNDArray, 'csr': CSRNDArray} + for _ in range(repeat): + shape = rand_shape_2d(dim0, dim1) + for stype in stypes: + for density in densities: + a, _ = rand_sparse_ndarray(shape, stype, density) + assert isinstance(a, stype_dict[stype]) + data = pkl.dumps(a) + b = pkl.loads(data) + assert isinstance(b, stype_dict[stype]) + assert same(a.asnumpy(), b.asnumpy()) + + +def test_sparse_ndarray_save_load(): + np.random.seed(0) + repeat = 1 + stypes = ['default', 'row_sparse', 'csr'] + stype_dict = {'default': NDArray, 'row_sparse': RowSparseNDArray, 'csr': CSRNDArray} + num_data = 20 + densities = [0, 0.01, 0.1, 0.2, 0.5] + fname = 'tmp_list.bin' + for _ in range(repeat): + data_list1 = [] + for i in range(num_data): + stype = stypes[np.random.randint(0, len(stypes))] + shape = rand_shape_2d(dim0=40, dim1=40) + density = densities[np.random.randint(0, len(densities))] + data_list1.append(rand_ndarray(shape, stype, density)) + assert isinstance(data_list1[-1], stype_dict[stype]) + mx.nd.save(fname, data_list1) + + data_list2 = mx.nd.load(fname) + assert len(data_list1) == len(data_list2) + for x, y in zip(data_list1, data_list2): + assert same(x.asnumpy(), y.asnumpy()) + + data_map1 = {'ndarray xx %s' % i: x for i, x in enumerate(data_list1)} + mx.nd.save(fname, data_map1) + data_map2 = mx.nd.load(fname) + assert len(data_map1) == len(data_map2) + for k, x in data_map1.items(): + y = data_map2[k] + assert same(x.asnumpy(), y.asnumpy()) + os.remove(fname) + + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py new file mode 100644 index 000000000000..d0064a9265f8 --- /dev/null +++ b/tests/python/unittest/test_sparse_operator.py @@ -0,0 +1,214 @@ +from mxnet.test_utils import * + + +def check_elemwise_add_ex(lhs_stype, rhs_stype, shape, lhs_grad_stype=None, rhs_grad_stype=None): + lhs = mx.symbol.Variable('lhs', stype=lhs_stype) + rhs = mx.symbol.Variable('rhs', stype=rhs_stype) + lhs_nd = rand_ndarray(shape, lhs_stype) + rhs_nd = rand_ndarray(shape, rhs_stype) + lhs_np = lhs_nd.asnumpy() + rhs_np = rhs_nd.asnumpy() + + out_np = lhs_np + rhs_np + test = mx.symbol.elemwise_add(lhs, rhs) + location = {'lhs': lhs_nd, 'rhs': rhs_nd} + check_symbolic_forward(test, location, [out_np]) + check_numeric_gradient(test, location) + grad_stypes = {} + if lhs_grad_stype is not None and lhs_grad_stype != 'default': + grad_stypes['lhs'] = lhs_grad_stype + if rhs_grad_stype is not None and rhs_grad_stype != 'default': + grad_stypes['rhs'] = rhs_grad_stype + check_symbolic_backward(test, location, [out_np], [out_np, out_np], + grad_stypes=grad_stypes) + + +def test_elemwise_add_ex(): + shape = rand_shape_2d() + check_elemwise_add_ex('default', 'default', shape) + check_elemwise_add_ex('default', 'row_sparse', shape) + check_elemwise_add_ex('row_sparse', 'default', shape) + check_elemwise_add_ex('row_sparse', 'row_sparse', shape, + lhs_grad_stype='row_sparse', rhs_grad_stype='row_sparse') + + +# TODO(haibin) randomize this test +def test_elemwise_add_ex_multiple_stages(): + # prep data + shape = (4, 2) + ds_np = np.array([[1, 2], [3, 4], [5, 6], [7, 8]]) + sp_np1 = np.array([[5, 10], [0, 0], [0, 0], [0, 0]]) + sp_np2 = np.array([[0, 0], [5, 10], [0, 0], [0, 0]]) + + val1 = mx.nd.array([[5, 10]]); + val2 = mx.nd.array([[5, 10]]); + idx1 = mx.nd.array([0], dtype=np.int64); + idx2 = mx.nd.array([1], dtype=np.int64); + sp_nd1 = mx.nd.row_sparse(val1, idx1, shape) + sp_nd2 = mx.nd.row_sparse(val2, idx2, shape) + ds_nd = mx.nd.array(ds_np) + + # sparse + sparse = sparse + sp_data1 = mx.symbol.Variable('sp_data1', stype='row_sparse') + sp_data2 = mx.symbol.Variable('sp_data2', stype='row_sparse') + ds_data = mx.symbol.Variable('ds_data') + plus = mx.symbol.elemwise_add(sp_data1, sp_data2, name='plus') + # sparse + dense = dense + test = mx.symbol.elemwise_add(plus, ds_data) + check_symbolic_forward(test, {'sp_data1': sp_nd1, 'sp_data2': sp_nd2, + 'ds_data': ds_nd}, [sp_np1 + sp_np2 + ds_np]) + + arr_grads = [mx.nd.zeros(shape) for i in range(3)] + exec_test = test.bind(default_context(), args={'sp_data1': sp_nd1, 'sp_data2': sp_nd2, + 'ds_data': ds_nd}, args_grad=arr_grads) + exec_test.forward(is_train=True) + assert_almost_equal(exec_test.outputs[0].asnumpy(), sp_np1 + sp_np2 + ds_np) + exec_test.backward(out_grads=exec_test.outputs) + assert_almost_equal(arr_grads[0].asnumpy(), arr_grads[1].asnumpy()) + +# TODO(haibin) also add test for backward pass. +def test_cast_storage_ex(): + def test_rsp_to_dns(shape): + rsp, (data, row_idx) = rand_sparse_ndarray(shape, 'row_sparse') + dns_out = mx.nd.cast_storage(rsp, stype='default') + dns_expected = np.zeros(shape, dtype=default_dtype()) + if row_idx is not None: + for k, v in enumerate(row_idx): + dns_expected[v, :] = data[k] + assert same(dns_out.asnumpy(), dns_expected) + + def test_dns_to_rsp(shape): + dns_in = rand_ndarray(shape, 'default') + rsp_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), stype='row_sparse') + ret = mx.nd.cast_storage(rsp_out, stype='default') + assert same(ret.asnumpy(), dns_in.asnumpy()) + + def test_csr_to_dns(shape): + csr, (indptr, indices, values) = rand_sparse_ndarray(shape, 'csr') + mx_dns = csr.todense() + np_dns = sp.csr_matrix((values, indices, indptr), shape).todense() + assert_almost_equal(mx_dns.asnumpy(), np_dns) + + def test_dns_to_csr(dns_in): + dns_in = np.array(dns_in) + csr_out = mx.nd.cast_storage(mx.nd.array(dns_in, dtype=default_dtype()), stype='csr') + ret = mx.nd.cast_storage(csr_out, stype='default') + assert same(ret.asnumpy(), dns_in) + + shape = rand_shape_2d() + test_rsp_to_dns(shape) + test_dns_to_rsp(shape) + test_csr_to_dns((4, 4)) + test_dns_to_csr([[0, 1, 0], [0, 2, 0], [3, 0, 0], [0, 0, 4], [5, 6, 0], [0, 0, 7]]) + + +def test_sparse_dot(): + def test_dot_csr(lhs_shape, rhs_shape, rhs_stype, trans_lhs, density=1): + lhs_nd = rand_ndarray(lhs_shape, 'csr', 1) + lhs_dns = lhs_nd.todense() + rhs_nd = rand_ndarray(rhs_shape, rhs_stype, density=density) + rhs_dns = rhs_nd if rhs_stype == 'default' else rhs_nd.todense() + out = mx.nd.dot(lhs_nd, rhs_dns, transpose_a=trans_lhs) + if trans_lhs and default_context().device_type is 'cpu': + assert out.stype == 'row_sparse' + else: + assert out.stype == 'default' + out_expected = mx.nd.dot(lhs_dns, rhs_dns, transpose_a=trans_lhs) + out_np = out_expected.asnumpy() + backward_trans = not trans_lhs + rhs_backward_grad = mx.nd.dot(lhs_dns, out_expected, transpose_a=backward_trans).asnumpy() + assert_almost_equal(out.asnumpy(), out_np, rtol=1e-4, atol=1e-5) + + # test symbolic forward + lhs = mx.symbol.Variable('lhs', stype='csr') + rhs = mx.symbol.Variable('rhs', stype=rhs_stype) + test = mx.symbol.dot(lhs, rhs, transpose_a=trans_lhs) + location = {'lhs': lhs_nd, 'rhs': rhs_nd} + expected = {'rhs': rhs_backward_grad} + check_symbolic_forward(test, location, [out_np], rtol=1e-3, atol=1e-4) + # test symbolic backward + check_symbolic_backward(test, location, [out_np], expected, + grad_req={'lhs': 'null', 'rhs': 'write'}, + rtol=1e-3, atol=1e-4) + + lhs_shape = rand_shape_2d(50, 200) + test_dot_csr(lhs_shape, (lhs_shape[1], 1), 'default', False) + test_dot_csr(lhs_shape, (lhs_shape[0], 1), 'default', True) + test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'default', False) + test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'default', True) + test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False) + test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True) + test_dot_csr(lhs_shape, (lhs_shape[1], rnd.randint(1, 10)), 'row_sparse', False, 0.05) + test_dot_csr(lhs_shape, (lhs_shape[0], rnd.randint(1, 10)), 'row_sparse', True, 0.05) + + +def test_sparse_embedding(): + in_dim = 10 + out_dim = 4 + batch = 24 + + data = mx.sym.Variable("data", stype='csr') + embed = mx.sym.SparseEmbedding(data=data, input_dim=in_dim, output_dim=out_dim, name="embed") + exe_test = embed.simple_bind(default_context(), grad_req={'data': 'null', 'embed_weight': 'write'}, + data=(batch, in_dim)) + + arg_map = dict(zip(embed.list_arguments(), exe_test.arg_arrays)) + grad_map = dict(zip(embed.list_arguments(), exe_test.grad_arrays)) + np_data = np.random.randint(low=0, high=in_dim, size=batch) + np_weight = np.random.uniform(-0.01, 0.01, arg_map["embed_weight"].shape) + np_onehot = np.zeros((batch, in_dim)) + np_onehot[np.arange(batch), np_data] = 1.0 + nd_onehot = mx.nd.array(np_onehot)._to_csr() + # forward + arg_map["data"][:] = nd_onehot + arg_map["embed_weight"][:] = np_weight + exe_test.forward(is_train=True) + assert_almost_equal(exe_test.outputs[0].asnumpy(), np.dot(np_onehot, np_weight)) + # backward + np_grad = np.random.uniform(-1, 1, exe_test.outputs[0].shape) + grad = mx.nd.zeros(np_grad.shape) + grad[:] = np_grad + exe_test.backward([grad]) + assert_almost_equal(grad_map["embed_weight"].asnumpy(), np.dot(np_onehot.T, np_grad), atol=1e-5) + + +def test_sparse_slice(): + def check_csr_slice(shape, slice_input): + storage_type = 'csr' + B, _ = rand_sparse_ndarray(shape, storage_type) + np = B.asnumpy() + begin = rnd.randint(0, B.shape[0] - 1) + end = rnd.randint(begin + 1, B.shape[0]) + nd_slice = mx.nd.crop(B, begin=begin, end=end) + assert same(nd_slice.asnumpy(), np[begin:end]), (nd_slice.asnumpy(), np[begin:end]) + + shape = (rnd.randint(7, 15), rnd.randint(1, 10)) + check_csr_slice(shape, True) + check_csr_slice(shape, False) + + +def test_sparse_retain(): + for _ in range(10): + shape = rand_shape_2d() + num_rows = shape[0] + rsp, _ = rand_sparse_ndarray(shape=shape, stype='row_sparse', density=0.5) + length = np.random.randint(1, num_rows + 1) + idx = random_sample(list(range(0, num_rows)), length) + idx.sort() + dns = rsp.asnumpy() + tensor_retained_expected = np.zeros(shape) + for i in idx: + tensor_retained_expected[i][:] = dns[i] + indices = mx.nd.array(idx) + rsp_retained = mx.nd.sparse_retain(rsp, indices=indices) + assert same(tensor_retained_expected, rsp_retained.asnumpy()) + + # check numeric gradient + data = mx.symbol.Variable('data') + idx = mx.symbol.Variable('indices') + sym = mx.sym.sparse_retain(data=data, indices=idx) + check_numeric_gradient(sym, [rsp, indices], grad_nodes=['data'], grad_stype_dict={'data': 'row_sparse'}) + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/travis/run_test.sh b/tests/travis/run_test.sh index cff4196b6043..6b8f778e29ab 100755 --- a/tests/travis/run_test.sh +++ b/tests/travis/run_test.sh @@ -99,21 +99,21 @@ if [ ${TASK} == "python_test" ]; then mkdir -p ${PWD}/data if [ ${TRAVIS_OS_NAME} == "osx" ]; then - python -m nose tests/python/unittest || exit -1 - python3 -m nose tests/python/unittest || exit -1 + python -m nose -v tests/python/unittest || exit -1 + python3 -m nose -v tests/python/unittest || exit -1 # make cython3 # cython tests # export MXNET_ENFORCE_CYTHON=1 # python3 -m nose tests/python/unittest || exit -1 - python3 -m nose tests/python/train || exit -1 - python -m nose tests/python/doctest || exit -1 - python3 -m nose tests/python/doctest || exit -1 + python3 -m nose -v tests/python/train || exit -1 + python -m nose -v tests/python/doctest || exit -1 + python3 -m nose -v tests/python/doctest || exit -1 else - nosetests tests/python/unittest || exit -1 - nosetests3 tests/python/unittest || exit -1 - nosetests3 tests/python/train || exit -1 - nosetests tests/python/doctest || exit -1 - nosetests3 tests/python/doctest || exit -1 + nosetests -v tests/python/unittest || exit -1 + nosetests3 -v tests/python/unittest || exit -1 + nosetests3 -v tests/python/train || exit -1 + nosetests -v tests/python/doctest || exit -1 + nosetests3 -v tests/python/doctest || exit -1 fi exit 0 fi diff --git a/tests/travis/setup.sh b/tests/travis/setup.sh index ec071009bda5..7c9d137b8269 100755 --- a/tests/travis/setup.sh +++ b/tests/travis/setup.sh @@ -15,8 +15,8 @@ if [ ${TRAVIS_OS_NAME} == "osx" ]; then brew install ImageMagick brew install swig if [ ${TASK} == "python_test" ]; then - python -m pip install --user nose numpy cython - python3 -m pip install --user nose numpy cython + python -m pip install --user nose numpy cython scipy + python3 -m pip install --user nose numpy cython scipy fi fi