Skip to content

Commit

Permalink
[CPU][Kernel] Single socket spmm (#3024)
Browse files Browse the repository at this point in the history
* optimizations of spmm for CPU

* Added names of contributors

* Minor code cleanup

* Moved the spmm optimization code to a new header file

* Moved to DGL's logging method

* removed duplicate code between SpMMSumCsr and SpMMCmpCsr

* Changes made to follow Google coding style

* Fixed lint errors in spmm.h

* Fixed some lint errors from spmm_blocking_libxsmm.h

* Fixed lint errors from spmm_blocking_libxsmm.h

* Added comments to SpMMCreateLibxsmmKernel

* to enable building of tests, and other cosmetic changes

* disabling libxsmm on windows

* Put a condition to avoid opt impl for FP64 as libxsmm does not have FP64 support yet

* cosmetic changes and documentation

* cosmetic changes

* to pass lint tests

* replaced multiple allocations for buffers of indices and edges with a single allocation

Co-authored-by: Minjie Wang <wmjlyjemaine@gmail.com>
  • Loading branch information
sanchit-misra and jermainewang committed Jul 13, 2021
1 parent 186ef59 commit fac75e1
Show file tree
Hide file tree
Showing 5 changed files with 760 additions and 44 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Expand Up @@ -35,3 +35,6 @@
[submodule "third_party/nccl"]
path = third_party/nccl
url = https://github.com/nvidia/nccl
[submodule "third_party/libxsmm"]
path = third_party/libxsmm
url = https://github.com/hfp/libxsmm.git
30 changes: 24 additions & 6 deletions CMakeLists.txt
Expand Up @@ -27,6 +27,7 @@ dgl_option(USE_NCCL "Build with NCCL support" OFF)
dgl_option(USE_SYSTEM_NCCL "Build using system's NCCL library" OFF)
dgl_option(USE_OPENMP "Build with OpenMP" ON)
dgl_option(USE_AVX "Build with AVX optimization" ON)
dgl_option(USE_LIBXSMM "Build with LIBXSMM library optimization" ON)
dgl_option(USE_FP16 "Build with fp16 support to enable mixed precision training" OFF)
dgl_option(USE_TVM "Build with TVM kernels" OFF)
dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF)
Expand All @@ -36,7 +37,7 @@ dgl_option(USE_HDFS "Build with HDFS support" OFF) # Set env HADOOP_HDFS_HOME if

# Set debug compile option for gdb, only happens when -DCMAKE_BUILD_TYPE=DEBUG
if (NOT MSVC)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g3 -ggdb")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -DDEBUG -O0 -g3 -ggdb")
endif(NOT MSVC)

if(USE_CUDA)
Expand Down Expand Up @@ -89,10 +90,10 @@ if(MSVC)
else(MSVC)
include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11)
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC ${CMAKE_C_FLAGS}")
set(CMAKE_C_FLAGS "-O2 -Wall -fPIC -march=native ${CMAKE_C_FLAGS}")
# We still use c++11 flag in CPU build because gcc5.4 (our default compiler) is
# not fully compatible with c++14 feature.
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -std=c++11 ${CMAKE_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "-O2 -Wall -fPIC -std=c++11 -march=native ${CMAKE_CXX_FLAGS}")
if(NOT APPLE)
set(CMAKE_SHARED_LINKER_FLAGS "-Wl,--warn-common ${CMAKE_SHARED_LINKER_FLAGS}")
endif(NOT APPLE)
Expand All @@ -108,9 +109,15 @@ if(USE_OPENMP)
endif(USE_OPENMP)

if(USE_AVX)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_AVX")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_AVX")
message(STATUS "Build with AVX optimization.")
if(USE_LIBXSMM)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_AVX -DUSE_LIBXSMM -DDGL_CPU_LLC_SIZE=40000000")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_AVX -DUSE_LIBXSMM -DDGL_CPU_LLC_SIZE=40000000")
message(STATUS "Build with LIBXSMM optimization.")
else(USE_LIBXSMM)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_AVX")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_AVX")
message(STATUS "Build with AVX optimization.")
endif(USE_LIBXSMM)
endif(USE_AVX)

# Build with fp16 to support mixed precision training.
Expand Down Expand Up @@ -194,6 +201,7 @@ target_include_directories(dgl PRIVATE "third_party/xbyak/")
target_include_directories(dgl PRIVATE "third_party/METIS/include/")
target_include_directories(dgl PRIVATE "tensoradapter/include")
target_include_directories(dgl PRIVATE "third_party/nanoflann/include")
target_include_directories(dgl PRIVATE "third_party/libxsmm/include")

