Skip to content

Commit

Permalink
Merge pull request #8203 from leofang/min_max_inf
Browse files Browse the repository at this point in the history
Fix CUB `min`/`max` initial values
  • Loading branch information
kmaehashi committed Apr 2, 2024
2 parents 8f70ec9 + 83b3937 commit 2f89464
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 71 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[submodule "third_party/cccl"]
path = third_party/cccl
url = https://github.com/NVIDIA/cccl.git
url = https://github.com/cupy/cccl.git
[submodule "third_party/jitify"]
path = third_party/jitify
url = https://github.com/NVIDIA/jitify.git
Expand Down
147 changes: 137 additions & 10 deletions cupy/cuda/cupy_cub.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
// numbers as in general the comparison is ill defined.
// - DO NOT USE THIS STUB for supporting CUB sorting!!!!!!
using namespace cub;
#define CUPY_CUB_NAMESPACE cub

template <>
struct FpLimits<complex<float>>
Expand Down Expand Up @@ -56,11 +57,50 @@ struct FpLimits<complex<double>>
template <> struct NumericTraits<complex<float>> : BaseTraits<FLOATING_POINT, true, false, unsigned int, complex<float>> {};
template <> struct NumericTraits<complex<double>> : BaseTraits<FLOATING_POINT, true, false, unsigned long long, complex<double>> {};

// need specializations for initial values
namespace std {

template <>
class numeric_limits<thrust::complex<float>> {
public:
static __host__ __device__ thrust::complex<float> infinity() noexcept {
return thrust::complex<float>(std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity());
}

static constexpr bool has_infinity = true;
};

template <>
class numeric_limits<thrust::complex<double>> {
public:
static __host__ __device__ thrust::complex<double> infinity() noexcept {
return thrust::complex<double>(std::numeric_limits<double>::infinity(), std::numeric_limits<double>::infinity());
}

static constexpr bool has_infinity = true;
};

template <>
class numeric_limits<__half> {
public:
static __host__ __device__ constexpr __half infinity() noexcept {
unsigned short inf_half = 0x7C00U;
__half inf_value = *reinterpret_cast<__half*>(&inf_half);
return inf_value;
}

static constexpr bool has_infinity = true;
};

} // namespace std


#else

// hipCUB internally uses std::numeric_limits, so we should provide specializations for the complex numbers.
// Note that there's std::complex, so to avoid name collision we must use the full decoration (thrust::complex)!
// TODO(leofang): wrap CuPy's thrust namespace with another one (say, cupy::thrust) for safer scope resolution?
#define CUPY_CUB_NAMESPACE hipcub

namespace std {
template <>
Expand All @@ -73,6 +113,12 @@ class numeric_limits<thrust::complex<float>> {
static __host__ __device__ thrust::complex<float> lowest() noexcept {
return thrust::complex<float>(-std::numeric_limits<float>::max(), -std::numeric_limits<float>::max());
}

static __host__ __device__ thrust::complex<float> infinity() noexcept {
return thrust::complex<float>(std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity());
}

static constexpr bool has_infinity = true;
};

template <>
Expand All @@ -85,6 +131,12 @@ class numeric_limits<thrust::complex<double>> {
static __host__ __device__ thrust::complex<double> lowest() noexcept {
return thrust::complex<double>(-std::numeric_limits<double>::max(), -std::numeric_limits<double>::max());
}

static __host__ __device__ thrust::complex<double> infinity() noexcept {
return thrust::complex<double>(std::numeric_limits<double>::infinity(), std::numeric_limits<double>::infinity());
}

static constexpr bool has_infinity = true;
};

// Copied from https://github.com/ROCmSoftwarePlatform/hipCUB/blob/master-rocm-3.5/hipcub/include/hipcub/backend/rocprim/device/device_reduce.hpp
Expand All @@ -104,12 +156,27 @@ class numeric_limits<__half> {
__half lowest_value = *reinterpret_cast<__half*>(&lowest_half);
return lowest_value;
}

static __host__ __device__ __half infinity() noexcept {
unsigned short inf_half = 0x7C00U;
__half inf_value = *reinterpret_cast<__half*>(&inf_half);
return inf_value;
}

static constexpr bool has_infinity = true;
};
} // namespace std

using namespace hipcub;

#endif // ifndef CUPY_USE_HIP

__host__ __device__ __half half_negate_inf() {
unsigned short minf_half = 0xFC00U;
__half* minf_value = reinterpret_cast<__half*>(&minf_half);
return *minf_value;
}

/* ------------------------------------ end of boilerplate ------------------------------------ */


Expand Down Expand Up @@ -669,8 +736,17 @@ struct _cub_reduce_min {
void operator()(void* workspace, size_t& workspace_size, void* x, void* y,
int num_items, cudaStream_t s)
{
DeviceReduce::Min(workspace, workspace_size, static_cast<T*>(x),
static_cast<T*>(y), num_items, s);
if constexpr (std::numeric_limits<T>::has_infinity)
{
DeviceReduce::Reduce(workspace, workspace_size, static_cast<T*>(x),
static_cast<T*>(y), num_items,
CUPY_CUB_NAMESPACE::Min(), std::numeric_limits<T>::infinity(), s);
}
else
{
DeviceReduce::Min(workspace, workspace_size, static_cast<T*>(x),
static_cast<T*>(y), num_items, s);
}
}
};

Expand All @@ -679,9 +755,19 @@ struct _cub_segmented_reduce_min {
void operator()(void* workspace, size_t& workspace_size, void* x, void* y,
int num_segments, seg_offset_itr offset_start, cudaStream_t s)
{
DeviceSegmentedReduce::Min(workspace, workspace_size,
static_cast<T*>(x), static_cast<T*>(y), num_segments,
offset_start, offset_start+1, s);
if constexpr (std::numeric_limits<T>::has_infinity)
{
DeviceSegmentedReduce::Reduce(workspace, workspace_size,
static_cast<T*>(x), static_cast<T*>(y), num_segments,
offset_start, offset_start+1,
CUPY_CUB_NAMESPACE::Min(), std::numeric_limits<T>::infinity(), s);
}
else
{
DeviceSegmentedReduce::Min(workspace, workspace_size,
static_cast<T*>(x), static_cast<T*>(y), num_segments,
offset_start, offset_start+1, s);
}
}
};

Expand All @@ -693,8 +779,28 @@ struct _cub_reduce_max {
void operator()(void* workspace, size_t& workspace_size, void* x, void* y,
int num_items, cudaStream_t s)
{
DeviceReduce::Max(workspace, workspace_size, static_cast<T*>(x),
static_cast<T*>(y), num_items, s);
if constexpr (std::numeric_limits<T>::has_infinity)
{
// to avoid compiler error: invalid argument type '__half' to unary expression on HIP...
if constexpr (std::is_same_v<T, __half>)
{
DeviceReduce::Reduce(workspace, workspace_size, static_cast<T*>(x),
static_cast<T*>(y), num_items,
CUPY_CUB_NAMESPACE::Max(), half_negate_inf(), s);
}
else
{
DeviceReduce::Reduce(workspace, workspace_size, static_cast<T*>(x),
static_cast<T*>(y), num_items,
CUPY_CUB_NAMESPACE::Max(), -std::numeric_limits<T>::infinity(), s);

}
}
else
{
DeviceReduce::Max(workspace, workspace_size, static_cast<T*>(x),
static_cast<T*>(y), num_items, s);
}
}
};

Expand All @@ -703,9 +809,30 @@ struct _cub_segmented_reduce_max {
void operator()(void* workspace, size_t& workspace_size, void* x, void* y,
int num_segments, seg_offset_itr offset_start, cudaStream_t s)
{
DeviceSegmentedReduce::Max(workspace, workspace_size,
static_cast<T*>(x), static_cast<T*>(y), num_segments,
offset_start, offset_start+1, s);
if constexpr (std::numeric_limits<T>::has_infinity)
{
// to avoid compiler error: invalid argument type '__half' to unary expression on HIP...
if constexpr (std::is_same_v<T, __half>)
{
DeviceSegmentedReduce::Reduce(workspace, workspace_size,
static_cast<T*>(x), static_cast<T*>(y), num_segments,
offset_start, offset_start+1,
CUPY_CUB_NAMESPACE::Max(), half_negate_inf(), s);
}
else
{
DeviceSegmentedReduce::Reduce(workspace, workspace_size,
static_cast<T*>(x), static_cast<T*>(y), num_segments,
offset_start, offset_start+1,
CUPY_CUB_NAMESPACE::Max(), -std::numeric_limits<T>::infinity(), s);
}
}
else
{
DeviceSegmentedReduce::Max(workspace, workspace_size,
static_cast<T*>(x), static_cast<T*>(y), num_segments,
offset_start, offset_start+1, s);
}
}
};

Expand Down
62 changes: 21 additions & 41 deletions install/cupy_builder/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _nvcc_gencode_options(cuda_version: int) -> List[str]:
('compute_72', 'sm_72'), # Jetson (Xavier)
('compute_87', 'sm_87'), # Jetson (Orin)
]
elif cuda_version >= 11010:
elif cuda_version >= 11020:
arch_list = ['compute_35',
'compute_50',
('compute_60', 'sm_60'),
Expand All @@ -143,23 +143,6 @@ def _nvcc_gencode_options(cuda_version: int) -> List[str]:
('compute_80', 'sm_80'),
('compute_86', 'sm_86'),
'compute_86']
elif cuda_version >= 11000:
arch_list = ['compute_35',
'compute_50',
('compute_60', 'sm_60'),
('compute_61', 'sm_61'),
('compute_70', 'sm_70'),
('compute_75', 'sm_75'),
('compute_80', 'sm_80'),
'compute_80']
elif cuda_version >= 10000:
arch_list = ['compute_30',
'compute_50',
('compute_60', 'sm_60'),
('compute_61', 'sm_61'),
('compute_70', 'sm_70'),
('compute_75', 'sm_75'),
'compute_70']
else:
# This should not happen.
assert False
Expand Down Expand Up @@ -215,14 +198,14 @@ def _compile_unix_nvcc(self, obj: str, src: str, ext: Extension) -> None:

cuda_version = self._context.features['cuda'].get_version()
postargs = _nvcc_gencode_options(cuda_version) + [
'-Xfatbin=-compress-all', '-O2', '--compiler-options="-fPIC"']
if cuda_version >= 11020:
postargs += ['--std=c++14']
num_threads = int(os.environ.get('CUPY_NUM_NVCC_THREADS', '2'))
postargs += [f'-t{num_threads}']
else:
postargs += ['--std=c++11']
postargs += ['-Xcompiler=-fno-gnu-unique']
'-Xfatbin=-compress-all', '-O2', '--compiler-options="-fPIC"',
'--expt-relaxed-constexpr']
num_threads = int(os.environ.get('CUPY_NUM_NVCC_THREADS', '2'))
# Note: we only support CUDA 11.2+ since CuPy v13.0.0.
# Bumping C++ standard from C++14 to C++17 for "if constexpr"
postargs += ['--std=c++17',
f'-t{num_threads}',
'-Xcompiler=-fno-gnu-unique']
print('NVCC options:', postargs)
self.spawn(compiler_so + base_opts + cc_args + [src, '-o', obj] +
postargs)
Expand All @@ -235,12 +218,10 @@ def _compile_unix_hipcc(self, obj: str, src: str, ext: Extension) -> None:
base_opts = build.get_compiler_base_options(rocm_path)
compiler_so = rocm_path

hip_version = build.get_hip_version()
postargs = ['-O2', '-fPIC', '--include', 'hip_runtime.h']
if hip_version >= 402:
postargs += ['--std=c++14']
else:
postargs += ['--std=c++11']
# Note: we only support ROCm 4.3+ since CuPy v11.0.0.
# Bumping C++ standard from C++14 to C++17 for "if constexpr"
postargs += ['--std=c++17']
print('HIPCC options:', postargs)
self.spawn(compiler_so + base_opts + cc_args + [src, '-o', obj] +
postargs)
Expand All @@ -257,17 +238,16 @@ def compile(self, obj: str, src: str, ext: Extension) -> None:
cuda_version = self._context.features['cuda'].get_version()
postargs = _nvcc_gencode_options(cuda_version) + [
'-Xfatbin=-compress-all', '-O2']
if cuda_version >= 11020:
# MSVC 14.0 (2015) is deprecated for CUDA 11.2 but we need it
# to build CuPy because some Python versions were built using it.
# REF: https://wiki.python.org/moin/WindowsCompilers
postargs += ['-allow-unsupported-compiler']
# Note: we only support CUDA 11.2+ since CuPy v13.0.0.
# MSVC 14.0 (2015) is deprecated for CUDA 11.2 but we need it
# to build CuPy because some Python versions were built using it.
# REF: https://wiki.python.org/moin/WindowsCompilers
postargs += ['-allow-unsupported-compiler']
postargs += ['-Xcompiler', '/MD', '-D_USE_MATH_DEFINES']
# This is to compile thrust with MSVC2015
if cuda_version >= 11020:
postargs += ['--std=c++14']
num_threads = int(os.environ.get('CUPY_NUM_NVCC_THREADS', '2'))
postargs += [f'-t{num_threads}']
# Bumping C++ standard from C++14 to C++17 for "if constexpr"
num_threads = int(os.environ.get('CUPY_NUM_NVCC_THREADS', '2'))
postargs += ['--std=c++17',
f'-t{num_threads}']
cl_exe_path = self._find_host_compiler_path()
if cl_exe_path is not None:
print(f'Using host compiler at {cl_exe_path}')
Expand Down

0 comments on commit 2f89464

Please sign in to comment.