Skip to content

Commit

Permalink
fixes to keep master copy in fp32 of weights optionally
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed May 1, 2024
1 parent fd474fe commit 795f8b6
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1176,20 +1176,15 @@ __global__ void adamw_kernel3(Tp* params_memory, float* master_params, Tg* grads
m /= beta1_correction; // m_hat
v /= beta2_correction; // v_hat
// update the parameters (weight/bias)
float old_param;
if(master_params) {
old_param = master_params[i];
} else {
old_param = (float)params_memory[i];
}
float old_param = master_params != NULL ? master_params[i] : (float)params_memory[i];
float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * old_param));
// if we have master parameters, directly update the two weight copies
if(master_params) {
params_memory[i] = (floatX)param;;
master_params[i] = param;
if (master_params != NULL) {
params_memory[i] = (floatX)param; // low-precision copy, for use in the forward pass
master_params[i] = param; // float copy, for use in the next parameter update
} else {
// otherwise, one update with stochastic rounding
// todo - explain stochastic rounding here
// without a master copy of params in float, do a direct update in low precision
// and use stochastic rounding to mitigate loss of training stability
unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed);
stochastic_rounding(param, &params_memory[i], random);
}
Expand Down Expand Up @@ -1290,6 +1285,11 @@ __global__ void fused_classifier_kernel3(floatX* logits, floatX* losses, floatX*
}
}

__global__ void copy_kernel(float* dst, const floatX* src, size_t n) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) { dst[i] = (float)src[i]; }
}

// ----------------------------------------------------------------------------
// kernel launchers

Expand Down Expand Up @@ -1854,7 +1854,7 @@ typedef struct {
float accumulated_mean_loss; // Mean loss after aggregating it on all GPUs
floatX* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost
unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc.
bool use_master_weights;
int use_master_weights;
} GPT2;

void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
Expand Down Expand Up @@ -1923,7 +1923,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
model->seq_len = 0;
model->mean_loss = -1.0f; // -1.0f will designate no loss
model->rng_state = 13371337;
model->use_master_weights = false;
model->use_master_weights = 1; // keep master weights copy in float for optim update?
}

void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {
Expand Down Expand Up @@ -2246,9 +2246,10 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
cudaCheck(cudaMemset(model->v_memory, 0, model->num_parameters * sizeof(float)));
printf0("allocated %zu MiB for AdamW optimizer state m\n", (model->num_parameters * sizeof(float)) >> 20);
printf0("allocated %zu MiB for AdamW optimizer state v\n", (model->num_parameters * sizeof(float)) >> 20);
if(model->use_master_weights) {
// mixed precision, need to allocate one more buffer
if (model->use_master_weights == 1) {
// allocate one more buffer to keep the master copy of weights as float, and copy the weights over
cudaCheck(cudaMalloc((void**)&model->master_weights, model->num_parameters * sizeof(float)));
copy_kernel<<<CEIL_DIV(model->num_parameters, 512), 512>>>(model->master_weights, (floatX*)model->params_memory, model->num_parameters);
}
}

Expand Down Expand Up @@ -2431,7 +2432,7 @@ void error_usage() {
fprintf(stderr, " -g <int> genT, how many steps of inference we do (default = 64)\n");
fprintf(stderr, " -a <int> overfit a single batch? 0/1. useful for debugging\n");
fprintf(stderr, " -f <int> enable_tf32 override (default: 1, set to 0 to disable tf32)\n");
fprintf(stderr, " -w <int> keep f32 copy of weights for the optimizer\n");
fprintf(stderr, " -w <int> keep f32 copy of weights for the optimizer? (default: 1)\n");
exit(EXIT_FAILURE);
}

Expand All @@ -2453,7 +2454,7 @@ int main(int argc, char *argv[]) {
int overfit_single_batch = 0; // useful for debugging, 1 = only load a single data batch once
int max_steps = -1;
int override_enable_tf32 = 1;
bool master_weights = false;
int use_master_weights = 1;
for (int i = 1; i < argc; i+=2) {
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
if (argv[i][0] != '-') { error_usage(); } // must start with dash
Expand All @@ -2471,7 +2472,7 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'g') { genT = atoi(argv[i+1]); }
else if (argv[i][1] == 'a') { overfit_single_batch = atoi(argv[i+1]); }
else if (argv[i][1] == 'f') { override_enable_tf32 = atoi(argv[i+1]); }
else if (argv[i][1] == 'w') { master_weights = (bool)atoi(argv[i+1]); }
else if (argv[i][1] == 'w') { use_master_weights = atoi(argv[i+1]); }
else { error_usage(); }
}
printf0("+-----------------------+----------------------------------------------------+\n");
Expand All @@ -2488,7 +2489,7 @@ int main(int argc, char *argv[]) {
printf0("| sample_every | %-50d |\n", sample_every);
printf0("| genT | %-50d |\n", genT);
printf0("| overfit_single_batch | %-50d |\n", overfit_single_batch);
printf0("| master_weights | %-50b |\n", master_weights);
printf0("| use_master_weights | %-50s |\n", use_master_weights ? "enabled" : "disabled");
printf0("+-----------------------+----------------------------------------------------+\n");

// set up the device
Expand Down Expand Up @@ -2525,6 +2526,7 @@ int main(int argc, char *argv[]) {
// build the GPT-2 model from a checkpoint
GPT2 model;
gpt2_build_from_checkpoint(&model, load_filename);
model.use_master_weights = use_master_weights;
printf0("| load_filename | %-50s |\n", load_filename);
printf0("| max_sequence_length T | %-50d |\n", model.config.max_seq_len);
printf0("| vocab_size V | %-50d |\n", model.config.vocab_size);
Expand Down

0 comments on commit 795f8b6

Please sign in to comment.