Skip to content

Commit

Permalink
resolve merge and small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed May 2, 2024
2 parents 50714d2 + 41a0789 commit 79505bc
Showing 1 changed file with 65 additions and 9 deletions.
74 changes: 65 additions & 9 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ This reads & runs in fp32, B=4, T=64, LR=1e-4, val/sample never (200),
-a 1 is "overfit single batch", -x 10 is 10 iterations, and -f 0 disables tf32
*/

#include <string>

#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
Expand All @@ -45,11 +47,14 @@ This reads & runs in fp32, B=4, T=64, LR=1e-4, val/sample never (200),
#include <assert.h>
// GPU / CUDA related
#include <cublas_v2.h>
#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <cublasLt.h>
#include <cuda_bf16.h>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <nvtx3/nvToolsExt.h>

// Multi-GPU related
#ifdef MULTI_GPU
#include <mpi.h>
Expand Down Expand Up @@ -128,6 +133,18 @@ static void* cudnn_workspace = NULL;
// ----------------------------------------------------------------------------
// CUDA utils

// Profiler utils
class NvtxRange {
public:
NvtxRange(const char* s) { nvtxRangePush(s); }
NvtxRange(const std::string& base_str, int number) {
std::string range_string = base_str + " " + std::to_string(number);
nvtxRangePush(range_string.c_str());
}
~NvtxRange() { nvtxRangePop(); }
};
#define NVTX_RANGE_FN() NvtxRange nvtx_range(__FUNCTION__)

