Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update residual_forward to use packed input #299

Merged
merged 7 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
86 changes: 64 additions & 22 deletions dev/cuda/residual_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,30 @@ nvcc -O3 --use_fast_math residual_forward.cu -o residual_forward

version 1 is naive port from CPU code to kernel
./residual_forward 1
version 2 packs input into 128 bit memory reads
./residual_forward 2
*/

#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>
#include "common.h"

// turn on bf16 as default, done up here for now
#define ENABLE_BF16

#if defined(ENABLE_BF16)
typedef __nv_bfloat16 floatX;
typedef __nv_bfloat16 floatN;
#elif defined(ENABLE_FP16)
typedef half floatX;
typedef half floatN;
#else
typedef float floatX;
typedef float floatN;
#endif

typedef Packed128<floatX> x128;
// ----------------------------------------------------------------------------
// CPU code reference lol

Expand All @@ -26,33 +43,56 @@ void residual_forward_cpu(float* out, const float* inp1, const float* inp2, int
// GPU kernels

// elementwise ops are nice and ez
__global__ void residual_forward_kernel(float* out, const float* inp1, const float* inp2, int N) {
__global__ void residual_forward_kernel1(floatX* out, const floatX* inp1, const floatX* inp2, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) {
out[idx] = inp1[idx] + inp2[idx];
out[idx] = (floatX)((float)inp1[idx] + (float)inp2[idx]);
}
}

__global__ void residual_forward_kernel2(floatX* out, const floatX* inp1, const floatX* inp2, int N) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
if (idx < N) {
x128 packed_out;
x128 packed_inp1 = load128cs(inp1 + idx);
x128 packed_inp2 = load128cs(inp2 + idx);
for (int k = 0; k < packed_inp1.size; ++k)
{
packed_out[k] = (floatX)((float)packed_inp1[k] + (float)packed_inp2[k]);
}
store128(out + idx, packed_out);
}
}

// ----------------------------------------------------------------------------
// kernel launcher

void residual_forward1(float* out, const float* inp1, const float* inp2, int N, const int block_size) {
void residual_forward1(floatX* out, const floatX* inp1, const floatX* inp2, int N, const int block_size) {
const int grid_size = ceil_div(N, block_size);
residual_forward_kernel<<<grid_size, block_size>>>(out, inp1, inp2, N);
residual_forward_kernel1<<<grid_size, block_size>>>(out, inp1, inp2, N);
cudaCheck(cudaGetLastError());
}

void residual_forward2(floatX* out, const floatX* inp1, const floatX* inp2, int N, const int block_size) {
const int grid_size = ceil_div(N, (int)(block_size * x128::size));
residual_forward_kernel2<<<grid_size, block_size>>>(out, inp1, inp2, N);
cudaCheck(cudaGetLastError());
}

// kernel version dispatch
void residual_forward(int kernel_num,
float* out,
const float* inp1,
const float* inp2,
floatX* out,
const floatX* inp1,
const floatX* inp2,
int N,
int block_size) {
switch (kernel_num) {
case 1:
residual_forward1(out, inp1, inp2, N, block_size);
break;
case 2:
residual_forward2(out, inp1, inp2, N, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand All @@ -62,29 +102,26 @@ void residual_forward(int kernel_num,
// ----------------------------------------------------------------------------

int main(int argc, char **argv) {
srand(0);
setup_main();

int B = 8;
int T = 1024;
int C = 768;

int deviceIdx = 0;
cudaCheck(cudaSetDevice(deviceIdx));

// create host memory of random numbers
float* out = (float*)malloc(B * T * C * sizeof(float));
float* inp1 = make_random_float(B * T * C);
float* inp2 = make_random_float(B * T * C);

// move to GPU
float* d_out;
float* d_inp1;
float* d_inp2;
cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float)));
cudaCheck(cudaMalloc(&d_inp1, B * T * C * sizeof(float)));
cudaCheck(cudaMalloc(&d_inp2, B * T * C * sizeof(float)));
cudaCheck(cudaMemcpy(d_inp1, inp1, B * T * C * sizeof(float), cudaMemcpyHostToDevice));
cudaCheck(cudaMemcpy(d_inp2, inp2, B * T * C * sizeof(float), cudaMemcpyHostToDevice));
floatX* d_out;
floatX* d_inp1;
floatX* d_inp2;
cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(floatX)));
cudaCheck(cudaMalloc(&d_inp1, B * T * C * sizeof(floatX)));
cudaCheck(cudaMalloc(&d_inp2, B * T * C * sizeof(floatX)));
cudaCheck(memcpy_convert(d_inp1, inp1, B * T * C));
cudaCheck(memcpy_convert(d_inp2, inp2, B * T * C));

// read kernel_num from command line
int kernel_num = 1;
Expand All @@ -104,7 +141,12 @@ int main(int argc, char **argv) {
int block_size = block_sizes[j];
printf("Checking block size %d.\n", block_size);
residual_forward(kernel_num, d_out, d_inp1, d_inp2, B * T * C, block_size);
validate_result(d_out, out, "out", B * T * C, 1e-5f);
#if !defined(ENABLE_BF16) && !defined(ENABLE_FP16)
float tol = 1e-5;
#else
float tol = 1e-2f;
#endif
validate_result(d_out, out, "out", B * T * C, tol);
}

printf("All results match. Starting benchmarks.\n\n");
Expand Down Expand Up @@ -135,4 +177,4 @@ int main(int argc, char **argv) {
cudaCheck(cudaFree(d_inp2));

return 0;
}
}
14 changes: 11 additions & 3 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -935,9 +935,17 @@ __global__ void softmax_forward_kernel5(floatX* out, float inv_temperature, cons
}

__global__ void residual_forward_kernel(floatX* out, floatX* inp1, floatX* inp2, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
if (idx < N) {
out[idx] = (floatX)((float)__ldcs(&inp1[idx]) + (float)__ldcs(&inp2[idx]));
x128 packed_out;
x128 packed_inp1 = load128cs(inp1 + idx);
x128 packed_inp2 = load128cs(inp2 + idx);
#pragma unroll packed_inp1.size
for (int k = 0; k < packed_inp1.size; ++k)
{
packed_out[k] = (floatX)((float)packed_inp1[k] + (float)packed_inp2[k]);
}
store128(out + idx, packed_out);
}
}

Expand Down Expand Up @@ -1481,7 +1489,7 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,

void residual_forward(floatX* out, floatX* inp1, floatX* inp2, int N) {
const int block_size = 256;
const int grid_size = CEIL_DIV(N, block_size);
const int grid_size = CEIL_DIV(N, block_size * x128::size);
residual_forward_kernel<<<grid_size, block_size>>>(out, inp1, inp2, N);
cudaCheck(cudaGetLastError());
}
Expand Down