Skip to content

Commit

Permalink
moved loss calculation to backward part
Browse files Browse the repository at this point in the history
  • Loading branch information
ngc92 committed May 23, 2024
1 parent 8ccc1fd commit c728993
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 46 deletions.
12 changes: 6 additions & 6 deletions test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -203,19 +203,19 @@ int main(int argc, char *argv[]) {
clock_gettime(CLOCK_MONOTONIC, &start);
gpt2_forward(&model, x, y, B, T);
gpt2_zero_grad(&model);
gpt2_backward(&model);
float mean_loss = gpt2_backward(&model);
clock_gettime(CLOCK_MONOTONIC, &end);
double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;

if (step == 0) {
// error checking at step 0 for reference activations

// compare the achieved loss
if (fabsf(model.mean_loss - *expected_loss) >= loss_diff_threshold) {
printf("LOSS MISMATCH: %f %f\n", model.mean_loss, *expected_loss);
if (fabsf(mean_loss - *expected_loss) >= loss_diff_threshold) {
printf("LOSS MISMATCH: %f %f\n", mean_loss, *expected_loss);
allok = 0;
} else {
printf("LOSS OK: %f %f\n", model.mean_loss, *expected_loss);
printf("LOSS OK: %f %f\n", mean_loss, *expected_loss);
}

// move the (mixed precision) grads from GPU to CPU
Expand Down Expand Up @@ -277,8 +277,8 @@ int main(int argc, char *argv[]) {
gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.01f, 1.f, step+1, &multi_gpu_config);

// print the timing information at the end
printf("step %d: loss %f (took %f ms)\n", step+1, model.mean_loss, time_elapsed_s * 1000);
losses[step] = model.mean_loss;
printf("step %d: loss %f (took %f ms)\n", step+1, mean_loss, time_elapsed_s * 1000);
losses[step] = mean_loss;
}

// expected losses are as follows, from Python
Expand Down
119 changes: 79 additions & 40 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1260,7 +1260,7 @@ template <bool WriteLogits = true, bool WriteProbs = false>
__global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)
fused_classifier_kernel5(floatX* logits, floatX* losses, floatX* probs,
const float dloss, const int* targets,
int B, int T, int V, int P) {
int B, int T, int V, int P, std::bool_constant<WriteLogits>) {
int idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data
int ix = targets[idx];

Expand Down Expand Up @@ -1672,15 +1672,18 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* dpreatt, floatX* da
}

// replaces logits with logit gradients
template <typename Type>
template <typename Type, bool WriteLogits>
void fused_classifier(Type* logits, Type* losses,
const float dloss, const int* targets,
int B, int T, int V, int P, cudaStream_t stream) {
int B, int T, int V, int P,
std::bool_constant<WriteLogits> write_logits,
cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 1024;
const int N = B * T;
const int grid_size = N;
fused_classifier_kernel5<<<grid_size, block_size, 512, stream>>>(logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P);
fused_classifier_kernel5<<<grid_size, block_size, 512, stream>>>(logits, losses, (floatX*) NULL, dloss, targets,
B, T, V, P, write_logits);
cudaCheck(cudaGetLastError());
}

Expand Down Expand Up @@ -1962,7 +1965,7 @@ typedef struct {
int seq_len; // the sequence length (T) of current forward pass
int* inputs; // the input tokens for the current forward pass
int* targets; // the target tokens for the current forward pass
float mean_loss; // after a forward pass with targets, will be populated with the mean loss
bool has_targets; // set to true if the forward pass populated targets
float accumulated_mean_loss; // Mean loss after aggregating it on all GPUs
floatX* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost
float* cpu_losses_fp32; // same but fp32
Expand Down Expand Up @@ -2044,11 +2047,11 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
model->grads_acts_memory = NULL;
model->inputs = NULL;
model->targets = NULL;
model->has_targets = false;
model->cpu_losses = NULL;
model->cpu_losses_fp32 = NULL;
model->batch_size = 0;
model->seq_len = 0;
model->mean_loss = -1.0f; // -1.0f will designate no loss
model->rng_state = 13371337;
model->use_master_weights = 1; // keep master weights copy in float for optim update?
model->recompute = 1; // default to recompute gelu during backward
Expand All @@ -2057,7 +2060,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
cudaCheck(cudaDeviceSynchronize());
}

void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B, size_t T, int grad_accum_steps=1) {
void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B, size_t T) {
// right now, this function is fully synchronous with the host
NVTX_RANGE_FN();
// targets are optional and could be NULL
Expand Down Expand Up @@ -2111,6 +2114,9 @@ void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B,
cudaCheck(cudaMemcpyAsync(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice, main_stream));
if (targets != NULL) {
cudaCheck(cudaMemcpyAsync(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice, main_stream));
model->has_targets = true;
} else {
model->has_targets = false;
}

// validate inputs, all indices must be in the range [0, V)
Expand Down Expand Up @@ -2200,28 +2206,38 @@ void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B,
}

matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream);
cudaCheck(cudaDeviceSynchronize());
}

// also forward the cross-entropy loss function if we have the targets
if (targets != NULL) {
float gpt2_validate(GPT2 *model) {
// convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow
const size_t B = model->batch_size;
const size_t T = model->seq_len;
const size_t V = model->config.vocab_size;
const size_t Vp = model->config.padded_vocab_size;

ActivationTensors acts = model->acts;

float mean_loss = 0.0f;
if (model->has_targets) {
NvtxRange classifier_and_loss_range("classifier_and_loss");
// fused classifier: does the forward pass and first part of the backward pass
const float dloss = 1.0f / (B * T * grad_accum_steps); // results in the uniform average loss over all elements
fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, main_stream);
// for convenience also evaluate the mean loss (TODO re-think this compute+sync point)
const float dloss = 1.0f / (B * T); // results in the uniform average loss over all elements
// note: we don't need to generate dlogits here
fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, std::bool_constant<false>{}, model->main_stream);
cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(floatX), cudaMemcpyDeviceToHost));
float mean_loss = 0.0f;
for (int i = 0; i < B*T; i++) {
float loss = (float)(model->cpu_losses[i]);
model->cpu_losses_fp32[i] = loss;
mean_loss += loss;
}
mean_loss /= B*T*grad_accum_steps;
model->mean_loss = mean_loss;
mean_loss /= B*T;
} else {
// if we don't have targets, we don't have loss
model->mean_loss = -1.0f;
printf("Error: must forward with targets before validate\n");
exit(EXIT_FAILURE);
}
cudaCheck(cudaDeviceSynchronize());
return mean_loss;
}

