Skip to content

Commit

Permalink
Add the bf16 cuda kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Jun 29, 2023
1 parent 018e017 commit ec79fc4
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 1 deletion.
4 changes: 4 additions & 0 deletions candle-kernels/src/affine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ extern "C" __global__ void FN_NAME( \
} \
} \

#if __CUDA_ARCH__ >= 800
AFFINE_OP(__nv_bfloat16, affine_bf16)
#endif

#if __CUDA_ARCH__ >= 530
AFFINE_OP(__half, affine_f16)
#endif
Expand Down
7 changes: 7 additions & 0 deletions candle-kernels/src/binary.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
#include "binary_op_macros.cuh"
#include<stdint.h>

#if __CUDA_ARCH__ >= 800
BINARY_OP(__nv_bfloat16, badd_bf16, x + y)
BINARY_OP(__nv_bfloat16, bdiv_bf16, x / y)
BINARY_OP(__nv_bfloat16, bmul_bf16, x * y)
BINARY_OP(__nv_bfloat16, bsub_bf16, x - y)
#endif

#if __CUDA_ARCH__ >= 530
BINARY_OP(__half, badd_f16, x + y)
BINARY_OP(__half, bdiv_f16, x / y)
Expand Down
13 changes: 13 additions & 0 deletions candle-kernels/src/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ extern "C" __global__ void FN_NAME( \
} \
} \

#if __CUDA_ARCH__ >= 800
CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16)

CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32)
CAST_OP(__nv_bfloat16, __half, cast_bf16_f16)
CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
CAST_OP(__nv_bfloat16, double, cast_bf16_f64)
CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16)
CAST_OP(__half, __nv_bfloat16, cast_f16_bf16)
CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
CAST_OP(double, __nv_bfloat16, cast_f64_bf16)
#endif

#if __CUDA_ARCH__ >= 530
CAST_OP(__half, __half, cast_f16_f16)

Expand Down
1 change: 1 addition & 0 deletions candle-kernels/src/compatibility.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "cuda_fp16.h"
#include "cuda_bf16.h"

// Table showing which features are supported on which compute capability
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/#features-and-technical-specifications
Expand Down
17 changes: 16 additions & 1 deletion candle-kernels/src/cuda_utils.cuh
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include "cuda_fp16.h"
#include "compatibility.cuh"

// TODO: This is often used to check that the data is contiguous so that
Expand Down Expand Up @@ -156,3 +155,19 @@ __device__ __forceinline__ __half expg(__half a) { return hexp(a); }
__device__ __forceinline__ __half absg(__half a) { return __habs(a); }
__device__ __forceinline__ __half copysigng(__half a, __half b) { return __float2half(copysignf(__half2float(a), __half2float(b))); }
#endif

#if __CUDA_ARCH__ >= 800
__device__ __forceinline__ __nv_bfloat16 powg(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(powf(__bfloat162float(a), __bfloat162float(b))); }
__device__ __forceinline__ bool isnang(__nv_bfloat16 a) { return __hisnan(a); }
__device__ __forceinline__ __nv_bfloat16 sqrtg(__nv_bfloat16 a) { return hsqrt(a); }
__device__ __forceinline__ __nv_bfloat16 cosg(__nv_bfloat16 a) { return hcos(a); }
__device__ __forceinline__ __nv_bfloat16 sing(__nv_bfloat16 a) { return hsin(a); }
__device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16 one = 1.0; return one / a; }
__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); }
__device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); }
__device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); }
__device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); }
__device__ __forceinline__ __nv_bfloat16 expg(__nv_bfloat16 a) { return hexp(a); }
__device__ __forceinline__ __nv_bfloat16 absg(__nv_bfloat16 a) { return __habs(a); }
__device__ __forceinline__ __nv_bfloat16 copysigng(__nv_bfloat16 a, __nv_bfloat16 b) { return __float2bfloat16(copysignf(__bfloat162float(a), __bfloat162float(b))); }
#endif
4 changes: 4 additions & 0 deletions candle-kernels/src/embeddings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ extern "C" __global__ void FN_NAME( \
} \
} \

#if __CUDA_ARCH__ >= 800
EMB_OP(__nv_bfloat16, emb_bf16)
#endif

#if __CUDA_ARCH__ >= 530
EMB_OP(__half, emb_f16)
#endif
Expand Down
4 changes: 4 additions & 0 deletions candle-kernels/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ extern "C" __global__ void FN_NAME( \
} \
} \

#if __CUDA_ARCH__ >= 800
SUM_OP(__nv_bfloat16, sum_bf16)
#endif

#if __CUDA_ARCH__ >= 530
SUM_OP(__half, sum_f16)
#endif
Expand Down
4 changes: 4 additions & 0 deletions candle-kernels/src/ternary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ extern "C" __global__ void FN_NAME( \
} \
} \

#if __CUDA_ARCH__ >= 800
WHERE_OP(__nv_bfloat16, where_bf16)
#endif

#if __CUDA_ARCH__ >= 530
WHERE_OP(__half, where_f16)
#endif
Expand Down
14 changes: 14 additions & 0 deletions candle-kernels/src/unary.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,20 @@ __device__ __forceinline__ T relu_fwd(T x) {
}


#if __CUDA_ARCH__ >= 800
UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
UNARY_OP(__nv_bfloat16, uneg_bf16, -x)
UNARY_OP(__nv_bfloat16, uexp_bf16, expg(x))
UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))
UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x))
UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x))
UNARY_OP(__nv_bfloat16, usqr_bf16, x*x)
UNARY_OP(__nv_bfloat16, usqrt_bf16, sqrtg(x))
UNARY_OP(__nv_bfloat16, ugelu_bf16, gelu_fwd(x))
UNARY_OP(__nv_bfloat16, urelu_bf16, relu_fwd(x))
#endif

#if __CUDA_ARCH__ >= 530
UNARY_OP(__half, ucopy_f16, x)
UNARY_OP(__half, uneg_f16, -x)
Expand Down

0 comments on commit ec79fc4

Please sign in to comment.