Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load bf16 directly, and some "quality of life" handling of fp32/fp16/bf16 precisions #265

Merged
merged 12 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 15 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,20 @@ else
endif
endif

# Precision settings, default to bf16 but ability to override
PRECISION ?= BF16
VALID_PRECISIONS := FP32 FP16 BF16
ifeq ($(filter $(PRECISION),$(VALID_PRECISIONS)),)
$(error Invalid precision $(PRECISION), valid precisions are $(VALID_PRECISIONS))
endif
ifeq ($(PRECISION), FP32)
PFLAGS = -DENABLE_FP32
else ifeq ($(PRECISION), FP16)
PFLAGS = -DENABLE_FP16
else
PFLAGS = -DENABLE_BF16
endif

# PHONY means these targets will always be executed
.PHONY: all train_gpt2 test_gpt2 train_gpt2cu test_gpt2cu train_gpt2fp32cu test_gpt2fp32cu

Expand All @@ -108,7 +122,7 @@ test_gpt2: test_gpt2.c
$(CC) $(CFLAGS) $(INCLUDES) $(LDFLAGS) $< $(LDLIBS) -o $@

train_gpt2cu: train_gpt2.cu
$(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) -o $@
$(NVCC) $(NVCC_FLAGS) $(PFLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) -o $@

train_gpt2fp32cu: train_gpt2_fp32.cu
$(NVCC) $(NVCC_FLAGS) $< $(NVCC_LDFLAGS) $(NVCC_INCLUDES) $(NVCC_LDLIBS) $(NVCC_LDFLAGS) -o $@
Expand Down
3 changes: 2 additions & 1 deletion profile_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ For example, I have NVIDIA Nsight Compute installed on my Mac, and I rsync
the profile.ncu-rep from a cloud box to local to pretty view.
*/

#define ENABLE_BF16
#define TESTING
#include "train_gpt2.cu"

