Skip to content

Commit

Permalink
add cast and op_functor for cuda build-in types
Browse files Browse the repository at this point in the history
  • Loading branch information
Courtesy-Xs committed Apr 2, 2024
1 parent a2878e3 commit f033ada
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 25 deletions.
71 changes: 71 additions & 0 deletions extensions/csrc/cuda/funcs/cast_functor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#pragma once

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <functional>

#include "../utils/micros.h"

// Note(LiuYang): This file provides base math operation for data type
// include POD and cuda built-in type such as half and __nv_bfloat16

namespace colossalAI {
namespace cuda {
namespace funcs {

// Get type2 from type or vice versa (applied to half and bfloat16)
template <typename T>
struct TypeConverter {
using Type = half2;
}; // keep for generality

template <>
struct TypeConverter<half2> {
using Type = at::Half;
};

template <>
struct TypeConverter<at::Half> {
using Type = half2;
};

template <>
struct TypeConverter<__nv_bfloat162> {
using Type = at::BFloat16;
};

template <>
struct TypeConverter<at::BFloat16> {
using Type = __nv_bfloat162;
};

template <typename From, typename To>
struct CastFunctor : public std::unary_function<From, To> {
HOSTDEVICE To operator()(From val) { return static_cast<To>(val); }
};

#define COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(FROM, TO, STMT, \
FUNCTION_MODIFIER) \
template <> \
struct CastFunctor<FROM, TO> : public std::unary_function<FROM, TO> { \
FUNCTION_MODIFIER TO operator()(FROM val) { return STMT; } \
};

COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(int2, float2, make_float2(val.x, val.y),
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, float2, make_float2(val, val),
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half2, float2, __half22float2(val), DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float2, half2, __float22half2_rn(val),
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(float, half2, __float2half2_rn(val),
DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, half2, __half2half2(val), DEVICE)
COLOSSAL_CAST_FUNCTOR_SPECIALIZATION(half, float, __half2float(val), DEVICE)

#undef COLOSSAL_CAST_FUNCTOR_SPECIALIZATION
} // namespace funcs
} // namespace cuda
} // namespace colossalAI
60 changes: 47 additions & 13 deletions extensions/csrc/cuda/funcs/op_functor.h
Original file line number Diff line number Diff line change
@@ -1,31 +1,65 @@
#pragma once

#include <cuda.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>

#include <functional>

#include "../utils/micros.h"

namespace colossalAI {
namespace cuda {
namespace funcs {

enum class BinaryOpType { kAdd = 0, kMinus, kMul, KDiv, kMax, KMin };
enum class BinaryOpType { kAdd = 0, kMinus, kMul, kDiv, kMax, kMin };

template <typename T, BinaryOpType Op>
// Note(LiuYang): This file provides base math operation for data type
// include POD and cuda built-in type such as half and __nv_bfloat16
template <typename LT, typename RT, typename RET, BinaryOpType Op>
struct BinaryOpFunctor;

template <typename T>
struct BinaryOpFunctor<T, BinaryOpType::kAdd>
: public std::binary_function<T, T, T> {
__host__ __device__ T operator()(T lhs, T rhs) { return lhs + rhs; }
};

template <typename T>
struct BinaryOpFunctor<T, BinaryOpType::kMax>
: public std::binary_function<T, T, T> {
__host__ __device__ T operator()(T lhs, T rhs) { return max(lhs, rhs); }
};
#define COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BINARY_OP_TYPE, STMT, \
FUNCTION_MODIFIER, ARGS...) \
template <ARGS> \
struct BinaryOpFunctor<T, T, T, BINARY_OP_TYPE> \
: public std::binary_function<T, T, T> { \
FUNCTION_MODIFIER T operator()(T lhs, T rhs) { return STMT; } \
};

COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kAdd, lhs + rhs,
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMinus, lhs - rhs,
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMul, lhs* rhs,
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kDiv, lhs / rhs,
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMax, max(lhs, rhs),
HOSTDEVICE, typename T)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(T, BinaryOpType::kMin, min(lhs, rhs),
HOSTDEVICE, typename T)

COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kAdd,
__hadd(lhs, rhs), DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kAdd,
__hadd2(lhs, rhs), DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kAdd,
__hadd(lhs, rhs), DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kAdd,
__hadd2(lhs, rhs), DEVICE)

COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half, BinaryOpType::kMul,
__hmul(lhs, rhs), DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(half2, BinaryOpType::kMul,
__hmul2(lhs, rhs), DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat16, BinaryOpType::kMul,
__hmul(lhs, rhs), DEVICE)
COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION(__nv_bfloat162, BinaryOpType::kMul,
__hmul2(lhs, rhs), DEVICE)

#undef COLOSSAL_BINARY_FUNCTOR_SPECIALIZATION

} // namespace funcs
} // namespace cuda
Expand Down
4 changes: 2 additions & 2 deletions extensions/csrc/cuda/include/block_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ struct GetOpForReduceType;

template <typename T>
struct GetOpForReduceType<T, ReduceType::kMax> {
using Op = funcs::BinaryOpFunctor<T, funcs::BinaryOpType::kMax>;
using Op = funcs::BinaryOpFunctor<T, T, T, funcs::BinaryOpType::kMax>;
};

