Skip to content

Commit

Permalink
add a first, super simple logger
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Apr 20, 2024
1 parent acdcfea commit 4e3b7ea
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1798,16 +1798,48 @@ void tokenizer_free(Tokenizer *tokenizer) {
}
}

// ----------------------------------------------------------------------------
// Logger lite, will probably grow/change some over time

typedef struct {
FILE *logfile;
int flush_every; // every how many steps to flush the log
} Logger;

void logger_init(Logger *logger, const char *filename) {
logger->flush_every = 20;
logger->logfile = NULL;
if (filename != NULL) { logger->logfile = fopenCheck(filename, "w"); }
}

void logger_log_val(Logger *logger, int step, float val_loss) {
if (logger->logfile != NULL) {
fprintf(logger->logfile, "s:%d tel:%.4f\n", step, val_loss);
}
}

void logger_log_train(Logger *logger, int step, float train_loss) {
if (logger->logfile != NULL) {
fprintf(logger->logfile, "s:%d trl:%.4f\n", step, train_loss);
if (step % 10 == 0) { fflush(logger->logfile); }
}
}

void logger_free(Logger *logger) {
if (logger->logfile != NULL) { fclose(logger->logfile); }
}

// ----------------------------------------------------------------------------
// CLI, poor man's argparse

void error_usage() {
// default run = debugging run with TinyShakespeare
// bigger run = train on TinyStories! e.g. val/sample less often, but sample more tokens
// bigger run = train on TinyStories! e.g. val/sample less often, but sample more tokens, write to logfile
fprintf(stderr, "Usage: ./train_gpt2cu [options]\n");
fprintf(stderr, "Example: ./train_gpt2cu -i data/TinyStories -v 100 -s 100 -g 144\n");
fprintf(stderr, "Example: ./train_gpt2cu -i data/TinyStories -v 100 -s 100 -g 144 -o stories.log\n");
fprintf(stderr, "Options:\n");
fprintf(stderr, " -i <string> input dataset prefix (default = data/tiny_shakespeare)\n");
fprintf(stderr, " -o <string> output log file (default = NULL)\n");
fprintf(stderr, " -b <int> batch size B (default = 4)\n");
fprintf(stderr, " -t <int> sequence length T (default = 1024)\n");
fprintf(stderr, " -l <float> learning rate (default = 1e-4f)\n");
Expand All @@ -1824,6 +1856,7 @@ int main(int argc, char *argv[]) {

// read in the (optional) command line arguments
const char* input_dataset_prefix = "data/tiny_shakespeare"; // or e.g. data/TinyStories
const char* output_log_file = NULL;
int B = 4; // batch size
int T = 1024; // sequence length max
float learning_rate = 1e-4f;
Expand All @@ -1837,6 +1870,7 @@ int main(int argc, char *argv[]) {
if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
// read in the args
if (argv[i][1] == 'i') { input_dataset_prefix = argv[i+1]; }
else if (argv[i][1] == 'o') { output_log_file = argv[i+1]; }
else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); }
else if (argv[i][1] == 't') { T = atoi(argv[i+1]); }
else if (argv[i][1] == 'l') { learning_rate = atof(argv[i+1]); }
Expand All @@ -1847,6 +1881,7 @@ int main(int argc, char *argv[]) {
else { error_usage(); }
}
printf("input dataset prefix: %s\n", input_dataset_prefix);
printf("output log file: %s\n", output_log_file == NULL ? "NULL" : output_log_file);
printf("batch size B: %d\n", B);
printf("sequence length T: %d\n", T);
printf("learning rate: %f\n", learning_rate);
Expand Down Expand Up @@ -1895,6 +1930,10 @@ int main(int argc, char *argv[]) {
printf("train dataset num_batches: %d\n", train_loader.num_batches);
printf("val dataset num_batches: %d\n", val_loader.num_batches);

// set up the logfile
Logger logger;
logger_init(&logger, output_log_file);

// build the Tokenizer
Tokenizer tokenizer;
tokenizer_init(&tokenizer, "gpt2_tokenizer.bin");
Expand All @@ -1921,6 +1960,7 @@ int main(int argc, char *argv[]) {
}
val_loss /= val_num_batches;
printf("val loss %f\n", val_loss);
logger_log_val(&logger, step, val_loss);
}

// once in a while do model inference to print generated text
Expand Down Expand Up @@ -1978,6 +2018,7 @@ int main(int argc, char *argv[]) {
double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;
total_sum_iteration_time_s += time_elapsed_s;
printf("step %d/%d: train loss %f (%f ms)\n", step + 1, train_num_batches, model.mean_loss, time_elapsed_s * 1000);
logger_log_train(&logger, step, model.mean_loss);
}
// add a total average, for optimizations that are only mild improvements
printf("total average iteration time: %f ms\n", total_sum_iteration_time_s / train_num_batches * 1000);
Expand All @@ -1992,6 +2033,7 @@ int main(int argc, char *argv[]) {
cudaCheck(cudaFree(cublaslt_workspace));
cublasCheck(cublasDestroy(cublas_handle));
cublasCheck(cublasLtDestroy(cublaslt_handle));
logger_free(&logger);

return 0;
}
Expand Down

0 comments on commit 4e3b7ea

Please sign in to comment.