Expand Down Expand Up @@ -51,7 +52,7 @@ int main() {

// build the GPT-2 model from a checkpoint
GPT2 model;
gpt2_build_from_checkpoint(&model, "gpt2_124M.bin");
gpt2_build_from_checkpoint(&model, "gpt2_124M_bf16.bin");

int B = 4;
int T = 1024;
Expand Down
220 changes: 140 additions & 80 deletions test_gpt2.cu
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
#define ENABLE_BF16
#define TESTING
#include "train_gpt2.cu"

// poor man's tensor checker
int check_tensor(float *a, float *b, int n, const char* label, float threshold=1e-0) {
int print_upto = 5;
// a is the calculated tensor, b is the reference tensor
int print_upto = 10;
int ok = 1;
float max_diff = 0.0f;
float max_rel_error = 0.0f;
float max_a = 0.0f;
float max_b = 0.0f;
printf("%s\n", label);
for (int i = 0; i < n; i++) {
if (fabsf(a[i] - b[i]) <= threshold) {
float diff = fabsf(a[i] - b[i]);
if (diff > max_diff) {
max_diff = diff;
float denom = fabsf(b[i]);
max_rel_error = (denom == 0.0f) ? 0.0f : diff / denom;
max_a = a[i];
max_b = b[i];
}
if (diff <= threshold) {
if (i < print_upto) { printf("OK "); }
} else {
if (i < print_upto) { printf("NOT OK "); }
Expand All @@ -17,13 +31,58 @@ int check_tensor(float *a, float *b, int n, const char* label, float threshold=1
}
// print the final result
if (ok) {
printf("TENSOR OK\n");
printf("TENSOR OK, max diff: %e, with rel error: %e (calculated=%f, ref=%f)\n",
max_diff, max_rel_error, max_a, max_b);
} else {
printf("TENSOR NOT OK\n");
printf("TENSOR NOT OK, max diff: %e, with rel error: %e (calculated=%f, ref=%f)\n",
max_diff, max_rel_error, max_a, max_b);
}
return ok;
}

// the same tensors as in the train file, but in float, which are used as reference
typedef struct {
float* wte; // (V, C)
float* wpe; // (maxT, C)
float* ln1w; // (L, C)
float* ln1b; // (L, C)
float* qkvw; // (L, 3*C, C)
float* qkvb; // (L, 3*C)
float* attprojw; // (L, C, C)
float* attprojb; // (L, C)
float* ln2w; // (L, C)
float* ln2b; // (L, C)
float* fcw; // (L, 4*C, C)
float* fcb; // (L, 4*C)
float* fcprojw; // (L, C, 4*C)
float* fcprojb; // (L, C)
float* lnfw; // (C)
float* lnfb; // (C)
} FloatParameterTensors;
static_assert(sizeof(FloatParameterTensors) == NUM_PARAMETER_TENSORS * sizeof(void*), "Inconsistent sizes!");

// malloc_and_point, but in float and on CPU, because we use this data to check correctness on CPU
float* float_cpu_malloc_and_point_parameters(FloatParameterTensors* params, size_t* param_sizes) {
// calculate the total number of parameters
size_t num_parameters = 0;
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
num_parameters += param_sizes[i];
}
// everything is float so number of bytes to allocate is a simple multiplication
float* params_memory = (float*)mallocCheck(num_parameters * sizeof(float));
float** ptrs[] = {
&params->wte, &params->wpe, &params->ln1w, &params->ln1b, &params->qkvw, &params->qkvb,
&params->attprojw, &params->attprojb, &params->ln2w, &params->ln2b, &params->fcw, &params->fcb,
&params->fcprojw, &params->fcprojb, &params->lnfw, &params->lnfb
};
float* params_memory_iterator = params_memory;
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
*(ptrs[i]) = params_memory_iterator;
params_memory_iterator += param_sizes[i];
}
return params_memory;
}

int main(int argc, char *argv[]) {

// set up the device
Expand All @@ -48,46 +107,46 @@ int main(int argc, char *argv[]) {

// build the GPT-2 model from a checkpoint
GPT2 model;
gpt2_build_from_checkpoint(&model, "gpt2_124M.bin");

// int C = model.config.channels;
int V = model.config.vocab_size;
int maxT = model.config.max_seq_len;
// int L = model.config.num_layers;
gpt2_build_from_checkpoint(&model, "gpt2_124M_bf16.bin");
size_t V = model.config.vocab_size;
size_t maxT = model.config.max_seq_len;
size_t L = model.config.num_layers;
size_t C = model.config.channels;

// load additional information that we will use for debugging and error checking
FILE *state_file = fopenCheck("gpt2_124M_debug_state.bin", "rb");
int state_header[256];
freadCheck(state_header, sizeof(int), 256, state_file);
if (state_header[0] != 20240327) { printf("Bad magic state file"); exit(1); }
if (state_header[1] != 1) { printf("Bad version in state file"); exit(1); }
if (state_header[0] != 20240327) { printf("Bad magic state file"); exit(EXIT_FAILURE); }
if (state_header[1] != 1) { printf("Bad version in state file"); exit(EXIT_FAILURE); }
int B = state_header[2]; // batch size, e.g. 4
int T = state_header[3]; // time / sequence length (e.g. 64, up to maxT)
assert(0 <= T && T <= maxT);
printf("[State]\n");
printf("batch_size: %d\n", B);
printf("seq_len: %d\n", T);

ParameterTensors expected_grads; // will be read from file (from PyTorch)
ParameterTensors calculated_grads; // will be calculated by us
float* expected_grads_memory = malloc_and_point_parameters(&expected_grads, model.param_elements, model.param_sizeof, 0);
float* calculated_grads_memory = malloc_and_point_parameters(&calculated_grads, model.param_elements, model.param_sizeof, 0);
float* converted_grads_memory = (float*)mallocCheck(model.num_parameters * sizeof(float));

// inputs and expected outputs, only used for error checking
// read reference information from the file saved from Python/PyTorch side
// 1) input x and y
int* x = (int*)mallocCheck(B * T * sizeof(int));
int* y = (int*)mallocCheck(B * T * sizeof(int));
float* expected_logits = (float*) mallocCheck(B * T * V * sizeof(float));
float* expected_loss = (float*) mallocCheck(1 * sizeof(float));

// read reference information from Python
freadCheck(x, sizeof(int), B*T, state_file);
freadCheck(y, sizeof(int), B*T, state_file);
// 2) results of forward pass (logits and loss)
float* expected_logits = (float*) mallocCheck(B * T * V * sizeof(float));
float* expected_loss = (float*) mallocCheck(1 * sizeof(float));
freadCheck(expected_logits, sizeof(float), B*T*V, state_file);
freadCheck(expected_loss, sizeof(float), 1, state_file);
// 3) results of backward pass (parameter gradients)
FloatParameterTensors expected_grads; // will be read from file. right now: all in fp32
float* expected_grads_memory = float_cpu_malloc_and_point_parameters(&expected_grads, model.param_elements);
freadCheck(expected_grads_memory, sizeof(float), model.num_parameters, state_file);
fcloseCheck(state_file);

// this memory will be used to do one single copy of all (mixed precision) GPU grads to CPU grads
void* grads_memory_cpu = mallocCheck(model.num_parameters_bytes);
float* grads_memory_cpu_float = (float*)mallocCheck(model.num_parameters * sizeof(float));

// overall OK signal for the test
int allok = 1;

Expand All @@ -103,25 +162,32 @@ int main(int argc, char *argv[]) {
}
int logits_ok = 1;

// FP16 and lower require very high tolerances unfortunately
float accuracy_threshold = 1e-2;
// FP16 and lower require very high tolerances unfortunately. TODO look into more
float logit_accuracy_threshold = 1e-2f;
float loss_diff_threshold = 0.05f;
#if defined(ENABLE_BF16) || defined(ENABLE_F16)
accuracy_threshold = 23;
logit_accuracy_threshold = 15.0f;
#endif


float max_diff = 0.0f;
for (int i=0; i<B*T*V; i++) {
if(i < 3) {
if(i < 10) {
printf("%f %f\n", expected_logits[i], logits_cpu[i]);
}
if (fabsf(expected_logits[i] - logits_cpu[i]) >= accuracy_threshold) {
float diff = fabsf(expected_logits[i] - logits_cpu[i]);
max_diff = fmaxf(max_diff, diff);
if (diff >= logit_accuracy_threshold) {
printf("MISMATCH AT INDEX %d: ", i);
printf("%f %f\n", expected_logits[i],logits_cpu[i]);
logits_ok = 0;
break;
}
}
allok = allok && logits_ok;
if(!logits_ok) { printf("NOT "); }
printf("OK (LOGITS)\n");
printf("logit max diff: %f\n", max_diff);

// let's do 10 training iterations, following the pytorch code
float losses[10];
Expand All @@ -137,71 +203,63 @@ int main(int argc, char *argv[]) {
if (step == 0) {
// error checking at step 0 for reference activations


allok = allok && logits_ok;
free(logits_cpu_raw);
free(logits_cpu);

// compare the achieved loss
if (fabsf(model.mean_loss - *expected_loss) >= accuracy_threshold) {
if (fabsf(model.mean_loss - *expected_loss) >= loss_diff_threshold) {
printf("LOSS MISMATCH: %f %f\n", model.mean_loss, *expected_loss);
allok = 0;
} else {
printf("LOSS OK: %f %f\n", model.mean_loss, *expected_loss);
}

// and now compare the gradients on the parameters
// cudaMemcpy(calculated_grads.lnfw, model.grads.lnfw, C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.lnfb, model.grads.lnfb, C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.fcprojw, model.grads.fcprojw, L * C * 4*C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.fcprojb, model.grads.fcprojb, L * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.fcw, model.grads.fcw, L * 4*C * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.fcb, model.grads.fcb, L * 4*C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.ln2w, model.grads.ln2w, L * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.ln2b, model.grads.ln2b, L * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.attprojw, model.grads.attprojw, L * C * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.attprojb, model.grads.attprojb, L * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.qkvw, model.grads.qkvw, L * 3*C * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.qkvb, model.grads.qkvb, L * 3*C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.ln1w, model.grads.ln1w, L * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.ln1b, model.grads.ln1b, L * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.wte, model.grads.wte, V * C * sizeof(float), cudaMemcpyDeviceToHost);
// cudaMemcpy(calculated_grads.wpe, model.grads.wpe, maxT * C * sizeof(float), cudaMemcpyDeviceToHost);
// check_tensor(calculated_grads.lnfb, expected_grads.lnfb, C, "lnfb");
// check_tensor(calculated_grads.lnfw, expected_grads.lnfw, C, "lnfw");
// check_tensor(calculated_grads.fcprojw, expected_grads.fcprojw, L * C * 4*C, "fcprojw");
// check_tensor(calculated_grads.fcprojb, expected_grads.fcprojb, L * C, "fcprojb");
// check_tensor(calculated_grads.fcw, expected_grads.fcw, L * 4*C * C, "fcw");
// check_tensor(calculated_grads.fcb, expected_grads.fcb, L * 4*C, "fcb");
// check_tensor(calculated_grads.ln2w, expected_grads.ln2w, L * C, "ln2w");
// check_tensor(calculated_grads.ln2b, expected_grads.ln2b, L * C, "ln2b");
// check_tensor(calculated_grads.attprojw, expected_grads.attprojw, L * C * C, "attprojw");
// check_tensor(calculated_grads.attprojb, expected_grads.attprojb, L * C, "attprojb");
// check_tensor(calculated_grads.qkvw, expected_grads.qkvw, L * 3*C * C, "qkvw");
// check_tensor(calculated_grads.qkvb, expected_grads.qkvb, L * 3*C, "qkvb");
// check_tensor(calculated_grads.ln1w, expected_grads.ln1w, L * C, "ln1w");
// check_tensor(calculated_grads.ln1b, expected_grads.ln1b, L * C, "ln1b");
// check_tensor(calculated_grads.wte, expected_grads.wte, V * C, "wte");
// check_tensor(calculated_grads.wpe, expected_grads.wpe, maxT * C, "wpe");

// get gradients from GPU and convert all non-FP32 gradients back to FP32 for check_tensor
cudaMemcpy(calculated_grads_memory, model.grads_memory, model.num_parameters * sizeof(floatX), cudaMemcpyDeviceToHost);
char* src_iterator = (char*)calculated_grads_memory;
float* dst_iterator = (float*)converted_grads_memory;
for (size_t i = 0; i < NUM_PARAMETER_TENSORS; i++) {
// move the (mixed precision) grads from GPU to CPU
cudaMemcpy(grads_memory_cpu, model.grads_memory, model.num_parameters_bytes, cudaMemcpyDeviceToHost);

// convert all gradients to float on the CPU
char* src_iterator = (char*)grads_memory_cpu; // can be lower precision, so we use char*
float* dst_iterator = (float*)grads_memory_cpu_float; // float*
float* exp_iterator = expected_grads_memory; // float* of expected gradients from Python
float* tensors1[NUM_PARAMETER_TENSORS];
float* tensors2[NUM_PARAMETER_TENSORS];
for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) {
if (model.param_sizeof[i] == sizeof(float)) {
// float tensor => copy over directly
memcpy(dst_iterator, src_iterator, model.param_elements[i] * sizeof(float));
} else {
assert(model.param_sizeof[i] == sizeof(floatX));
// low-precision tensor => convert to float
assert(model.param_sizeof[i] == sizeof(floatX)); // floatX is the single non-float supported atm
for (size_t j = 0; j < model.param_elements[i]; j++) {
dst_iterator[j] = ((floatX*)src_iterator)[j];
dst_iterator[j] = ((floatX*)src_iterator)[j]; // convert to float
}
}
// for convenience record the position of comparison for reality vs. expectation
tensors1[i] = dst_iterator; // reality
tensors2[i] = exp_iterator; // expectation
// advance the iterators
src_iterator += model.param_elements[i] * model.param_sizeof[i];
dst_iterator += model.param_elements[i];
exp_iterator += model.param_elements[i];
}
// compare the gradients ona the parameters all at once
check_tensor(converted_grads_memory, expected_grads_memory, model.num_parameters, "grads");

// compare the gradients on the parameters all at once, in fp32
// I set the tolerances manually by inspecting the gradient differences for
// a few elements of each tensor. bf16 looks ok but not amazing here.
// It's possible we have bugs lurking, or maybe it is bf16. Not 100% sure.
allok = allok & check_tensor(tensors1[0], tensors2[0], V * C, "wte", 6e-1f);
allok = allok & check_tensor(tensors1[1], tensors2[1], maxT * C, "wpe", 1e-2f);
allok = allok & check_tensor(tensors1[2], tensors2[2], L * 3*C * C, "qkvw", 9e-2); // hmm a bit high
allok = allok & check_tensor(tensors1[3], tensors2[3], L * 3*C, "qkvb", 3e-2f);
allok = allok & check_tensor(tensors1[4], tensors2[4], L * C * C, "attprojw", 3e-2f);
allok = allok & check_tensor(tensors1[5], tensors2[5], L * C, "attprojb", 3e-2f);
allok = allok & check_tensor(tensors1[6], tensors2[6], L * 4*C * C, "fcw", 9e-2f); // hmm a bit high
allok = allok & check_tensor(tensors1[7], tensors2[7], L * 4*C, "fcb", 9e-2f); // hmm a bit high
allok = allok & check_tensor(tensors1[8], tensors2[8], L * C * 4*C, "fcprojw", 9e-2f); // hmm a bit high
allok = allok & check_tensor(tensors1[9], tensors2[9], L * C, "fcprojb", 3e-2f);
allok = allok & check_tensor(tensors1[10], tensors2[10], L * C, "ln1w", 0.1f); // hmm bit higher
allok = allok & check_tensor(tensors1[11], tensors2[11], L * C, "ln1b", 3e-2f);
allok = allok & check_tensor(tensors1[12], tensors2[12], L * C, "ln2w", 0.1f); // hmm bit higher
allok = allok & check_tensor(tensors1[13], tensors2[13], L * C, "ln2b", 3e-2f);
allok = allok & check_tensor(tensors1[14], tensors2[14], C, "lnfw", 0.12f); // hmm bit higher
allok = allok & check_tensor(tensors1[15], tensors2[15], C, "lnfb", 3e-2f);
}

gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.01f, step+1);
Expand All @@ -227,7 +285,7 @@ int main(int argc, char *argv[]) {

// compare
for (int i = 0; i < 10; i++) {
if (fabsf(losses[i] - expected_losses[i]) >= accuracy_threshold) {
if (fabsf(losses[i] - expected_losses[i]) >= loss_diff_threshold) {
printf("LOSS MISMATCH AT STEP %d: %f %f\n", i, losses[i], expected_losses[i]);
allok = 0;
} else {
Expand All @@ -241,11 +299,13 @@ int main(int argc, char *argv[]) {
// free everything
free(x);
free(y);
free(logits_cpu_raw);
free(logits_cpu);
free(expected_logits);
free(expected_loss);
free(expected_grads_memory);
free(calculated_grads_memory);
free(converted_grads_memory);
free(grads_memory_cpu);
free(grads_memory_cpu_float);
gpt2_free(&model);
cudaCheck(cudaFree(cublaslt_workspace));
cublasCheck(cublasDestroy(cublas_handle));
Expand Down