# For serialization
if (USE_HDFS)
Expand All @@ -213,6 +221,15 @@ if(NOT MSVC)
list(APPEND DGL_LINKER_LIBS metis)
endif(NOT MSVC)

# Compile LIBXSMM
if((NOT MSVC) AND USE_LIBXSMM)
add_custom_target(libxsmm COMMAND make realclean COMMAND make -j BLAS=0
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/third_party/libxsmm
)
add_dependencies(dgl libxsmm)
list(APPEND DGL_LINKER_LIBS -L${CMAKE_SOURCE_DIR}/third_party/libxsmm/lib/ xsmm)
endif((NOT MSVC) AND USE_LIBXSMM)

# Compile TVM Runtime and Featgraph
# (NOTE) We compile a dynamic library called featgraph_runtime, which the DGL library links to.
# Kernels are packed in a separate dynamic library called featgraph_kernels, which DGL
Expand Down Expand Up @@ -287,6 +304,7 @@ if(BUILD_CPP_TEST)
include_directories("third_party/xbyak")
include_directories("third_party/dmlc-core/include")
include_directories("third_party/phmap")
include_directories("third_party/libxsmm/include")
file(GLOB_RECURSE TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/tests/cpp/*.cc)
add_executable(runUnitTests ${TEST_SRC_FILES})
target_link_libraries(runUnitTests gtest gtest_main)
Expand Down
181 changes: 143 additions & 38 deletions src/array/cpu/spmm.h
Expand Up @@ -15,12 +15,90 @@
#if !defined(_WIN32)
#ifdef USE_AVX
#include "intel/cpu_support.h"
#ifdef USE_LIBXSMM
#include "spmm_blocking_libxsmm.h"
#endif // USE_LIBXSMM
#endif // USE_AVX
#endif // _WIN32
namespace dgl {
namespace aten {
namespace cpu {

#if !defined(_WIN32)
#ifdef USE_AVX
/*!
* \brief CPU kernel of SpMM on Csr format using Xbyak.
* \param cpu_spec JIT'ed kernel
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param X The feature on source nodes.
* \param W The feature on edges.
* \param O The result feature on destination nodes.
* \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes. For each edge, it uses the
* JIT'ed kernel.
*/
template <typename IdType, typename DType, typename Op>
void SpMMSumCsrXbyak(dgl::ElemWiseAddUpdate<Op>* cpu_spec, const BcastOff& bcast,
const CSRMatrix& csr, const DType* X, const DType* W, DType* O) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>();
const IdType* edges = csr.data.Ptr<IdType>();
int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
#pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j;
cpu_spec->run(out_off, X + cid * lhs_dim, W + eid * rhs_dim, dim);
}
}
}
#endif // USE_AVX
#endif // _WIN32

/*!
* \brief Naive CPU kernel of SpMM on Csr format.
* \param cpu_spec JIT'ed kernel
* \param bcast Broadcast information.
* \param csr The Csr matrix.
* \param X The feature on source nodes.
* \param W The feature on edges.
* \param O The result feature on destination nodes.
* \note it uses node parallel strategy, different threads are responsible
* for the computation of different nodes.
*/
template <typename IdType, typename DType, typename Op>
void SpMMSumCsrNaive(const BcastOff& bcast, const CSRMatrix& csr, const DType* X,
const DType* W, DType* O) {
const bool has_idx = !IsNullArray(csr.data);
const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>();
const IdType* edges = csr.data.Ptr<IdType>();
int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
#pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off =
Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off =
Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
out_off[k] += Op::Call(lhs_off, rhs_off);
}
}
}
}