// cuBLAS workspace. Hardcoding to 32MiB but only Hopper needs 32, for others 4 is OK
static size_t cublaslt_workspace_size = 32 * 1024 * 1024;
static void* cublaslt_workspace = NULL;
Expand Down Expand Up @@ -647,6 +664,7 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
float* stats, // output for backward pass: (B, NH, T)
floatX* inp, // input: (B, T, 3, NH, HS) QKV
int B, int T, int NH, int C) {
NVTX_RANGE_FN();
int HS = C / NH; // number of features per head
bool is_inference_only = (stats == nullptr);

Expand Down Expand Up @@ -688,6 +706,7 @@ void attention_forward_cudnn(floatX* out, // output: (B, T, NH, HS)
void attention_backward_cudnn(floatX* dqkvr, // output
floatX* dout, floatX* qkvr, floatX* o, float* stats, // inputs
int B, int T, int NH, int C) {
NVTX_RANGE_FN();
int HS = C / NH; // number of features per head

// Get graph and tensors from cache (or generate it on first use)
Expand Down Expand Up @@ -1353,6 +1372,7 @@ __global__ void copy_and_cast_kernel(float* dst, const floatX* src, size_t n) {
void encoder_forward(floatX* out,
const int* inp, const floatX* wte, const floatX* wpe,
int B, int T, int C) {
NVTX_RANGE_FN();
const int block_size = 256;
const int N = B * T * C;
const int grid_size = CEIL_DIV(N, (int)(block_size * x128::size));
Expand All @@ -1363,6 +1383,7 @@ void encoder_forward(floatX* out,
void encoder_backward(floatX* dwte, floatX* dwpe,
const floatX* dout, const int* inp,
int B, int T, int C) {
NVTX_RANGE_FN();
const int N = B * T * C;
const int block_size = 256;
const int grid_size = CEIL_DIV(N, block_size);
Expand All @@ -1373,6 +1394,7 @@ void encoder_backward(floatX* dwte, floatX* dwpe,
void layernorm_forward(floatX* out, floatX* mean, floatX* rstd,
floatX* inp, floatX* weight, floatX* bias,
int B, int T, int C) {
NVTX_RANGE_FN();
const int block_size = 512;
const int N = B * T;
const int grid_size = CEIL_DIV(N * 32, block_size);
Expand All @@ -1386,6 +1408,7 @@ void layernorm_forward(floatX* out, floatX* mean, floatX* rstd,
void matmul_forward_cublaslt(floatX* out,
floatX* inp, floatX* weight, floatX* bias,
int B, int T, int C, int OC) {
NVTX_RANGE_FN();
int has_bias = (bias != NULL);

// check bias alignment
Expand Down Expand Up @@ -1465,6 +1488,7 @@ void matmul_forward_cublaslt(floatX* out,
void attention_forward(floatX* out, floatX* qkvr, floatX* att,
floatX* inp,
int B, int T, int C, int NH) {
NVTX_RANGE_FN();
// Note: `inp` is not needed for backward pass, so we re-use it as a scratch buffer.
// Its contents will be overwritten by this function.
const int block_size = 256;
Expand Down Expand Up @@ -1536,20 +1560,23 @@ void attention_forward(floatX* out, floatX* qkvr, floatX* att,
}

void residual_forward(floatX* out, floatX* inp1, floatX* inp2, int N) {
NVTX_RANGE_FN();
const int block_size = 256;
const int grid_size = CEIL_DIV(N, block_size * x128::size);
residual_forward_kernel<<<grid_size, block_size>>>(out, inp1, inp2, N);
cudaCheck(cudaGetLastError());
}

void gelu_forward(floatX* out, const floatX* inp, int N) {
NVTX_RANGE_FN();
const int block_size = 512;
const int grid_size = CEIL_DIV(N, block_size * x128::size);
gelu_forward_kernel2<<<grid_size, block_size>>>(out, inp, N);
cudaCheck(cudaGetLastError());
}

void gelu_backward(floatX* dinp, const floatX* inp, const floatX* dout, const int N) {
NVTX_RANGE_FN();
const int block_size = 128;
const int grid_size = CEIL_DIV(N, block_size * x128::size);
gelu_backward_kernel<<<grid_size, block_size>>>(dinp, inp, dout, N);
Expand All @@ -1559,6 +1586,7 @@ void gelu_backward(floatX* dinp, const floatX* inp, const floatX* dout, const in
void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias,
floatX* dout, floatX* inp, floatX* weight,
int B, int T, int C, int OC) {
NVTX_RANGE_FN();
float one = 1.0f;
float zero = 0.0f;
// backward to input, uses = in the backward pass (set the gradient)
Expand All @@ -1581,6 +1609,7 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias,
void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch,
const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd,
int B, int T, int C) {
NVTX_RANGE_FN();
const int block_size = 1024;
const int grid_size = 1 * cuda_num_SMs;
size_t shared_mem_size = (2 * C + 1) * sizeof(float);
Expand All @@ -1595,6 +1624,7 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* dpreatt, floatX* da
const floatX* dout,
const floatX* qkvr, const floatX* att,
int B, int T, int C, int NH) {
NVTX_RANGE_FN();
const int block_size = 256;
int HS = C / NH; // head size

Expand Down Expand Up @@ -1655,6 +1685,7 @@ template <typename Type>
void fused_classifier3(Type* logits, Type* losses,
const Type* dlosses, const int* targets,
int B, int T, int V, int P) {
NVTX_RANGE_FN();
const int block_size = 1024;
const int N = B * T;
const int grid_size = N;
Expand Down Expand Up @@ -1984,6 +2015,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
}

void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {
NVTX_RANGE_FN();
// targets are optional and could be NULL
// in this function we must be careful and use size_t instead of int, otherwise
// we could overflow int. E.g. l * B * NH * T * T overflows int at B 16.
Expand Down Expand Up @@ -2049,6 +2081,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {
encoder_forward(acts.encoded, model->inputs, params.wte, params.wpe, B, T, C); // encoding goes into residual[0]

for (int l = 0; l < L; l++) {
NvtxRange layer_range("Layer", l);

residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;

Expand Down Expand Up @@ -2113,6 +2146,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {

// also forward the cross-entropy loss function if we have the targets
if (targets != NULL) {
NvtxRange classifier_and_loss_range("classifier_and_loss");
// fused classifier: does the forward pass and first part of the backward pass
// we're passing dlosses = NULL, which will default them to 1.0f/(B*T), i.e. uniform loss
fused_classifier3(acts.output, acts.losses, (floatX*)NULL, model->targets, B, T, V, Vp);
Expand All @@ -2123,19 +2157,20 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {
for (int i=0; i<B*T; i++) { mean_loss += (float)(model->cpu_losses[i]); }
mean_loss /= B*T;
model->mean_loss = mean_loss;

} else {
// if we don't have targets, we don't have loss
model->mean_loss = -1.0f;
}
}

void gpt2_zero_grad(GPT2 *model) {
NVTX_RANGE_FN();
if (model->grads_acts_memory != NULL) { cudaCheck(cudaMemset(model->grads_acts_memory, 0, model->num_grad_acts * sizeof(floatX))); }
if (model->grads_memory != NULL) { cudaCheck(cudaMemset(model->grads_memory, 0, model->num_parameters * sizeof(floatX))); }
}

void gpt2_backward(GPT2 *model) {
NVTX_RANGE_FN();
// double check we forwarded previously, with targets
if (model->mean_loss == -1.0f) {
printf("Error: must forward with targets before backward\n");
Expand Down Expand Up @@ -2192,6 +2227,8 @@ void gpt2_backward(GPT2 *model) {

// now backward all the layers
for (int l = L-1; l >= 0; l--) {
NvtxRange layer_range("Layer", l);

residual = l == 0 ? acts.encoded : acts.residual3 + (l-1) * B * T * C;

// get the pointers of the weights for this layer
Expand Down Expand Up @@ -2279,6 +2316,7 @@ float multi_gpu_cpu_float_mean(float value, const MultiGpuConfig* multi_gpu_conf
// Averages out the loss and gradients across all GPUs. No-op when multi-GPU is disabled.
// todo - this version only works if all the parameters are the same size (floatX)
void gpt2_multi_gpu_accumulate(GPT2* model, MultiGpuConfig* multi_gpu_config) {
NVTX_RANGE_FN();
// Average all losses.
model->accumulated_mean_loss = multi_gpu_cpu_float_mean(model->mean_loss, multi_gpu_config);
#ifdef MULTI_GPU
Expand All @@ -2293,6 +2331,7 @@ void gpt2_multi_gpu_accumulate(GPT2* model, MultiGpuConfig* multi_gpu_config) {
}

void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t) {
NVTX_RANGE_FN();
// reference: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html

// lazily allocate the memory for m_memory and v_memory
Expand Down Expand Up @@ -2398,6 +2437,7 @@ void dataloader_reset(DataLoader *loader) {
}

void dataloader_next_batch(DataLoader *loader) {
NVTX_RANGE_FN();
size_t B = loader->B;
size_t T = loader->T;
// if we are at the end of the file, loop back to the beginning
Expand Down Expand Up @@ -2638,13 +2678,19 @@ int main(int argc, char *argv[]) {
float* cpu_logits = (float*)mallocCheck(model.config.vocab_size * sizeof(float));

// train
struct timespec start, end;
cudaEvent_t start, end;
cudaCheck(cudaEventCreate(&start));
cudaCheck(cudaEventCreate(&end));
cudaCheck(cudaProfilerStart());
double total_sum_iteration_time_s = 0.0;
for (int step = 0; step <= train_num_batches; step++) {
NvtxRange step_range("Train step", step);

int last_step = step == train_num_batches;

// once in a while estimate the validation loss
if (step % val_loss_every == 0 || last_step) {
NvtxRange validation_range("validation");
float val_loss = 0.0f;
dataloader_reset(&val_loader);
for (int i = 0; i < val_num_batches; i++) {
Expand All @@ -2660,6 +2706,7 @@ int main(int argc, char *argv[]) {

// once in a while do model inference to print generated text
if (multi_gpu_config.process_rank == 0 && (step > 0 && (step % sample_every) == 0 || last_step)) {
NvtxRange generation_range("generation");
// fill up gen_tokens with the <|endoftext|> token, which kicks off the generation
int eot_token = tokenizer.eot_token;
for(int i = 0; i < B * T; ++i) {
Expand All @@ -2668,6 +2715,7 @@ int main(int argc, char *argv[]) {
// now sample from the model autoregressively
printf("generating:\n---\n");
for (int t = 1; t < genT; t++) {
NvtxRange generation_range("Generation step", t);
// note that inference is very wasteful here because for each token
// we re-calculate the forward pass for all of (B,T) positions from scratch
// but the inference here is just for sanity checking anyway
Expand Down Expand Up @@ -2708,34 +2756,42 @@ int main(int argc, char *argv[]) {
if (last_step) { break; }

// do a training step
clock_gettime(CLOCK_MONOTONIC, &start);
cudaEventRecord(start);
if (overfit_single_batch == 0 || (step == 0 && overfit_single_batch == 1)) {
// if we're overfitting a single batch, we'll only call this at step = 0
dataloader_next_batch(&train_loader);
}
dataloader_next_batch(&train_loader);
gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T);
gpt2_zero_grad(&model);
gpt2_backward(&model);
if (multi_gpu_config.num_processes > 1) {
gpt2_multi_gpu_accumulate(&model, &multi_gpu_config);
}
gpt2_update(&model, learning_rate, 0.9f, 0.999f, 1e-8f, 0.0f, step+1);
cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings
clock_gettime(CLOCK_MONOTONIC, &end);
double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;

cudaEventRecord(end);
float time_elapsed_ms;
cudaCheck(cudaEventSynchronize(end)); // wait for the end event to finish to get correct timings
cudaCheck(cudaEventElapsedTime(&time_elapsed_ms, start, end));

if (step > 0) { // consider the first batch to be a warmup (e.g. cuBLAS/cuDNN initialisation)
total_sum_iteration_time_s += time_elapsed_s;
total_sum_iteration_time_s += time_elapsed_ms / 1000.0;
}
int tokens_per_second = multi_gpu_config.num_processes * (B * T) / time_elapsed_s;
int tokens_per_second = multi_gpu_config.num_processes * (B * T) / time_elapsed_ms * 1000.0;
float accumulated_loss = multi_gpu_config.num_processes == 1 ? model.mean_loss : model.accumulated_mean_loss;
printf0("step %4d/%d: train loss %f (acc %f) (%f ms, %d tok/s)\n", step + 1, train_num_batches, model.mean_loss, accumulated_loss, time_elapsed_s * 1000, tokens_per_second);
printf0("step %4d/%d: train loss %f (acc %f) (%f ms, %d tok/s)\n", step + 1, train_num_batches, model.mean_loss, accumulated_loss, time_elapsed_ms, tokens_per_second);
logger_log_train(&logger, step, model.mean_loss);

// disable the profiler after 10 steps of optimization
if (step == 10) { cudaProfilerStop(); }
}
// add a total average, for optimizations that are only mild improvements (excluding 1st batch as warmup)
printf0("total average iteration time: %f ms\n", total_sum_iteration_time_s / (train_num_batches-1) * 1000);

// free and destroy everything
cudaCheck(cudaEventDestroy(end));
cudaCheck(cudaEventDestroy(start));
dataloader_free(&train_loader);
dataloader_free(&val_loader);
tokenizer_free(&tokenizer);
Expand Down

0 comments on commit 79505bc

Please sign in to comment.