Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move all kernels into a dedicated cuda stream #448

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions cudnn_att.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ static void cudaCheck(cudaError_t error, const char *file, int line) {
};
#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))

static void cuDNNCheck(cudnnStatus_t error, const char *file, int line) {
if (error != CUDNN_STATUS_SUCCESS) {
printf("[CUDNN ERROR] at file %s:%d:\n%s\n", file, line, cudnnGetErrorString(error));
exit(EXIT_FAILURE);
}
};
#define cuDNNCheck(err) (cuDNNCheck(err, __FILE__, __LINE__))

// Profiler utils
namespace {
class NvtxRange {
Expand All @@ -50,9 +58,8 @@ namespace {
static cudnnHandle_t cudnn_handle;
static size_t cudnn_workspace_size = 0; // dynamically allocated as needed (up to 256MiB!)
static void* cudnn_workspace = NULL;
#define checkCudnnErr(err) assert((int)err == 0);

static void checkCudnnFE(fe::error_object e, const char *file, int line) {
static void checkCudnnFE(const fe::error_object& e, const char *file, int line) {
if(!e.is_good()) {
printf("[CUDNN ERROR] at file %s:%d:\n%s\n", file, line, e.err_msg.c_str());
exit(EXIT_FAILURE);
Expand Down Expand Up @@ -240,11 +247,13 @@ auto lookup_cache_or_build_graph_bwd(int B, int NH, int T, int HS) {
void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
float* stats, // output for backward pass: (B, NH, T)
floatX* inp, // input: (B, T, 3, NH, HS) QKV
int B, int T, int NH, int C) {
int B, int T, int NH, int C, cudaStream_t stream) {
NVTX_RANGE_FN();
int HS = C / NH; // number of features per head
bool is_inference_only = (stats == nullptr);

cuDNNCheck(cudnnSetStream(cudnn_handle, stream));

// Get graph and tensors from cache (or generate it on first use)
auto graph = lookup_cache_or_build_graph_fwd(B, NH, T, HS, is_inference_only);

Expand All @@ -271,7 +280,7 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)

void attention_backward_cudnn(floatX* dqkvr, // output
floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs
int B, int T, int NH, int C) {
int B, int T, int NH, int C, cudaStream_t stream) {
NVTX_RANGE_FN();
int HS = C / NH; // number of features per head

Expand All @@ -298,15 +307,16 @@ void attention_backward_cudnn(floatX* dqkvr,
{Attn_scale_UID, &attn_scale_cpu}};

// Execute graph
cuDNNCheck(cudnnSetStream(cudnn_handle, stream));
checkCudnnFE(graph->execute(cudnn_handle, variant_pack, cudnn_workspace));
cudaCheck(cudaGetLastError());
}

void create_cudnn() {
checkCudnnErr(cudnnCreate(&cudnn_handle));
cuDNNCheck(cudnnCreate(&cudnn_handle));
}

void destroy_cudnn() {
if (cudnn_workspace != NULL) { cudaCheck(cudaFree(cudnn_workspace)); }
checkCudnnErr(cudnnDestroy(cudnn_handle));
cuDNNCheck(cudnnDestroy(cudnn_handle));
}
12 changes: 6 additions & 6 deletions test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -203,19 +203,19 @@ int main(int argc, char *argv[]) {
clock_gettime(CLOCK_MONOTONIC, &start);
gpt2_forward(&model, x, y, B, T);
gpt2_zero_grad(&model);
gpt2_backward(&model);
float mean_loss = gpt2_backward(&model);
clock_gettime(CLOCK_MONOTONIC, &end);
double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;

if (step == 0) {
// error checking at step 0 for reference activations

// compare the achieved loss
if (fabsf(model.mean_loss - *expected_loss) >= loss_diff_threshold) {
printf("LOSS MISMATCH: %f %f\n", model.mean_loss, *expected_loss);
if (fabsf(mean_loss - *expected_loss) >= loss_diff_threshold) {
printf("LOSS MISMATCH: %f %f\n", mean_loss, *expected_loss);
allok = 0;
} else {
printf("LOSS OK: %f %f\n", model.mean_loss, *expected_loss);
printf("LOSS OK: %f %f\n", mean_loss, *expected_loss);
}

// move the (mixed precision) grads from GPU to CPU
Expand Down Expand Up @@ -277,8 +277,8 @@ int main(int argc, char *argv[]) {
gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+1, &multi_gpu_config);

// print the timing information at the end
printf("step %d: loss %f (took %f ms)\n", step+1, model.mean_loss, time_elapsed_s * 1000);
losses[step] = model.mean_loss;
printf("step %d: loss %f (took %f ms)\n", step+1, mean_loss, time_elapsed_s * 1000);
losses[step] = mean_loss;
}

// expected losses are as follows, from Python
Expand Down
Loading
Loading