/*!
* \brief CPU kernel of SpMM on Csr format.
* \param bcast Broadcast information.
Expand All @@ -42,52 +120,46 @@ void SpMMSumCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
const DType* W = efeat.Ptr<DType>();
int64_t dim = bcast.out_len, lhs_dim = bcast.lhs_len, rhs_dim = bcast.rhs_len;
DType* O = out.Ptr<DType>();
CHECK_NOTNULL(indptr);
CHECK_NOTNULL(O);
if (Op::use_lhs) {
CHECK_NOTNULL(indices);
CHECK_NOTNULL(X);
}
if (Op::use_rhs) {
if (has_idx)
CHECK_NOTNULL(edges);
CHECK_NOTNULL(W);
}
#if !defined(_WIN32)
#ifdef USE_AVX
typedef dgl::ElemWiseAddUpdate<Op> ElemWiseUpd;
/* Prepare an assembler kernel */
static std::unique_ptr<ElemWiseUpd> asm_kernel_ptr(
(dgl::IntelKernel<>::IsEnabled()) ? new ElemWiseUpd() : nullptr);
/* Distribute the kernel among OMP threads */
ElemWiseUpd* cpu_spec = (asm_kernel_ptr && asm_kernel_ptr->applicable())
? asm_kernel_ptr.get()
: nullptr;
if (cpu_spec && dim > 16 && !bcast.use_bcast) {
#pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j;
cpu_spec->run(out_off, X + cid * lhs_dim, W + eid * rhs_dim, dim);
}
}
#ifdef USE_LIBXSMM
const bool no_libxsmm =
bcast.use_bcast || std::is_same<DType, double>::value;
if (!no_libxsmm) {
SpMMSumCsrLibxsmm<IdType, DType, Op>(bcast, csr, ufeat, efeat, out);
} else {
#endif // USE_LIBXSMM
typedef dgl::ElemWiseAddUpdate<Op> ElemWiseUpd;
/* Prepare an assembler kernel */
static std::unique_ptr<ElemWiseUpd> asm_kernel_ptr(
(dgl::IntelKernel<>::IsEnabled()) ? new ElemWiseUpd() : nullptr);
/* Distribute the kernel among OMP threads */
ElemWiseUpd* cpu_spec = (asm_kernel_ptr && asm_kernel_ptr->applicable())
? asm_kernel_ptr.get()
: nullptr;
if (cpu_spec && dim > 16 && !bcast.use_bcast) {
SpMMSumCsrXbyak<IdType, DType, Op>(cpu_spec, bcast, csr, X, W, O);
} else {
#endif // USE_AVX
#endif // _WIN32

#pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
DType* out_off = O + rid * dim;
for (IdType j = row_start; j < row_end; ++j) {
const IdType cid = indices[j];
const IdType eid = has_idx ? edges[j] : j;
for (int64_t k = 0; k < dim; ++k) {
const int64_t lhs_add = bcast.use_bcast ? bcast.lhs_offset[k] : k;
const int64_t rhs_add = bcast.use_bcast ? bcast.rhs_offset[k] : k;
const DType* lhs_off =
Op::use_lhs ? X + cid * lhs_dim + lhs_add : nullptr;
const DType* rhs_off =
Op::use_rhs ? W + eid * rhs_dim + rhs_add : nullptr;
out_off[k] += Op::Call(lhs_off, rhs_off);
}
}
}
SpMMSumCsrNaive<IdType, DType, Op>(bcast, csr, X, W, O);
#if !defined(_WIN32)
#ifdef USE_AVX
}
#ifdef USE_LIBXSMM
}
#endif // USE_LIBXSMM
#endif // USE_AVX
#endif // _WIN32
}
Expand Down Expand Up @@ -172,6 +244,32 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
DType* O = static_cast<DType*>(out->data);
IdType* argX = Op::use_lhs ? static_cast<IdType*>(argu->data) : nullptr;
IdType* argW = Op::use_rhs ? static_cast<IdType*>(arge->data) : nullptr;
CHECK_NOTNULL(indptr);
CHECK_NOTNULL(O);
if (Op::use_lhs) {
CHECK_NOTNULL(indices);
CHECK_NOTNULL(X);
CHECK_NOTNULL(argX);
}
if (Op::use_rhs) {
if (has_idx)
CHECK_NOTNULL(edges);
CHECK_NOTNULL(W);
CHECK_NOTNULL(argW);
}
#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM

const bool no_libxsmm =
bcast.use_bcast || std::is_same<DType, double>::value;
if (!no_libxsmm) {
SpMMCmpCsrLibxsmm<IdType, DType, Op, Cmp>(bcast, csr, ufeat, efeat, out, argu, arge);
} else {
#endif // USE_LIBXSMM
#endif // USE_AVX
#endif // _WIN32

#pragma omp parallel for
for (IdType rid = 0; rid < csr.num_rows; ++rid) {
const IdType row_start = indptr[rid], row_end = indptr[rid + 1];
Expand All @@ -197,6 +295,13 @@ void SpMMCmpCsr(const BcastOff& bcast, const CSRMatrix& csr, NDArray ufeat,
}
}
}
#if !defined(_WIN32)
#ifdef USE_AVX
#ifdef USE_LIBXSMM
}
#endif // USE_LIBXSMM
#endif // USE_AVX
#endif // _WIN32
}

/*!
Expand Down

0 comments on commit fac75e1

Please sign in to comment.