Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for FP16/BF16 in train_gpt2.cu (1.86x Perf) #218

Merged
merged 12 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
51 changes: 40 additions & 11 deletions test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
#include "train_gpt2.cu"

// poor man's tensor checker
int check_tensor(float *a, float *b, int n, const char* label) {
int check_tensor(float *a, float *b, int n, const char* label, float threshold=1e-0) {
int print_upto = 5;
int ok = 1;
printf("%s\n", label);
for (int i = 0; i < n; i++) {
if (fabsf(a[i] - b[i]) <= 1e-2) {
if (fabsf(a[i] - b[i]) <= threshold) {
if (i < print_upto) { printf("OK "); }
} else {
if (i < print_upto) { printf("NOT OK "); }
Expand Down Expand Up @@ -70,9 +70,10 @@ int main(int argc, char *argv[]) {

ParameterTensors expected_grads; // will be read from file (from PyTorch)
ParameterTensors calculated_grads; // will be calculated by us
float* expected_grads_memory = malloc_and_point_parameters(&expected_grads, model.param_sizes, 0);
float* calculated_grads_memory = malloc_and_point_parameters(&calculated_grads, model.param_sizes, 0);

float* expected_grads_memory = malloc_and_point_parameters(&expected_grads, model.param_elements, model.param_sizeof, 0);
float* calculated_grads_memory = malloc_and_point_parameters(&calculated_grads, model.param_elements, model.param_sizeof, 0);
float* converted_grads_memory = (float*)mallocCheck(model.num_parameters * sizeof(float));

// inputs and expected outputs, only used for error checking
int* x = (int*)mallocCheck(B * T * sizeof(int));
int* y = (int*)mallocCheck(B * T * sizeof(int));
Expand All @@ -94,14 +95,25 @@ int main(int argc, char *argv[]) {
gpt2_forward(&model, x, NULL, B, T);
// at this point, target should be equal to expected_logits, let's compare
// copy logits to CPU so we can compare them
floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * V * sizeof(floatX));
float* logits_cpu = (float*)mallocCheck(B * T * V * sizeof(float));
cudaMemcpy(logits_cpu, model.acts.output, B * T * V * sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * V * sizeof(floatX), cudaMemcpyDeviceToHost);
for (int i = 0; i < B * T * V; i++) {
logits_cpu[i] = (float)logits_cpu_raw[i];
}
int logits_ok = 1;

// FP16 and lower require very high tolerances unfortunately
float accuracy_threshold = 1e-2;
#if defined(ENABLE_BF16) || defined(ENABLE_F16)
accuracy_threshold = 20;
#endif

for (int i=0; i<B*T*V; i++) {
if(i < 3) {
printf("%f %f\n", expected_logits[i], logits_cpu[i]);
}
if (fabsf(expected_logits[i] - logits_cpu[i]) >= 1e-2) {
if (fabsf(expected_logits[i] - logits_cpu[i]) >= accuracy_threshold) {
printf("MISMATCH AT INDEX %d: ", i);
printf("%f %f\n", expected_logits[i],logits_cpu[i]);
logits_ok = 0;
Expand All @@ -127,10 +139,11 @@ int main(int argc, char *argv[]) {


allok = allok && logits_ok;
free(logits_cpu_raw);
free(logits_cpu);

// compare the achieved loss
if (fabsf(model.mean_loss - *expected_loss) >= 1e-2) {
if (fabsf(model.mean_loss - *expected_loss) >= accuracy_threshold) {
printf("LOSS MISMATCH: %f %f\n", model.mean_loss, *expected_loss);
allok = 0;
} else {
Expand Down Expand Up @@ -171,9 +184,24 @@ int main(int argc, char *argv[]) {
// check_tensor(calculated_grads.wte, expected_grads.wte, V * C, "wte");
// check_tensor(calculated_grads.wpe, expected_grads.wpe, maxT * C, "wpe");

// get gradients from GPU and convert all non-FP32 gradients back to FP32 for check_tensor
cudaMemcpy(calculated_grads_memory, model.grads_memory, model.num_parameters * sizeof(floatX), cudaMemcpyDeviceToHost);
char* src_iterator = (char*)calculated_grads_memory;
float* dst_iterator = (float*)converted_grads_memory;
for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
if (model.param_sizeof[i] == sizeof(float)) {
memcpy(dst_iterator, src_iterator, model.param_elements[i] * sizeof(float));
} else {
assert(model.param_sizeof[i] == sizeof(floatX));
for (size_t j = 0; j < model.param_elements[i]; j++) {
dst_iterator[j] = ((floatX*)src_iterator)[j];
}
}
src_iterator += model.param_elements[i] * model.param_sizeof[i];
dst_iterator += model.param_elements[i];
}
// compare the gradients ona the parameters all at once
cudaMemcpy(calculated_grads_memory, model.grads_memory, model.num_parameters * sizeof(float), cudaMemcpyDeviceToHost);
check_tensor(calculated_grads_memory, expected_grads_memory, model.num_parameters, "grads");
check_tensor(converted_grads_memory, expected_grads_memory, model.num_parameters, "grads");
}

gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.01f, step+1);
Expand All @@ -199,7 +227,7 @@ int main(int argc, char *argv[]) {

// compare
for (int i = 0; i < 10; i++) {
if (fabsf(losses[i] - expected_losses[i]) >= 1e-2) {
if (fabsf(losses[i] - expected_losses[i]) >= accuracy_threshold) {
printf("LOSS MISMATCH AT STEP %d: %f %f\n", i, losses[i], expected_losses[i]);
allok = 0;
} else {
Expand All @@ -217,6 +245,7 @@ int main(int argc, char *argv[]) {
free(expected_loss);
free(expected_grads_memory);
free(calculated_grads_memory);
free(converted_grads_memory);
gpt2_free(&model);
cudaCheck(cudaFree(cublaslt_workspace));
cublasCheck(cublasDestroy(cublas_handle));
Expand Down