Skip to content

Commit

Permalink
fix CI error
Browse files Browse the repository at this point in the history
  • Loading branch information
leofang committed Feb 22, 2024
1 parent ff9f09b commit a291ab9
Showing 1 changed file with 39 additions and 38 deletions.
77 changes: 39 additions & 38 deletions cupy/cuda/cupy_cub.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,44 +56,6 @@ struct FpLimits<complex<double>>
template <> struct NumericTraits<complex<float>> : BaseTraits<FLOATING_POINT, true, false, unsigned int, complex<float>> {};
template <> struct NumericTraits<complex<double>> : BaseTraits<FLOATING_POINT, true, false, unsigned long long, complex<double>> {};

// need specializations for initial values
namespace std {

template <>
class numeric_limits<thrust::complex<float>> {
public:
static __host__ __device__ thrust::complex<float> infinity() noexcept {
return thrust::complex<float>(std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity());
}

static constexpr bool has_infinity = true;
};

template <>
class numeric_limits<thrust::complex<double>> {
public:
static __host__ __device__ thrust::complex<double> infinity() noexcept {
return thrust::complex<double>(std::numeric_limits<double>::infinity(), std::numeric_limits<double>::infinity());
}

static constexpr bool has_infinity = true;
};

template <>
class numeric_limits<__half> {
public:
static __host__ __device__ constexpr __half infinity() noexcept {
unsigned short inf_half = 0x7C00U;
__half inf_value = *reinterpret_cast<__half*>(&inf_half);
return inf_value;
}

static constexpr bool has_infinity = true;
};

} // namespace std


#else

// hipCUB internally uses std::numeric_limits, so we should provide specializations for the complex numbers.
Expand Down Expand Up @@ -148,6 +110,45 @@ class numeric_limits<__half> {
using namespace hipcub;

#endif // ifndef CUPY_USE_HIP


// need specializations for initial values
namespace std {

template <>
class numeric_limits<thrust::complex<float>> {
public:
static __host__ __device__ thrust::complex<float> infinity() noexcept {
return thrust::complex<float>(std::numeric_limits<float>::infinity(), std::numeric_limits<float>::infinity());
}

static constexpr bool has_infinity = true;
};

template <>
class numeric_limits<thrust::complex<double>> {
public:
static __host__ __device__ thrust::complex<double> infinity() noexcept {
return thrust::complex<double>(std::numeric_limits<double>::infinity(), std::numeric_limits<double>::infinity());
}

static constexpr bool has_infinity = true;
};

template <>
class numeric_limits<__half> {
public:
static __host__ __device__ constexpr __half infinity() noexcept {
unsigned short inf_half = 0x7C00U;
__half inf_value = *reinterpret_cast<__half*>(&inf_half);
return inf_value;
}

static constexpr bool has_infinity = true;
};

} // namespace std

/* ------------------------------------ end of boilerplate ------------------------------------ */


Expand Down

0 comments on commit a291ab9

Please sign in to comment.