Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve dot #61

Merged
merged 8 commits into from
Jun 2, 2017
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions benchmark/python/sparse_op.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# pylint: skip-file
import ctypes

import mxnet as mx
from mxnet.test_utils import *
import numpy as np
Expand All @@ -7,6 +9,16 @@
import pickle as pickle
import time
import sys
import argparse

from mxnet.base import check_call, _LIB

parser = argparse.ArgumentParser(description="Benchmark sparse operators",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--num-omp-threads', type=int, default=1, help='number of omp threads to set in MXNet')
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it pass lint?


args = parser.parse_args()


def get_avazu(data_dir):
if not os.path.isdir(data_dir):
Expand Down Expand Up @@ -118,6 +130,7 @@ def bench_dot_backward(m, k, n, density, ctx, repeat):
print("dot_forward\tdot(csr, dns)")
print('density(%)\tcontext\tn\tm\tk\tt_sparse\tt_dense\tt_sparse/t_dense')

check_call(_LIB.MXSetNumOMPThreads(ctypes.c_int(args.num_omp_threads)))
# TODO(haibin) make these runtime options
m = 512
k = [50000, 100000]
Expand Down
13 changes: 13 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#ifndef MXNET_OPERATOR_MXNET_OP_H_
#define MXNET_OPERATOR_MXNET_OP_H_

#include <dmlc/omp.h>
#include <mxnet/base.h>
#include <algorithm>

Expand All @@ -22,6 +23,8 @@ const float PI = 3.14159265358979323846;
using std::isnan;
#endif

template<typename xpu>
int get_num_threads(const int N);

#ifdef __CUDACC__
#define CUDA_KERNEL_LOOP(i, n) \
Expand All @@ -37,8 +40,18 @@ inline int cuda_get_num_blocks(const int N) {
using namespace mshadow::cuda;
return std::min(kMaxGridNum, (N + kBaseThreadNum - 1) / kBaseThreadNum);
}

template<>
inline int get_num_threads<gpu>(const int N) {
using namespace mshadow::cuda;
return kBaseThreadNum * cuda_get_num_blocks(N);
}
#endif // __CUDACC__

template<>
inline int get_num_threads<cpu>(const int N) {
return omp_get_max_threads();
}

/*! \brief operator request type switch */
#define MXNET_ASSIGN_REQ_SWITCH(req, ReqType, ...) \
Expand Down
123 changes: 101 additions & 22 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <vector>
#include <algorithm>
#include <utility>
#include <type_traits>
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
#include "../mxnet_op.h"
Expand Down Expand Up @@ -495,20 +496,12 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
return true;
}

/*!
* \brief Tempalte declaration of dot(csr, dns1) = dns2.
* Whether csr and dns1 are transposed before dot operation
* is determined by trans_csr and trans_dns, respectively.
* For now we only implemented the case when trans_dns = false.
*/
template<bool trans_csr, bool trans_dns, int req>
struct DotCsrDnsDns;