void gpt2_zero_grad(GPT2 *model) {
Expand All @@ -2232,10 +2248,36 @@ void gpt2_zero_grad(GPT2 *model) {
cudaCheck(cudaDeviceSynchronize());
}

void gpt2_backward(GPT2 *model) {
float gpt2_backward(GPT2 *model, int grad_accum_steps=1) {
NVTX_RANGE_FN();

// convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow
const size_t B = model->batch_size;
const size_t T = model->seq_len;
const size_t V = model->config.vocab_size;
const size_t Vp = model->config.padded_vocab_size;
const size_t L = model->config.num_layers;
const size_t NH = model->config.num_heads;
const size_t C = model->config.channels;

ActivationTensors acts = model->acts;
cudaStream_t main_stream = model->main_stream;

cudaEvent_t losses_ready;

// double check we forwarded previously, with targets
if (model->mean_loss == -1.0f) {
// also forward the cross-entropy loss function if we have the targets
float mean_loss = 0.0f;
if (model->has_targets) {
NvtxRange classifier_and_loss_range("classifier_and_loss");
// fused classifier: does the forward pass and first part of the backward pass
const float dloss = 1.0f / (B * T * grad_accum_steps); // results in the uniform average loss over all elements
fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, std::bool_constant<true>{}, main_stream);

cudaCheck(cudaEventCreateWithFlags(&losses_ready, cudaEventDisableTiming | cudaEventBlockingSync));
cudaCheck(cudaMemcpyAsync(model->cpu_losses, acts.losses, B * T * sizeof(floatX), cudaMemcpyDeviceToHost, main_stream));
cudaCheck(cudaEventRecord(losses_ready, main_stream));
} else {
printf("Error: must forward with targets before backward\n");
exit(EXIT_FAILURE);
}
Expand All @@ -2261,22 +2303,11 @@ void gpt2_backward(GPT2 *model) {
gpt2_zero_grad(model);
}

// convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow
const size_t B = model->batch_size;
const size_t T = model->seq_len;
const size_t Vp = model->config.padded_vocab_size;
const size_t L = model->config.num_layers;
const size_t NH = model->config.num_heads;
const size_t C = model->config.channels;

// backward pass: go in the reverse order of the forward pass, and call backward() functions
ParameterTensors params = model->params; // for brevity
ParameterTensors grads = model->grads;
ActivationTensors acts = model->acts;
GradActTensors grads_acts = model->grads_acts;

cudaStream_t main_stream = model->main_stream;

// reset residual stream gradients (put here to work with gradient accumulation)
cudaCheck(cudaMemsetAsync(model->grads_acts.residual3, 0, B * T * C * sizeof(floatX), main_stream));

Expand Down Expand Up @@ -2374,7 +2405,16 @@ void gpt2_backward(GPT2 *model) {
}
encoder_backward(grads.wte, grads.wpe, dresidual, model->inputs, B, T, C, random_u32(&model->rng_state), main_stream);

// now we have enqueued the entire backward pass on the GPU. wait and relax until we have
// the losses ready, then sum them up concurrently to the GPU work.
cudaCheck(cudaEventSynchronize(losses_ready));
cudaCheck(cudaEventDestroy(losses_ready));
for (int i = 0; i < B*T; i++) { mean_loss += (float)(model->cpu_losses[i]); }
mean_loss /= B*T*grad_accum_steps;

cudaCheck(cudaDeviceSynchronize());

return mean_loss;
}

// Compute sum of a single CPU value across all GPU processes. No-op when multi-GPU is disabled.
Expand All @@ -2391,12 +2431,12 @@ float multi_gpu_cpu_float_sum(float value) {

// Averages out the loss and gradients across all GPUs. No-op when multi-GPU is disabled.
// todo - this version only works if all the parameters are the same size (floatX)
void gpt2_multi_gpu_accumulate(GPT2* model, MultiGpuConfig* multi_gpu_config) {
void gpt2_multi_gpu_accumulate(GPT2* model, MultiGpuConfig* multi_gpu_config, float local_loss) {
#ifdef MULTI_GPU
NVTX_RANGE_FN();
if (multi_gpu_config->num_processes == 1) { return; }
// Average all losses.
model->accumulated_mean_loss = multi_gpu_cpu_float_sum(model->mean_loss) / multi_gpu_config->num_processes;
model->accumulated_mean_loss = multi_gpu_cpu_float_sum(local_loss) / multi_gpu_config->num_processes;
if(multi_gpu_config->zero_stage == 0) {
// no ZERO == standard DDP: Average all gradients.
ncclCheck(ncclAllReduce(model->grads_memory, model->grads_memory,
Expand Down Expand Up @@ -2789,7 +2829,7 @@ int main(int argc, char *argv[]) {
for (int i = 0; i < val_num_batches; i++) {
dataloader_next_batch(&val_loader);
gpt2_forward(&model, val_loader.inputs, val_loader.targets, B, T);
val_loss += model.mean_loss;
val_loss += gpt2_validate(&model);
}
val_loss /= val_num_batches;
val_loss = multi_gpu_cpu_float_sum(val_loss) / multi_gpu_config.num_processes;
Expand All @@ -2807,6 +2847,7 @@ int main(int argc, char *argv[]) {
if (i % 10 == 0) { printf("evaluating HellaSwag: %d/%d\r", i, eval_loader.num_batches); }
evalloader_next_batch(&eval_loader);
gpt2_forward(&model, eval_loader.inputs, eval_loader.targets, B, T);
gpt2_validate(&model);
int correct = evalloader_stat_losses(&eval_loader, model.cpu_losses_fp32);
eval_acc_norm += (float)correct;
}
Expand Down Expand Up @@ -2880,16 +2921,14 @@ int main(int argc, char *argv[]) {
dataloader_next_batch(&train_loader);
}
// forward pass. note that we pass in grad_accum_steps, which scales down the loss
gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T, grad_accum_steps);
lossf += model.mean_loss; // the mean_loss was normalized by grad_accum_steps inside gpt2_forward
gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T);
// backward pass. all model params accumulate gradients with += inside this inner loop
gpt2_backward(&model);
lossf += gpt2_backward(&model);
}
// override the mean loss, accounting for the gradient accumulation loop
// this is esp important to do here in multigpu update below, where model.mean_loss gets allreduced
model.mean_loss = lossf;
// update the parameters
gpt2_multi_gpu_accumulate(&model, &multi_gpu_config);
gpt2_multi_gpu_accumulate(&model, &multi_gpu_config, lossf);
float grad_norm = gpt2_update(&model, learning_rate, 0.9f, 0.999f, 1e-8f, 0.0f, grad_clip, step+1, &multi_gpu_config);
gpt2_multi_gpu_gather(&model, &multi_gpu_config);
// zero out the gradients for the next iteration
Expand All @@ -2911,11 +2950,11 @@ int main(int argc, char *argv[]) {
ema_tokens_per_second = 0.95f * ema_tokens_per_second + 0.05f * tokens_per_second;
bias_corrected_ema_tokens_per_second = ema_tokens_per_second / (1.0f - powf(0.95f, step));
}
float accumulated_loss = multi_gpu_config.num_processes == 1 ? model.mean_loss : model.accumulated_mean_loss;
float accumulated_loss = multi_gpu_config.num_processes == 1 ? lossf : model.accumulated_mean_loss;
printf0("step %4d/%d: train loss %f norm %.4f (%.2f ms, %.0f tok/s)\n",
step + 1, train_num_batches, accumulated_loss, grad_norm,
time_elapsed_ms, bias_corrected_ema_tokens_per_second);
logger_log_train(&logger, step, model.mean_loss);
logger_log_train(&logger, step, lossf);

// disable the profiler after 3 steps of optimization
if (step == 3) { cudaProfilerStop(); }
Expand Down

0 comments on commit c728993

Please sign in to comment.