We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 82b13e6 commit 0ddd079Copy full SHA for 0ddd079
aten/src/THC/THCTensorTopK.cuh
@@ -117,7 +117,7 @@ struct TopKTypeConfig<at::Half> {
117
typedef uint32_t RadixType;
118
119
static inline __device__ RadixType convert(at::Half v) {
120
-#if CUDA_VERSION >= 8000
+#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
121
RadixType x = __half_as_ushort(v);
122
RadixType mask = -((x >> 15)) | 0x8000;
123
return (x ^ mask);
@@ -128,7 +128,7 @@ struct TopKTypeConfig<at::Half> {
128
}
129
130
static inline __device__ at::Half deconvert(RadixType v) {
131
132
RadixType mask = ((v >> 15) - 1) | 0x8000;
133
return __ushort_as_half(v ^ mask);
134
#else
0 commit comments