/*!
* \brief Kernel of dot(csr, dns1) = dns2
* Parallelization by output matrix elements
*/
template<int req>
struct DotCsrDnsDns<false, false, req> {
struct DotCsrDnsDns {
/*!
* \brief This function represents performing an inner product between a row of lhs
* and a column of rhs and then assigning the value to out[i].
Expand Down Expand Up @@ -537,9 +530,10 @@ struct DotCsrDnsDns<false, false, req> {

/*!
* \brief Kernel of dot(csr.T(), dns1) = dns2
* Parallelization by output matrix elements
*/
template<int req>
struct DotCsrDnsDns<true, false, req> {
struct DotCsrTransDnsDns {
/*!
* \brief This function represents performing an inner product between a column of lhs
* and a column of rhs and then assigning the value to out[i].
Expand Down Expand Up @@ -583,6 +577,69 @@ struct DotCsrDnsDns<true, false, req> {
}
};

/*!
* \brief Kernel of dot(csr, dns1) = dns2
* Parallelization by row blocks
*/
struct DotCsrDnsDnsByRowBlocks {
/*!
* \brief
* \param i the i-th thread
*/
template<typename DType, typename IType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l,
const CType* col_idx_l, const DType* data_r, const size_t seg_len,
const size_t num_rows, const size_t num_cols) {
const size_t seg_start = i * seg_len;
if (seg_start >= num_rows) return;
const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows);
for (size_t j = seg_start; j < seg_end; ++j) {
if (indptr_l[j] == indptr_l[j+1]) continue;
const size_t offset_out = j * num_cols;
for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) {
const auto val = data_l[k];
const size_t offset_r = col_idx_l[k] * num_cols;
for (size_t l = 0; l < num_cols; ++l) {
out[offset_out+l] += data_r[offset_r+l] * val;
}
}
}
}
};

/*!
* \brief Kernel of dot(csr.T(), dns1) = dns2
* Parallelization by row blocks
*/
struct DotCsrTransDnsDnsByRowBlocks {
/*!
* \brief
* \param i the i-th thread
*/
template<typename DType, typename IType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l,
const CType* col_idx_l, const DType* data_r, const size_t seg_len,
const size_t num_rows_l, const size_t num_rows,
const size_t num_cols) {
const size_t seg_start = i * seg_len;
if (seg_start >= num_rows) return;
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this check necessary? since you did rounding at line 599

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_threads could be so big that seg_len=1, and more threads than num_rows are launched. For those extra threads, we don't need to do anything, right?

const size_t seg_end = (i + 1) * seg_len;
for (size_t j = 0; j < num_rows_l; ++j) {
if (indptr_l[j] == indptr_l[j+1]) continue;
const size_t offset_r = j * num_cols;
for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) {
const auto col_idx = col_idx_l[k];
if (col_idx < seg_start || col_idx >= seg_end) continue;
const size_t offset_out = col_idx * num_cols;
const auto val = data_l[k];
for (size_t l = 0; l < num_cols; ++l) {
out[offset_out+l] += data_r[offset_r+l] * val;
}
}
}
}
};

template<typename xpu>
void DotCsrDnsDnsImpl(const OpContext& ctx,
const NDArray& lhs,
Expand All @@ -594,6 +651,7 @@ void DotCsrDnsDnsImpl(const OpContext& ctx,
CHECK_EQ(lhs.storage_type(), kCSRStorage);
CHECK_EQ(rhs.storage_type(), kDefaultStorage);
CHECK_EQ(ret->storage_type(), kDefaultStorage);
if (!lhs.storage_initialized()) return;

mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob data_l = lhs.data();
Expand All @@ -602,22 +660,43 @@ void DotCsrDnsDnsImpl(const OpContext& ctx,
const TBlob data_r = rhs.data();
const TBlob data_out = ret->data();

MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_INT_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_INT_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
if (!lhs.storage_initialized()) return;
MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MSHADOW_INT_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_INT_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
if (std::is_same<xpu, cpu>::value) { // cpu parallelization by row blocks
if (kWriteTo == req) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
s, data_out.Size(), data_out.dptr<DType>());
}
int num_threads = mxnet_op::get_num_threads<xpu>(data_out.shape_[0]);
size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads;
if (trans_lhs) {
mxnet_op::Kernel<DotCsrDnsDns<true, false, ReqType>, xpu>::Launch(s, data_out.Size(),
mxnet_op::Kernel<DotCsrTransDnsDnsByRowBlocks, xpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), lhs.shape()[0],
rhs.shape()[1]);
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), seg_len,
lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]);
} else {
mxnet_op::Kernel<DotCsrDnsDns<false, false, ReqType>, xpu>::Launch(s, data_out.Size(),
mxnet_op::Kernel<DotCsrDnsDnsByRowBlocks, xpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), rhs.shape()[1]);
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), seg_len,
data_out.shape_[0], data_out.shape_[1]);
}
});
} else { // gpu parallelization by output elements
if (trans_lhs) {
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrTransDnsDns<ReqType>, xpu>::Launch(s, data_out.Size(),
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), lhs.shape()[0],
data_out.shape_[1]);
});
} else {
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
mxnet_op::Kernel<DotCsrDnsDns<ReqType>, xpu>::Launch(s, data_out.Size(),
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), rhs.shape()[1]);
});
}
}
});
});
});
Expand Down