Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/streams #552

Merged
merged 9 commits into from
Jun 8, 2024
Merged

Feature/streams #552

merged 9 commits into from
Jun 8, 2024

Conversation

karpathy
Copy link
Owner

@karpathy karpathy commented Jun 5, 2024

bringing back streams, this PR brings back a single "main stream" to start.

@karpathy
Copy link
Owner Author

karpathy commented Jun 5, 2024

Main difference is that I pulled out the main stream to be a global inside train_gpt2.cu, because I think the stream is not a property of the model itself, it's a property of the trainer run configuration, so it makes more sense to me there.

That said, we are still not 100% on the "main stream". My nsys (which I run as nsys profile ./train_gpt2cu \ ... ) says that there a few more streams actually, each with just a tiny amount of memory work. Streams 16, 18, 20. One of these I traced to be doing some memcopies during ncclCommInitRank. But I can't find any API to "set" the stream of NCCL...

train_gpt2.cu Outdated
@@ -562,7 +563,7 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) {
fwrite(model_header, sizeof(int), 256, model_file);
// write the parameters
void* params_memory_cpu = (void*)mallocCheck(model->num_parameters_bytes);
cudaCheck(cudaMemcpy(params_memory_cpu, model->params_memory, model->num_parameters_bytes, cudaMemcpyDeviceToHost));
cudaCheck(cudaMemcpyAsync(params_memory_cpu, model->params_memory, model->num_parameters_bytes, cudaMemcpyDeviceToHost, main_stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing sync

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this might actually be fine, because the copy involves memory which is not page-locked, so it isn't actually going to run async, I believe, but at the very least this code looks like a time bomb

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better to use cudaMemcpy here without Async, or with Async and then synchronize? I don't have enough background in C / CUDA / multi-stream here. If it's not the Async function, does it use the default stream, and then it's synchronous?

The docs are not super complete:
https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html#api-sync-behavior__memcpy-sync

Which stream is used for cudaMemcpy?
And they don't specify what happens when you cpu from device to a normal, non-pinned memory, as far as I can see.

Not sure what the correct solution is here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, things are a bit vague:
If you want to look at nsight systems and see that nothing is using the legacy stream, then memcpyAsync followed by synchronize is maybe the better solution. It also lets you search for synchronize to find the places where we have to wait.

A lot of these would also become at least partially async with 556, where we overlap device<->host and host<->disk transfers.

train_gpt2.cu Outdated
@@ -628,11 +629,13 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
// read in all the parameters from file and copy them to device
void* params_memory_cpu = (void*)mallocCheck(model->num_parameters_bytes);
freadCheck(params_memory_cpu, 1, model->num_parameters_bytes, model_file);
cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice));
cudaCheck(cudaMemcpyAsync(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice, main_stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing sync

train_gpt2.cu Outdated
@@ -718,13 +721,14 @@ void gpt2_build_from_random(GPT2 *model, int depth) {
}

// copy them to GPU
cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice));
cudaCheck(cudaMemcpyAsync(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice, main_stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing sync

train_gpt2.cu Outdated
@@ -1331,9 +1348,9 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader)
// write AdamW m, v, and master_weights here (they are all float)
size_t shard_num_parameters = multi_gpu_config.shard_num_parameters;
float* cpu_buffer = (float*)mallocCheck(shard_num_parameters * sizeof(float));
cudaCheck(cudaMemcpy(cpu_buffer, model->m_memory, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost));
cudaCheck(cudaMemcpyAsync(cpu_buffer, model->m_memory, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost, main_stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more syncs missing

@ngc92 ngc92 mentioned this pull request Jun 7, 2024
@karpathy
Copy link
Owner Author

karpathy commented Jun 8, 2024

I decided the the Async look scary and we should minimize dependencies "across lines of code" (e.g. requiring a synchronize right after) so I reverted them. This way we can also easily search for "Async" to look for possible trouble. We'll have memory traffic on the default stream but that's ok. Last thought we should minimize use of parallelism outside of the "critical path" that makes code fast. So anything we do a single time or rarely (e.g. load, store, checkpoint, etc.) would remain sync, just doesn't seem worth it.

@karpathy karpathy merged commit 637c1b6 into master Jun 8, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants