Skip to content

Commit

Permalink
fix CI error
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Feb 22, 2024
1 parent ff9f09b commit c7667a7
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions cupy/cuda/cupy_cub.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
// numbers as in general the comparison is ill defined.
// - DO NOT USE THIS STUB for supporting CUB sorting!!!!!!
using namespace cub;
#define CUPY_CUB_NAMESPACE cub

template <>
struct FpLimits<complex<float>>
Expand Down Expand Up @@ -99,6 +100,7 @@ class numeric_limits<__half> {
// hipCUB internally uses std::numeric_limits, so we should provide specializations for the complex numbers.
// Note that there's std::complex, so to avoid name collision we must use the full decoration (thrust::complex)!
// TODO(leofang): wrap CuPy's thrust namespace with another one (say, cupy::thrust) for safer scope resolution?
#define CUPY_CUB_NAMESPACE hipcub

namespace std {
template <>
Expand All @@ -111,6 +113,12 @@ class numeric_limits<thrust::complex<float>> {
static __host__ __device__ thrust::complex<float> lowest() noexcept {
return thrust::complex<float>(-std::numeric_limits<float>::max(), -std::numeric_limits<float>::max());
}

static __host__ __device__ thrust::complex<float> infinity() noexcept {
return thrust::complex<float>(std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity());
}

static constexpr bool has_infinity = true;
};

template <>
Expand All @@ -123,6 +131,12 @@ class numeric_limits<thrust::complex<double>> {
static __host__ __device__ thrust::complex<double> lowest() noexcept {
return thrust::complex<double>(-std::numeric_limits<double>::max(), -std::numeric_limits<double>::max());
}

static __host__ __device__ thrust::complex<double> infinity() noexcept {
return thrust::complex<double>(std::numeric_limits<double>::infinity(), std::numeric_limits<double>::infinity());
}

static constexpr bool has_infinity = true;
};

// Copied from https://github.com/ROCmSoftwarePlatform/hipCUB/blob/master-rocm-3.5/hipcub/include/hipcub/backend/rocprim/device/device_reduce.hpp
Expand All @@ -142,12 +156,21 @@ class numeric_limits<__half> {
__half lowest_value = *reinterpret_cast<__half*>(&lowest_half);
return lowest_value;
}

static __host__ __device__ __half infinity() noexcept {
unsigned short inf_half = 0x7C00U;
__half inf_value = *reinterpret_cast<__half*>(&inf_half);
return inf_value;
}

static constexpr bool has_infinity = true;
};
} // namespace std

using namespace hipcub;

#endif // ifndef CUPY_USE_HIP

/* ------------------------------------ end of boilerplate ------------------------------------ */


Expand Down Expand Up @@ -711,7 +734,7 @@ struct _cub_reduce_min {
{
DeviceReduce::Reduce(workspace, workspace_size, static_cast<T*>(x),
static_cast<T*>(y), num_items,
cub::Min(), std::numeric_limits<T>::infinity(), s);
CUPY_CUB_NAMESPACE::Min(), std::numeric_limits<T>::infinity(), s);
}
else
{
Expand All @@ -731,7 +754,7 @@ struct _cub_segmented_reduce_min {
DeviceSegmentedReduce::Reduce(workspace, workspace_size,
static_cast<T*>(x), static_cast<T*>(y), num_segments,
offset_start, offset_start+1,
cub::Min(), std::numeric_limits<T>::infinity(), s);
CUPY_CUB_NAMESPACE::Min(), std::numeric_limits<T>::infinity(), s);
}
else
{
Expand All @@ -754,7 +777,7 @@ struct _cub_reduce_max {
{
DeviceReduce::Reduce(workspace, workspace_size, static_cast<T*>(x),
static_cast<T*>(y), num_items,
cub::Max(), -std::numeric_limits<T>::infinity(), s);
CUPY_CUB_NAMESPACE::Max(), -std::numeric_limits<T>::infinity(), s);
}
else
{
Expand All @@ -774,7 +797,7 @@ struct _cub_segmented_reduce_max {
DeviceSegmentedReduce::Reduce(workspace, workspace_size,
static_cast<T*>(x), static_cast<T*>(y), num_segments,
offset_start, offset_start+1,
cub::Max(), -std::numeric_limits<T>::infinity(), s);
CUPY_CUB_NAMESPACE::Max(), -std::numeric_limits<T>::infinity(), s);
}
else
{
Expand Down

0 comments on commit c7667a7

Please sign in to comment.