Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example for the dtype change for gelu kernels #250

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
62 changes: 44 additions & 18 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,21 @@ mpirun -np 4 ./train_gpt2cu -b 8 -v 200 -s 200 -i data/TinyStories
// ----------------------------------------------------------------------------
// CUDA precision settings

// 128 Memory Read Packing data structure
template<class ElementType>
struct alignas(16) Packed128 {
__device__ ElementType& operator[](int index) {
return reinterpret_cast<ElementType*>(&payload)[index];
}

__device__ const ElementType& operator[](int index) const {
return reinterpret_cast<const ElementType*>(&payload)[index];
}

int4 payload;
static constexpr const size_t size = sizeof(int4) / sizeof(ElementType);
};

// turn on bf16 as default, done up here for now
#define ENABLE_BF16

Expand Down Expand Up @@ -685,26 +700,37 @@ __global__ void residual_forward_kernel(TOut* out, T1* inp1, T2* inp2, int N) {
}

#define GELU_SCALING_FACTOR sqrtf(2.0f / M_PI)
__global__ void gelu_forward_kernel(floatX* out, const floatX* inp, int N) {
__global__ void gelu_forward_kernel(Packed128<floatX>* out, const Packed128<floatX>* inp, int N) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < N) {
float xi = (float)inp[i];
float cube = 0.044715f * xi * xi * xi;
out[i] = (floatX)(0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube))));
if(i < N){
Packed128<floatX> data = inp[i];
Packed128<floatX> local_out;
for(int j = 0; j < Packed128<floatX>::size; j++){
float xi = (float)data[j];
float cube = 0.044715f * xi * xi * xi;
local_out[j] = (floatX)(0.5f * xi * (1.0f + tanhf(GELU_SCALING_FACTOR * (xi + cube))));
}
out[i] = local_out;
}
}

__global__ void gelu_backward_kernel(floatX* dinp, const floatX* inp, const floatX* dout, const int N) {
__global__ void gelu_backward_kernel(Packed128<floatX>* dinp, const Packed128<floatX>* inp, const Packed128<floatX>* dout, const int N) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < N) {
float x = (float)inp[i];
float cube = 0.044715f * x * x * x;
float tanh_arg = GELU_SCALING_FACTOR * (x + cube);
float tanh_out = tanhf(tanh_arg);
float coshf_out = coshf(tanh_arg);
float sech_out = 1.0f / (coshf_out * coshf_out);
float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);
dinp[i] = (floatX)(local_grad * (float)dout[i]);
Packed128<floatX> inp_packed = inp[i];
Packed128<floatX> dout_packed = dout[i];
Packed128<floatX> local_out;
for(int j = 0; j < Packed128<floatX>::size; j++){
float x = (float)inp_packed[j];
float cube = 0.044715f * x * x * x;
float tanh_arg = GELU_SCALING_FACTOR * (x + cube);
float tanh_out = tanhf(tanh_arg);
float coshf_out = coshf(tanh_arg);
float sech_out = 1.0f / (coshf_out * coshf_out);
float local_grad = 0.5f * (1.0f + tanh_out) + x * 0.5f * sech_out * GELU_SCALING_FACTOR * (1.0f + 3.0f * 0.044715f * x * x);
local_out[j] = (floatX)(local_grad * (float)dout_packed[j]);
}
dinp[i] = local_out;
}
}

Expand Down Expand Up @@ -1329,15 +1355,15 @@ void residual_forward(TOut* out, T1* inp1, T2* inp2, int N) {

void gelu_forward(floatX* out, const floatX* inp, int N) {
const int block_size = 128;
const int grid_size = CEIL_DIV(N, block_size);
gelu_forward_kernel<<<grid_size, block_size>>>(out, inp, N);
const int grid_size = CEIL_DIV(N/Packed128<floatX>::size, block_size);
gelu_forward_kernel<<<grid_size, block_size>>>((Packed128<floatX>*)out, (Packed128<floatX>*)inp, N);
cudaCheck(cudaGetLastError());
}

void gelu_backward(floatX* dinp, const floatX* inp, const floatX* dout, const int N) {
const int block_size = 128;
const int grid_size = CEIL_DIV(N, block_size);
gelu_backward_kernel<<<grid_size, block_size>>>(dinp, inp, dout, N);
const int grid_size = CEIL_DIV(N/Packed128<floatX>::size, block_size);
gelu_backward_kernel<<<grid_size, block_size>>>((Packed128<floatX>*)dinp, (Packed128<floatX>*)inp, (Packed128<floatX>*)dout, N);
cudaCheck(cudaGetLastError());
}

Expand Down