Skip to content

Commit

Permalink
[GPU] Add support for CUDA-based GPU build (#3160)
Browse files Browse the repository at this point in the history
* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* Initial CUDA work

* redirect log to python console (#3090)

* redir log to python console

* fix pylint

* Apply suggestions from code review

* Update basic.py

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* Update c_api.h

* Apply suggestions from code review

* Apply suggestions from code review

* super-minor: better wording

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
Co-authored-by: StrikerRUS <nekit94-12@hotmail.com>

* re-order includes (fixes #3132) (#3133)

* Revert "re-order includes (fixes #3132) (#3133)" (#3153)

This reverts commit 656d267.

* Missing change from previous rebase

* Minor cleanup and removal of development scripts.

* Only set gpu_use_dp on by default for CUDA. Other minor change.

* Fix python lint indentation problem.

* More python lint issues.

* Big lint cleanup - more to come.

* Another large lint cleanup - more to come.

* Even more lint cleanup.

* Minor cleanup so less differences in code.

* Revert is_use_subset changes

* Another rebase from master to fix recent conflicts.

* More lint.

* Simple code cleanup - add & remove blank lines, revert unneccessary format changes, remove added dead code.

* Removed parameters added for CUDA and various bug fix.

* Yet more lint and unneccessary changes.

* Revert another change.

* Removal of unneccessary code.

* temporary appveyor.yml for building and testing

* Remove return value in ReSize

* Removal of unused variables.

* Code cleanup from reviewers suggestions.

* Removal of FIXME comments and unused defines.

* More reviewers comments cleanup.

* More reviewers comments cleanup.

* More reviewers comments cleanup.

* Fix config variables.

* Attempt to fix check-docs failure

* Update Paramster.rst for num_gpu

* Removing test appveyor.yml

* Add �CUDA_RESOLVE_DEVICE_SYMBOLS to libraries to fix linking issue.

* Fixed handling of data elements less than 2K.

* More reviewers comments cleanup.

* Removal of TODO and fix printing of int64_t

* Add cuda change for CI testing and remove cuda from device_type in python.

* Missed one change form previous check-in

* Removal AdditionConfig and fix settings.

* Limit number of GPUs to one for now in CUDA.

* Update Parameters.rst for previous check-in

* Whitespace removal.

* Cleanup unused code.

* Changed uint/ushort/ulong to unsigned int/short/long to help Windows based CUDA compiler work.

* Lint change from previous check-in.

* Changes based on reviewers comments.

* More reviewer comment changes.

* Adding warning for is_sparse. Revert tmp_subset code. Only return FeatureGroupData if not is_multi_val_

* Fix so that CUDA code will compile even if you enable the SCORE_T_USE_DOUBLE define.

* Reviewer comment cleanup.

* Replace warning with Log message. Removal of some of the USE_CUDA. Fix typo and removal of pragma once.

* Remove PRINT debug for CUDA code.

* Allow to use of multiple GPUs for CUDA.

* More multi-GPUs enablement for CUDA.

* More code cleanup based on reviews comments.

* Update docs with latest config changes.

Co-authored-by: Gordon Fossum <fossum@us.ibm.com>
Co-authored-by: ChipKerchner <ckerchne@linux.vnet.ibm.com>
Co-authored-by: Guolin Ke <guolin.ke@outlook.com>
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
Co-authored-by: StrikerRUS <nekit94-12@hotmail.com>
Co-authored-by: James Lamb <jaylamb20@gmail.com>
  • Loading branch information
7 people committed Sep 20, 2020
1 parent 1fddabb commit f7ad945
Show file tree
Hide file tree
Showing 33 changed files with 2,944 additions and 14 deletions.
10 changes: 10 additions & 0 deletions .ci/test.sh
Expand Up @@ -128,6 +128,16 @@ if [[ $TASK == "gpu" ]]; then
exit 0
fi
cmake -DUSE_GPU=ON -DOpenCL_INCLUDE_DIR=$AMDAPPSDK_PATH/include/ ..
elif [[ $TASK == "cuda" ]]; then
sed -i'.bak' 's/std::string device_type = "cpu";/std::string device_type = "cuda";/' $BUILD_DIRECTORY/include/LightGBM/config.h
grep -q 'std::string device_type = "cuda"' $BUILD_DIRECTORY/include/LightGBM/config.h || exit -1 # make sure that changes were really done
if [[ $METHOD == "pip" ]]; then
cd $BUILD_DIRECTORY/python-package && python setup.py sdist || exit -1
pip install --user $BUILD_DIRECTORY/python-package/dist/lightgbm-$LGB_VER.tar.gz -v --install-option=--cuda || exit -1
pytest $BUILD_DIRECTORY/tests/python_package_test || exit -1
exit 0
fi
cmake -DUSE_CUDA=ON ..
elif [[ $TASK == "mpi" ]]; then
if [[ $METHOD == "pip" ]]; then
cd $BUILD_DIRECTORY/python-package && python setup.py sdist || exit -1
Expand Down
90 changes: 89 additions & 1 deletion CMakeLists.txt
@@ -1,17 +1,24 @@
if(USE_GPU OR APPLE)
cmake_minimum_required(VERSION 3.2)
elseif(USE_CUDA)
cmake_minimum_required(VERSION 3.16)
else()
cmake_minimum_required(VERSION 2.8)
endif()

PROJECT(lightgbm)
if(USE_CUDA)
PROJECT(lightgbm LANGUAGES C CXX CUDA)
else()
PROJECT(lightgbm LANGUAGES C CXX)
endif()

OPTION(USE_MPI "Enable MPI-based parallel learning" OFF)
OPTION(USE_OPENMP "Enable OpenMP" ON)
OPTION(USE_GPU "Enable GPU-accelerated training" OFF)
OPTION(USE_SWIG "Enable SWIG to generate Java API" OFF)
OPTION(USE_HDFS "Enable HDFS support (EXPERIMENTAL)" OFF)
OPTION(USE_TIMETAG "Set to ON to output time costs" OFF)
OPTION(USE_CUDA "Enable CUDA-accelerated training (EXPERIMENTAL)" OFF)
OPTION(USE_DEBUG "Set to ON for Debug mode" OFF)
OPTION(BUILD_STATIC_LIB "Build static library" OFF)
OPTION(__BUILD_FOR_R "Set to ON if building lib_lightgbm for use with the R package" OFF)
Expand Down Expand Up @@ -94,6 +101,10 @@ else()
ADD_DEFINITIONS(-DUSE_SOCKET)
endif(USE_MPI)

if(USE_CUDA)
SET(USE_OPENMP ON CACHE BOOL "CUDA requires OpenMP" FORCE)
endif(USE_CUDA)

if(USE_OPENMP)
find_package(OpenMP REQUIRED)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}")
Expand Down Expand Up @@ -123,6 +134,67 @@ if(USE_GPU)
ADD_DEFINITIONS(-DUSE_GPU)
endif(USE_GPU)

if(USE_CUDA)
find_package(CUDA REQUIRED)
include_directories(${CUDA_INCLUDE_DIRS})
LIST(APPEND CMAKE_CUDA_FLAGS -Xcompiler=${OpenMP_CXX_FLAGS} -Xcompiler=-fPIC -Xcompiler=-Wall)
CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS 6.0 6.1 6.2 7.0 7.5+PTX)

LIST(APPEND CMAKE_CUDA_FLAGS ${CUDA_ARCH_FLAGS})
if(USE_DEBUG)
SET(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g")
else()
SET(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 -lineinfo")
endif()
string(REPLACE ";" " " CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS}")
message(STATUS "CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")

ADD_DEFINITIONS(-DUSE_CUDA)
if (NOT DEFINED CMAKE_CUDA_STANDARD)
set(CMAKE_CUDA_STANDARD 11)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
endif()

set(BASE_DEFINES
-DPOWER_FEATURE_WORKGROUPS=12
-DUSE_CONSTANT_BUF=0
)
set(ALLFEATS_DEFINES
${BASE_DEFINES}
-DENABLE_ALL_FEATURES
)
set(FULLDATA_DEFINES
${ALLFEATS_DEFINES}
-DIGNORE_INDICES
)

message(STATUS "ALLFEATS_DEFINES: ${ALLFEATS_DEFINES}")
message(STATUS "FULLDATA_DEFINES: ${FULLDATA_DEFINES}")

function(add_histogram hsize hname hadd hconst hdir)
add_library(histo${hsize}${hname} OBJECT src/treelearner/kernels/histogram${hsize}.cu)
set_target_properties(histo${hsize}${hname} PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
if(hadd)
list(APPEND histograms histo${hsize}${hname})
set(histograms ${histograms} PARENT_SCOPE)
endif()
target_compile_definitions(
histo${hsize}${hname} PRIVATE
-DCONST_HESSIAN=${hconst}
${hdir}
)
endfunction()

foreach (hsize _16_64_256)
add_histogram("${hsize}" "_sp_const" "True" "1" "${BASE_DEFINES}")
add_histogram("${hsize}" "_sp" "True" "0" "${BASE_DEFINES}")
add_histogram("${hsize}" "-allfeats_sp_const" "False" "1" "${ALLFEATS_DEFINES}")
add_histogram("${hsize}" "-allfeats_sp" "False" "0" "${ALLFEATS_DEFINES}")
add_histogram("${hsize}" "-fulldata_sp_const" "True" "1" "${FULLDATA_DEFINES}")
add_histogram("${hsize}" "-fulldata_sp" "True" "0" "${FULLDATA_DEFINES}")
endforeach()
endif(USE_CUDA)

if(USE_HDFS)
find_package(JNI REQUIRED)
find_path(HDFS_INCLUDE_DIR hdfs.h REQUIRED)
Expand Down Expand Up @@ -228,6 +300,9 @@ file(GLOB SOURCES
src/objective/*.cpp
src/network/*.cpp
src/treelearner/*.cpp
if(USE_CUDA)
src/treelearner/*.cu
endif(USE_CUDA)
)

add_executable(lightgbm src/main.cpp ${SOURCES})
Expand Down Expand Up @@ -303,6 +378,19 @@ if(USE_GPU)
TARGET_LINK_LIBRARIES(_lightgbm ${OpenCL_LIBRARY} ${Boost_LIBRARIES})
endif(USE_GPU)

if(USE_CUDA)
set_target_properties(lightgbm PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON)
TARGET_LINK_LIBRARIES(
lightgbm
${histograms}
)
set_target_properties(_lightgbm PROPERTIES CUDA_RESOLVE_DEVICE_SYMBOLS ON)
TARGET_LINK_LIBRARIES(
_lightgbm
${histograms}
)
endif(USE_CUDA)

if(USE_HDFS)
TARGET_LINK_LIBRARIES(lightgbm ${HDFS_CXX_LIBRARIES})
TARGET_LINK_LIBRARIES(_lightgbm ${HDFS_CXX_LIBRARIES})
Expand Down
8 changes: 7 additions & 1 deletion docs/Parameters.rst
Expand Up @@ -1120,7 +1120,13 @@ GPU Parameters

- ``gpu_use_dp`` :raw-html:`<a id="gpu_use_dp" title="Permalink to this parameter" href="#gpu_use_dp">&#x1F517;&#xFE0E;</a>`, default = ``false``, type = bool

- set this to ``true`` to use double precision math on GPU (by default single precision is used)
- set this to ``true`` to use double precision math on GPU (by default single precision is used in OpenCL implementation and double precision is used in CUDA implementation)

- ``num_gpu`` :raw-html:`<a id="num_gpu" title="Permalink to this parameter" href="#num_gpu">&#x1F517;&#xFE0E;</a>`, default = ``1``, type = int, constraints: ``num_gpu > 0``

- number of GPUs

- **Note**: can be used only in CUDA implementation

.. end params list
Expand Down
3 changes: 3 additions & 0 deletions include/LightGBM/bin.h
Expand Up @@ -288,6 +288,9 @@ class Bin {
/*! \brief Number of all data */
virtual data_size_t num_data() const = 0;

/*! \brief Get data pointer */
virtual void* get_data() = 0;

virtual void ReSize(data_size_t num_data) = 0;

/*!
Expand Down
7 changes: 6 additions & 1 deletion include/LightGBM/config.h
Expand Up @@ -965,9 +965,14 @@ struct Config {
// desc = **Note**: refer to `GPU Targets <./GPU-Targets.rst#query-opencl-devices-in-your-system>`__ for more details
int gpu_device_id = -1;

// desc = set this to ``true`` to use double precision math on GPU (by default single precision is used)
// desc = set this to ``true`` to use double precision math on GPU (by default single precision is used in OpenCL implementation and double precision is used in CUDA implementation)
bool gpu_use_dp = false;

// check = >0
// desc = number of GPUs
// desc = **Note**: can be used only in CUDA implementation
int num_gpu = 1;

#pragma endregion

#pragma endregion
Expand Down
24 changes: 24 additions & 0 deletions include/LightGBM/cuda/cuda_utils.h
@@ -0,0 +1,24 @@
/*!
* Copyright (c) 2020 IBM Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_CUDA_CUDA_UTILS_H_
#define LIGHTGBM_CUDA_CUDA_UTILS_H_

#ifdef USE_CUDA

#include <cuda.h>
#include <cuda_runtime.h>
#include <stdio.h>

#define CUDASUCCESS_OR_FATAL(ans) { gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort = true) {
if (code != cudaSuccess) {
LightGBM::Log::Fatal("[CUDA] %s %s %d\n", cudaGetErrorString(code), file, line);
if (abort) exit(code);
}
}

#endif // USE_CUDA

#endif // LIGHTGBM_CUDA_CUDA_UTILS_H_
86 changes: 86 additions & 0 deletions include/LightGBM/cuda/vector_cudahost.h
@@ -0,0 +1,86 @@
/*!
* Copyright (c) 2020 IBM Corporation. All rights reserved.
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*/
#ifndef LIGHTGBM_CUDA_VECTOR_CUDAHOST_H_
#define LIGHTGBM_CUDA_VECTOR_CUDAHOST_H_

#include <LightGBM/utils/common.h>

#ifdef USE_CUDA
#include <cuda.h>
#include <cuda_runtime.h>
#endif
#include <stdio.h>

enum LGBM_Device {
lgbm_device_cpu,
lgbm_device_gpu,
lgbm_device_cuda
};

enum Use_Learner {
use_cpu_learner,
use_gpu_learner,
use_cuda_learner
};

namespace LightGBM {

class LGBM_config_ {
public:
static int current_device; // Default: lgbm_device_cpu
static int current_learner; // Default: use_cpu_learner
};


template <class T>
struct CHAllocator {
typedef T value_type;
CHAllocator() {}
template <class U> CHAllocator(const CHAllocator<U>& other);
T* allocate(std::size_t n) {
T* ptr;
if (n == 0) return NULL;
#ifdef USE_CUDA
if (LGBM_config_::current_device == lgbm_device_cuda) {
cudaError_t ret = cudaHostAlloc(&ptr, n*sizeof(T), cudaHostAllocPortable);
if (ret != cudaSuccess) {
Log::Warning("Defaulting to malloc in CHAllocator!!!");
ptr = reinterpret_cast<T*>(_mm_malloc(n*sizeof(T), 16));
}
} else {
ptr = reinterpret_cast<T*>(_mm_malloc(n*sizeof(T), 16));
}
#else
ptr = reinterpret_cast<T*>(_mm_malloc(n*sizeof(T), 16));
#endif
return ptr;
}

void deallocate(T* p, std::size_t n) {
(void)n; // UNUSED
if (p == NULL) return;
#ifdef USE_CUDA
if (LGBM_config_::current_device == lgbm_device_cuda) {
cudaPointerAttributes attributes;
cudaPointerGetAttributes(&attributes, p);
if ((attributes.type == cudaMemoryTypeHost) && (attributes.devicePointer != NULL)) {
cudaFreeHost(p);
}
} else {
_mm_free(p);
}
#else
_mm_free(p);
#endif
}
};
template <class T, class U>
bool operator==(const CHAllocator<T>&, const CHAllocator<U>&);
template <class T, class U>
bool operator!=(const CHAllocator<T>&, const CHAllocator<U>&);

} // namespace LightGBM

#endif // LIGHTGBM_CUDA_VECTOR_CUDAHOST_H_
8 changes: 8 additions & 0 deletions include/LightGBM/dataset.h
Expand Up @@ -589,6 +589,14 @@ class Dataset {
return feature_groups_[i]->is_multi_val_;
}

inline size_t FeatureGroupSizesInByte(int group) const {
return feature_groups_[group]->FeatureGroupSizesInByte();
}

inline void* FeatureGroupData(int group) const {
return feature_groups_[group]->FeatureGroupData();
}

inline double RealThreshold(int i, uint32_t threshold) const {
const int group = feature2group_[i];
const int sub_feature = feature2subfeature_[i];
Expand Down
11 changes: 11 additions & 0 deletions include/LightGBM/feature_group.h
Expand Up @@ -228,6 +228,17 @@ class FeatureGroup {
return bin_data_->GetIterator(min_bin, max_bin, most_freq_bin);
}

inline size_t FeatureGroupSizesInByte() {
return bin_data_->SizesInByte();
}

inline void* FeatureGroupData() {
if (is_multi_val_) {
return nullptr;
}
return bin_data_->get_data();
}

inline data_size_t Split(int sub_feature, const uint32_t* threshold,
int num_threshold, bool default_left,
const data_size_t* data_indices, data_size_t cnt,
Expand Down
8 changes: 6 additions & 2 deletions python-package/setup.py
Expand Up @@ -87,7 +87,7 @@ def silent_call(cmd, raise_error=False, error_msg=''):
return 1


def compile_cpp(use_mingw=False, use_gpu=False, use_mpi=False,
def compile_cpp(use_mingw=False, use_gpu=False, use_cuda=False, use_mpi=False,
use_hdfs=False, boost_root=None, boost_dir=None,
boost_include_dir=None, boost_librarydir=None,
opencl_include_dir=None, opencl_library=None,
Expand Down Expand Up @@ -115,6 +115,8 @@ def compile_cpp(use_mingw=False, use_gpu=False, use_mpi=False,
cmake_cmd.append("-DOpenCL_INCLUDE_DIR={0}".format(opencl_include_dir))
if opencl_library:
cmake_cmd.append("-DOpenCL_LIBRARY={0}".format(opencl_library))
elif use_cuda:
cmake_cmd.append("-DUSE_CUDA=ON")
if use_mpi:
cmake_cmd.append("-DUSE_MPI=ON")
if nomp:
Expand Down Expand Up @@ -188,6 +190,7 @@ class CustomInstall(install):
user_options = install.user_options + [
('mingw', 'm', 'Compile with MinGW'),
('gpu', 'g', 'Compile GPU version'),
('cuda', None, 'Compile CUDA version'),
('mpi', None, 'Compile MPI version'),
('nomp', None, 'Compile version without OpenMP support'),
('hdfs', 'h', 'Compile HDFS version'),
Expand All @@ -205,6 +208,7 @@ def initialize_options(self):
install.initialize_options(self)
self.mingw = 0
self.gpu = 0
self.cuda = 0
self.boost_root = None
self.boost_dir = None
self.boost_include_dir = None
Expand All @@ -228,7 +232,7 @@ def run(self):
open(LOG_PATH, 'wb').close()
if not self.precompile:
copy_files(use_gpu=self.gpu)
compile_cpp(use_mingw=self.mingw, use_gpu=self.gpu, use_mpi=self.mpi,
compile_cpp(use_mingw=self.mingw, use_gpu=self.gpu, use_cuda=self.cuda, use_mpi=self.mpi,
use_hdfs=self.hdfs, boost_root=self.boost_root, boost_dir=self.boost_dir,
boost_include_dir=self.boost_include_dir, boost_librarydir=self.boost_librarydir,
opencl_include_dir=self.opencl_include_dir, opencl_library=self.opencl_library,
Expand Down
5 changes: 5 additions & 0 deletions src/application/application.cpp
Expand Up @@ -11,6 +11,7 @@
#include <LightGBM/network.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/prediction_early_stop.h>
#include <LightGBM/cuda/vector_cudahost.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/text_reader.h>
Expand Down Expand Up @@ -38,6 +39,10 @@ Application::Application(int argc, char** argv) {
if (config_.data.size() == 0 && config_.task != TaskType::kConvertModel) {
Log::Fatal("No training/prediction data, application quit");
}

if (config_.device_type == std::string("cuda")) {
LGBM_config_::current_device = lgbm_device_cuda;
}
}

Application::~Application() {
Expand Down

0 comments on commit f7ad945

Please sign in to comment.