-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve dot #61
Improve dot #61
Changes from 6 commits
1768940
27f9166
90c62e3
25cc047
878bd5e
71f5c93
acb17c7
744ce2e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,6 +10,7 @@ | |
#include <vector> | ||
#include <algorithm> | ||
#include <utility> | ||
#include <type_traits> | ||
#include "../mshadow_op.h" | ||
#include "../elemwise_op_common.h" | ||
#include "../mxnet_op.h" | ||
|
@@ -495,20 +496,12 @@ inline bool DotBackwardInferStorageType(const nnvm::NodeAttrs& attrs, | |
return true; | ||
} | ||
|
||
/*! | ||
* \brief Tempalte declaration of dot(csr, dns1) = dns2. | ||
* Whether csr and dns1 are transposed before dot operation | ||
* is determined by trans_csr and trans_dns, respectively. | ||
* For now we only implemented the case when trans_dns = false. | ||
*/ | ||
template<bool trans_csr, bool trans_dns, int req> | ||
struct DotCsrDnsDns; | ||
|
||
/*! | ||
* \brief Kernel of dot(csr, dns1) = dns2 | ||
* Parallelization by output matrix elements | ||
*/ | ||
template<int req> | ||
struct DotCsrDnsDns<false, false, req> { | ||
struct DotCsrDnsDns { | ||
/*! | ||
* \brief This function represents performing an inner product between a row of lhs | ||
* and a column of rhs and then assigning the value to out[i]. | ||
|
@@ -537,9 +530,10 @@ struct DotCsrDnsDns<false, false, req> { | |
|
||
/*! | ||
* \brief Kernel of dot(csr.T(), dns1) = dns2 | ||
* Parallelization by output matrix elements | ||
*/ | ||
template<int req> | ||
struct DotCsrDnsDns<true, false, req> { | ||
struct DotCsrTransDnsDns { | ||
/*! | ||
* \brief This function represents performing an inner product between a column of lhs | ||
* and a column of rhs and then assigning the value to out[i]. | ||
|
@@ -583,6 +577,69 @@ struct DotCsrDnsDns<true, false, req> { | |
} | ||
}; | ||
|
||
/*! | ||
* \brief Kernel of dot(csr, dns1) = dns2 | ||
* Parallelization by row blocks | ||
*/ | ||
struct DotCsrDnsDnsByRowBlocks { | ||
/*! | ||
* \brief | ||
* \param i the i-th thread | ||
*/ | ||
template<typename DType, typename IType, typename CType> | ||
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, | ||
const CType* col_idx_l, const DType* data_r, const size_t seg_len, | ||
const size_t num_rows, const size_t num_cols) { | ||
const size_t seg_start = i * seg_len; | ||
if (seg_start >= num_rows) return; | ||
const size_t seg_end = (seg_start+seg_len < num_rows? seg_start+seg_len : num_rows); | ||
for (size_t j = seg_start; j < seg_end; ++j) { | ||
if (indptr_l[j] == indptr_l[j+1]) continue; | ||
const size_t offset_out = j * num_cols; | ||
for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { | ||
const auto val = data_l[k]; | ||
const size_t offset_r = col_idx_l[k] * num_cols; | ||
for (size_t l = 0; l < num_cols; ++l) { | ||
out[offset_out+l] += data_r[offset_r+l] * val; | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
/*! | ||
* \brief Kernel of dot(csr.T(), dns1) = dns2 | ||
* Parallelization by row blocks | ||
*/ | ||
struct DotCsrTransDnsDnsByRowBlocks { | ||
/*! | ||
* \brief | ||
* \param i the i-th thread | ||
*/ | ||
template<typename DType, typename IType, typename CType> | ||
MSHADOW_XINLINE static void Map(int i, DType* out, const DType* data_l, const IType* indptr_l, | ||
const CType* col_idx_l, const DType* data_r, const size_t seg_len, | ||
const size_t num_rows_l, const size_t num_rows, | ||
const size_t num_cols) { | ||
const size_t seg_start = i * seg_len; | ||
if (seg_start >= num_rows) return; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this check necessary? since you did rounding at line 599 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. num_threads could be so big that seg_len=1, and more threads than num_rows are launched. For those extra threads, we don't need to do anything, right? |
||
const size_t seg_end = (i + 1) * seg_len; | ||
for (size_t j = 0; j < num_rows_l; ++j) { | ||
if (indptr_l[j] == indptr_l[j+1]) continue; | ||
const size_t offset_r = j * num_cols; | ||
for (auto k = indptr_l[j]; k < indptr_l[j+1]; ++k) { | ||
const auto col_idx = col_idx_l[k]; | ||
if (col_idx < seg_start || col_idx >= seg_end) continue; | ||
const size_t offset_out = col_idx * num_cols; | ||
const auto val = data_l[k]; | ||
for (size_t l = 0; l < num_cols; ++l) { | ||
out[offset_out+l] += data_r[offset_r+l] * val; | ||
} | ||
} | ||
} | ||
} | ||
}; | ||
|
||
template<typename xpu> | ||
void DotCsrDnsDnsImpl(const OpContext& ctx, | ||
const NDArray& lhs, | ||
|
@@ -594,6 +651,7 @@ void DotCsrDnsDnsImpl(const OpContext& ctx, | |
CHECK_EQ(lhs.storage_type(), kCSRStorage); | ||
CHECK_EQ(rhs.storage_type(), kDefaultStorage); | ||
CHECK_EQ(ret->storage_type(), kDefaultStorage); | ||
if (!lhs.storage_initialized()) return; | ||
|
||
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
const TBlob data_l = lhs.data(); | ||
|
@@ -602,22 +660,43 @@ void DotCsrDnsDnsImpl(const OpContext& ctx, | |
const TBlob data_r = rhs.data(); | ||
const TBlob data_out = ret->data(); | ||
|
||
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { | ||
MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type | ||
MSHADOW_INT_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type | ||
MSHADOW_INT_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type | ||
if (!lhs.storage_initialized()) return; | ||
MSHADOW_TYPE_SWITCH(data_l.type_flag_, DType, { // data type | ||
MSHADOW_INT_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type | ||
MSHADOW_INT_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type | ||
if (std::is_same<xpu, cpu>::value) { // cpu parallelization by row blocks | ||
if (kWriteTo == req) { | ||
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch( | ||
s, data_out.Size(), data_out.dptr<DType>()); | ||
} | ||
int num_threads = mxnet_op::get_num_threads<xpu>(data_out.shape_[0]); | ||
size_t seg_len = (data_out.shape_[0] + num_threads - 1) / num_threads; | ||
if (trans_lhs) { | ||
mxnet_op::Kernel<DotCsrDnsDns<true, false, ReqType>, xpu>::Launch(s, data_out.Size(), | ||
mxnet_op::Kernel<DotCsrTransDnsDnsByRowBlocks, xpu>::Launch(s, num_threads, | ||
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(), | ||
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), lhs.shape()[0], | ||
rhs.shape()[1]); | ||
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), seg_len, | ||
lhs.shape()[0], data_out.shape_[0], data_out.shape_[1]); | ||
} else { | ||
mxnet_op::Kernel<DotCsrDnsDns<false, false, ReqType>, xpu>::Launch(s, data_out.Size(), | ||
mxnet_op::Kernel<DotCsrDnsDnsByRowBlocks, xpu>::Launch(s, num_threads, | ||
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(), | ||
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), rhs.shape()[1]); | ||
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), seg_len, | ||
data_out.shape_[0], data_out.shape_[1]); | ||
} | ||
}); | ||
} else { // gpu parallelization by output elements | ||
if (trans_lhs) { | ||
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { | ||
mxnet_op::Kernel<DotCsrTransDnsDns<ReqType>, xpu>::Launch(s, data_out.Size(), | ||
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(), | ||
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), lhs.shape()[0], | ||
data_out.shape_[1]); | ||
}); | ||
} else { | ||
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, { | ||
mxnet_op::Kernel<DotCsrDnsDns<ReqType>, xpu>::Launch(s, data_out.Size(), | ||
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(), | ||
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), rhs.shape()[1]); | ||
}); | ||
} | ||
} | ||
}); | ||
}); | ||
}); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it pass lint?