Skip to content

Commit 0ddd079

Browse files
lcskrishnaiotamudelta
authored andcommitted
fixes for topk fp16 (ROCm#270)
1 parent 82b13e6 commit 0ddd079

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

aten/src/THC/THCTensorTopK.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ struct TopKTypeConfig<at::Half> {
117117
typedef uint32_t RadixType;
118118

119119
static inline __device__ RadixType convert(at::Half v) {
120-
#if CUDA_VERSION >= 8000
120+
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
121121
RadixType x = __half_as_ushort(v);
122122
RadixType mask = -((x >> 15)) | 0x8000;
123123
return (x ^ mask);
@@ -128,7 +128,7 @@ struct TopKTypeConfig<at::Half> {
128128
}
129129

130130
static inline __device__ at::Half deconvert(RadixType v) {
131-
#if CUDA_VERSION >= 8000
131+
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
132132
RadixType mask = ((v >> 15) - 1) | 0x8000;
133133
return __ushort_as_half(v ^ mask);
134134
#else

0 commit comments

Comments
 (0)