diff --git a/benchmark/python/sparse_op.py b/benchmark/python/sparse_op.py index 42ec6e5eabaf..0aef3bc3ae31 100644 --- a/benchmark/python/sparse_op.py +++ b/benchmark/python/sparse_op.py @@ -1,19 +1,25 @@ -# pylint: skip-file -import mxnet as mx +import ctypes + from mxnet.test_utils import * -import numpy as np import scipy.sparse as sp -import os, gzip -import pickle as pickle +import os import time -import sys +import argparse + +from mxnet.base import check_call, _LIB + +parser = argparse.ArgumentParser(description="Benchmark sparse operators", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument('--num-omp-threads', type=int, default=1, help='number of omp threads to set in MXNet') +args = parser.parse_args() + def get_avazu(data_dir): if not os.path.isdir(data_dir): os.system("mkdir " + data_dir) os.chdir(data_dir) if (not os.path.exists('avazu-app.t')): - import urllib, zipfile + import urllib zippath = os.path.join(data_dir, "avazu-app.t.bz2") url = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/avazu-app.t.bz2" urllib.urlretrieve(url, zippath) @@ -21,6 +27,7 @@ def get_avazu(data_dir): os.system("bzip2 -d avazu-app.t.bz2") os.chdir("..") + def test_dot_real(): def get_iter(path, data_shape, batch_size): data_train = mx.io.LibSVMIter(data_libsvm=path, @@ -59,22 +66,39 @@ def get_iter(path, data_shape, batch_size): cost = end - start print(size / cost, cost, num_batch, num_batch / cost) + def test_dot_synthetic(): """benchmark mx.nd.dot(sparse_ndarray, dense_ndarray) with given density. `t_sparse` is the time cost of dot(csr, dns), while `t_dense` is the time cost of dot(dns, dns), with the same matrix except that it is in default storage type. """ + def measure_cost_forward_baseline(repeat, dot, lhs, rhs): + start = time.time() + for i in range(repeat): + dot(lhs, rhs) + end = time.time() + diff = end - start + return diff / repeat + + def measure_cost_backward_baseline(repeat, dot, transpose, lhs, rhs): + start = time.time() + for i in range(repeat): + dot(transpose(lhs), rhs) + end = time.time() + diff = end -start + return diff / repeat + def measure_cost(repeat, f, *args, **kwargs): - # start bench - start = time.time() - results = [] - for i in range(repeat): - results.append(f(*args, **kwargs)) - for result in results: - result.wait_to_read() - end = time.time() - diff = end - start - return diff / repeat + # start bench + start = time.time() + results = [] + for i in range(repeat): + results.append(f(*args, **kwargs)) + for result in results: + result.wait_to_read() + end = time.time() + diff = end - start + return diff / repeat def bench_dot_forward(m, k, n, density, ctx, repeat): set_default_context(ctx) @@ -82,6 +106,9 @@ def bench_dot_forward(m, k, n, density, ctx, repeat): data_shape = (m, k) csr_data = rand_ndarray(data_shape, 'csr', density) dns_data = csr_data.to_dense() + rhs_dns_np = dns.asnumpy() + lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) # csr in scipy + lhs_dns_np = lhs_csr_sp.todense() data = [dns_data, csr_data] costs = [] @@ -91,8 +118,16 @@ def bench_dot_forward(m, k, n, density, ctx, repeat): cost = measure_cost(repeat, mx.nd.dot, d, dns) costs.append(cost / repeat) ratio = costs[1] / costs[0] - fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.6f\t%0.5f\t%0.2f" - print(fmt % (density * 100, str(ctx), n, m, k, costs[1], costs[0], ratio)) + + costs_baseline = [] + cost = measure_cost_forward_baseline(repeat, np.dot, lhs_dns_np, rhs_dns_np) + costs_baseline.append(cost) + cost = measure_cost_forward_baseline(repeat, sp.spmatrix.dot, lhs_csr_sp, rhs_dns_np) + costs_baseline.append(cost) + ratio_baseline = costs_baseline[1] / costs_baseline[0] + fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.6f\t%0.5f\t%0.2f\t\t\t%0.6f\t%0.5f\t\t%0.2f" + print(fmt % (density * 100, str(ctx), n, m, k, costs[1], costs[0], ratio, + costs_baseline[1], costs_baseline[0], ratio_baseline)) def bench_dot_backward(m, k, n, density, ctx, repeat): set_default_context(ctx) @@ -100,6 +135,9 @@ def bench_dot_backward(m, k, n, density, ctx, repeat): data_shape = (m, k) csr_data = rand_ndarray(data_shape, 'csr', density) dns_data = csr_data.to_dense() + rhs_dns_np = dns.asnumpy() + lhs_csr_sp = sp.csr_matrix(dns_data.asnumpy()) + lhs_dns_np = lhs_csr_sp.todense() data = [dns_data, csr_data] costs = [] @@ -109,15 +147,24 @@ def bench_dot_backward(m, k, n, density, ctx, repeat): cost = measure_cost(repeat, mx.nd.dot, d, dns, transpose_a=True) costs.append(cost) ratio = costs[1] / costs[0] - fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.6f\t%0.5f\t%0.2f" - print(fmt % (density * 100, str(ctx), n, m, k, costs[1], costs[0], ratio)) + costs_baseline = [] + cost = measure_cost_backward_baseline(repeat, np.dot, np.transpose, lhs_dns_np, rhs_dns_np) + costs_baseline.append(cost) + cost = measure_cost_backward_baseline(repeat, sp.spmatrix.dot, sp.spmatrix.transpose, lhs_csr_sp, rhs_dns_np) + costs_baseline.append(cost) + ratio_baseline = costs_baseline[1] / costs_baseline[0] + fmt = "%0.1f\t\t%s\t%d\t%d\t%d\t%0.6f\t%0.5f\t%0.2f\t\t\t%0.6f\t%0.5f\t\t%0.2f" + print(fmt % (density * 100, str(ctx), n, m, k, costs[1], costs[0], ratio, + costs_baseline[1], costs_baseline[0], ratio_baseline)) print("A = sparse NDArray of shape(m, k)") print("B = dense NDArray of shape(k, n)") print("dot_forward\tdot(csr, dns)") - print('density(%)\tcontext\tn\tm\tk\tt_sparse\tt_dense\tt_sparse/t_dense') + print('density(%)\tcontext\tn\tm\tk\tt_sparse\tt_dense\tt_sparse/t_dense' + '\tt_scipy_sparse\tt_scipy_dense\tt_scipy_sparse/t_scipy_dense') + check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads))) # TODO(haibin) make these runtime options m = 512 k = [50000, 100000] @@ -132,7 +179,8 @@ def bench_dot_backward(m, k, n, density, ctx, repeat): bench_dot_forward(m, k[i], n[i], den, ctx, num_repeat) print("dot_backward\tdot(csr.T, dns)") - print('density(%)\tcontext\tn\tm\tk\tt_sparse\tt_dense\tt_sparse/t_dense') + print('density(%)\tcontext\tn\tm\tk\tt_sparse\tt_dense\tt_sparse/t_dense' + '\tt_scipy_sparse\tt_scipy_dense\tt_scipy_sparse/t_scipy_dense') for i in range(2): for ctx in contexts: for den in density: diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 9b5dcfe3d3b1..6a9ee30f1b04 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -7,6 +7,7 @@ #ifndef MXNET_OPERATOR_MXNET_OP_H_ #define MXNET_OPERATOR_MXNET_OP_H_ +#include #include #include @@ -22,6 +23,8 @@ const float PI = 3.14159265358979323846; using std::isnan; #endif +template +int get_num_threads(const int N); #ifdef __CUDACC__ #define CUDA_KERNEL_LOOP(i, n) \ @@ -37,8 +40,18 @@ inline int cuda_get_num_blocks(const int N) { using namespace mshadow::cuda; return std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum); } + +template<> +inline int get_num_threads(const int N) { + using namespace mshadow::cuda; + return kBaseThreadNum * cuda_get_num_blocks(N); +} #endif // __CUDACC__ +template<> +inline int get_num_threads(const int N) { + return omp_get_max_threads(); +} /*! \brief operator request type switch */ #define MXNET_ASSIGN_REQ_SWITCH(req, ReqType, ...) \ diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 3b54bf240447..135389685b8b 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -10,6 +10,7 @@ #include #include #include +#include #include "../mshadow_op.h" #include "../elemwise_op_common.h" #include "../mxnet_op.h" @@ -495,20 +496,12 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, return true; } -/*! - * \brief Tempalte declaration of dot(csr, dns1) = dns2. - * Whether csr and dns1 are transposed before dot operation - * is determined by trans_csr and trans_dns, respectively. - * For now we only implemented the case when trans_dns = false. - */ -template -struct DotCsrDnsDns; - /*! * \brief Kernel of dot(csr, dns1) = dns2 + * Parallelization by output matrix elements */ template -struct DotCsrDnsDns { +struct DotCsrDnsDns { /*! * \brief This function represents performing an inner product between a row of lhs * and a column of rhs and then assigning the value to out[i]. @@ -537,9 +530,10 @@ struct DotCsrDnsDns { /*! * \brief Kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by output matrix elements */ template -struct DotCsrDnsDns { +struct DotCsrTransDnsDns { /*! * \brief This function represents performing an inner product between a column of lhs * and a column of rhs and then assigning the value to out[i]. @@ -583,6 +577,69 @@ struct DotCsrDnsDns { } }; +/*! + * \brief Kernel of dot(csr, dns1) = dns2 + * Parallelization by row blocks + */ +struct DotCsrDnsDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const size_t seg_len, + const size_t num_rows, const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows); + for (size_t j = seg_start; j < seg_end; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_out = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto val = data_l[k]; + const size_t offset_r = col_idx_l[k] * num_cols; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + +/*! + * \brief Kernel of dot(csr.T(), dns1) = dns2 + * Parallelization by row blocks + */ +struct DotCsrTransDnsDnsByRowBlocks { + /*! + * \brief + * \param i the i-th thread + */ + template + MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, + const CType* col_idx_l, const DType* data_r, const size_t seg_len, + const size_t num_rows_l, const size_t num_rows, + const size_t num_cols) { + const size_t seg_start = i * seg_len; + if (seg_start >= num_rows) return; + const size_t seg_end = (i + 1) * seg_len; + for (size_t j = 0; j < num_rows_l; ++j) { + if (indptr_l[j] == indptr_l[j+1]) continue; + const size_t offset_r = j * num_cols; + for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { + const auto col_idx = col_idx_l[k]; + if (col_idx < seg_start || col_idx >= seg_end) continue; + const size_t offset_out = col_idx * num_cols; + const auto val = data_l[k]; + for (size_t l = 0; l < num_cols; ++l) { + out[offset_out+l] += data_r[offset_r+l] * val; + } + } + } + } +}; + template void DotCsrDnsDnsImpl(const OpContext& ctx, const NDArray& lhs, @@ -594,6 +651,7 @@ void DotCsrDnsDnsImpl(const OpContext& ctx, CHECK_EQ(lhs.storage_type(), kCSRStorage); CHECK_EQ(rhs.storage_type(), kDefaultStorage); CHECK_EQ(ret->storage_type(), kDefaultStorage); + if (!lhs.storage_initialized()) return; mshadow::Stream *s = ctx.get_stream(); const TBlob data_l = lhs.data(); @@ -602,22 +660,43 @@ void DotCsrDnsDnsImpl(const OpContext& ctx, const TBlob data_r = rhs.data(); const TBlob data_out = ret->data(); - MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { - MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type - MSHADOW_INT_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type - MSHADOW_INT_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type - if (!lhs.storage_initialized()) return; + MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type + MSHADOW_INT_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type + MSHADOW_INT_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type + if (std::is_same::value) { // cpu parallelization by row blocks + if (kWriteTo == req) { + mxnet_op::Kernel::Launch( + s, data_out.Size(), data_out.dptr()); + } + int num_threads = mxnet_op::get_num_threads(data_out.shape_[0]); + size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; if (trans_lhs) { - mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), + mxnet_op::Kernel::Launch(s, num_threads, data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), lhs.shape()[0], - rhs.shape()[1]); + col_idx_l.dptr(), data_r.dptr(), seg_len, + lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); } else { - mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), + mxnet_op::Kernel::Launch(s, num_threads, data_out.dptr(), data_l.dptr(), indptr_l.dptr(), - col_idx_l.dptr(), data_r.dptr(), rhs.shape()[1]); + col_idx_l.dptr(), data_r.dptr(), seg_len, + data_out.shape_[0], data_out.shape_[1]); } - }); + } else { // gpu parallelization by output elements + if (trans_lhs) { + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), lhs.shape()[0], + data_out.shape_[1]); + }); + } else { + MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { + mxnet_op::Kernel, xpu>::Launch(s, data_out.Size(), + data_out.dptr(), data_l.dptr(), indptr_l.dptr(), + col_idx_l.dptr(), data_r.dptr(), rhs.shape()[1]); + }); + } + } }); }); });