template <typename T>
struct GetOpForReduceType<T, ReduceType::kSum> {
using Op = funcs::BinaryOpFunctor<T, funcs::BinaryOpType::kAdd>;
using Op = funcs::BinaryOpFunctor<T, T, T, funcs::BinaryOpType::kAdd>;
};

#define COLOSSAL_SHFL_FUNCTION(MASK, VAL_PTR, DELTA, WIDTH, OP, LANES) \
Expand Down
34 changes: 24 additions & 10 deletions extensions/csrc/cuda/rms_layernorm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@

#include "block_reduce.h"
#include "../common/micros.h"
#include "utils/cuda_type_utils.h"
#include "funcs/cast_functor.h"
#include "funcs/op_functor.h"

using colossalAI::cuda::utils::block_reduce;
using colossalAI::cuda::utils::ReduceType;
using colossalAI::cuda::funcs::TypeConverter;
using colossalAI::cuda::funcs::CastFunctor;
using colossalAI::cuda::funcs::BinaryOpFunctor;
using colossalAI::cuda::funcs::BinaryOpType;

#define DISPATCH_RMSNORM_FLOAT_HALF_AND_BFLOAT(DATA_SIZE, TYPE, NAME, ...) \
if (DATA_SIZE == 2) { \
Expand Down Expand Up @@ -53,6 +58,9 @@ __global__ void rms_layernorm_kernel(
const int num_tokens,
const int hidden_size) {
using scalar2_t = typename TypeConverter<scalar_t>::Type;
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kMul> mul_scalar2t();
CastFunctor<scalar2_t, float> cast_scalar2t_2_float();
CastFunctor<float, scalar2_t> cast_float_2_scalar2t();
__shared__ float s_variance;

/*
Expand All @@ -72,12 +80,13 @@ __global__ void rms_layernorm_kernel(
float variance = 0.0f;
int row_offset = blockIdx.x * hidden_size / 2;


#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input_ptr[id];
float v1 = cuda_cast<float>(x_local[cnt].x);
float v2 = cuda_cast<float>(x_local[cnt].y);
float v1 = cast_scalar2t_2_float(x_local[cnt].x);
float v2 = cast_scalar2t_2_float(x_local[cnt].y);
variance += v1 * v1 + v2 * v2;
}
block_reduce<float, ReduceType::kSum,1>(&variance);
Expand All @@ -86,11 +95,11 @@ __global__ void rms_layernorm_kernel(
}
__syncthreads();

scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
scalar2_t s_variance_2 = cast_float_2_scalar2t(s_variance);
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
out_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
out_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]);
}
}

Expand Down Expand Up @@ -137,6 +146,11 @@ __global__ void fused_add_rms_layernorm_kernel(
const int num_tokens,
const int hidden_size) {
using scalar2_t = typename TypeConverter<scalar_t>::Type;
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kAdd> add_scalar2t();
CastFunctor<scalar2_t, float> cast_scalar2t_2_float();
CastFunctor<float, scalar2_t> cast_float_2_scalar2t();
BinaryOpFunctor<scalar2_t, scalar2_t, scalar2_t, BinaryOpType::kMul> mul_scalar2t();

__shared__ float s_variance;
scalar2_t x_local[4];

Expand All @@ -151,9 +165,9 @@ __global__ void fused_add_rms_layernorm_kernel(
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
x_local[cnt] = input_ptr[id];
x_local[cnt] = add(x_local[cnt], residual_ptr[id]);
float v1 = cuda_cast<float>(x_local[cnt].x);
float v2 = cuda_cast<float>(x_local[cnt].y);
x_local[cnt] = add_scalar2t(x_local[cnt], residual_ptr[id]);
float v1 = cast_scalar2t_2_float(x_local[cnt].x);
float v2 = cast_scalar2t_2_float(x_local[cnt].y);
variance += v1 * v1 + v2 * v2;
residual_ptr[id] = x_local[cnt];
}
Expand All @@ -163,11 +177,11 @@ __global__ void fused_add_rms_layernorm_kernel(
}
__syncthreads();

scalar2_t s_variance_2 = cuda_cast<scalar2_t>(s_variance);
scalar2_t s_variance_2 = cast_float_2_scalar2t(s_variance);
#pragma unroll unroll_factor
for (int idx = threadIdx.x, cnt = 0; idx < hidden_size / 2; idx += blockDim.x, cnt++) {
int id = row_offset + idx;
input_ptr[id] = mul(x_local[cnt], s_variance_2, weight_ptr[idx]);
input_ptr[id] = mul_scalar2t(mul_scalar2t(x_local[cnt], s_variance_2), weight_ptr[idx]);
}
}

Expand Down
4 changes: 4 additions & 0 deletions extensions/csrc/cuda/utils/micros.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@
throw std::runtime_error(cudaGetErrorString(status)); \
} \
}

#define HOST __host__
#define DEVICE __device__
#define HOSTDEVICE __host__ __device__

0 comments on commit f033ada

Please sign in to comment.