Skip to content

Commit

Permalink
Fixes #437: Respect CUDA_NO_HALF
Browse files Browse the repository at this point in the history
When `CUDA_NO_HALF` is defined, do not include half-precision-related headers nor define functionality involving half-precision types.
  • Loading branch information
georgelyu authored and eyalroz committed Nov 12, 2022
1 parent f992672 commit 1a52be1
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/cuda/api/array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

#include <cuda_runtime.h>
#include <cuda.h>

#ifndef CUDA_NO_HALF
#include <cuda_fp16.h>
#endif

namespace cuda {

Expand Down Expand Up @@ -52,7 +55,9 @@ template <> struct format_specifier<uint32_t> { static constexpr const CUarray_f
template <> struct format_specifier<int8_t > { static constexpr const CUarray_format value = CU_AD_FORMAT_SIGNED_INT8; };
template <> struct format_specifier<int16_t > { static constexpr const CUarray_format value = CU_AD_FORMAT_SIGNED_INT16; };
template <> struct format_specifier<int32_t > { static constexpr const CUarray_format value = CU_AD_FORMAT_SIGNED_INT32; };
#ifndef CUDA_NO_HALF
template <> struct format_specifier<half > { static constexpr const CUarray_format value = CU_AD_FORMAT_HALF; };
#endif
template <> struct format_specifier<float > { static constexpr const CUarray_format value = CU_AD_FORMAT_FLOAT; };

template<typename T>
Expand Down

0 comments on commit 1a52be1

Please sign in to comment.