Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utilities for cuda streams + disk IO #556

Merged
merged 10 commits into from
Jun 23, 2024
Merged

Conversation

ngc92
Copy link
Contributor

@ngc92 ngc92 commented Jun 5, 2024

handling disk io for checkpointing with cuda streams is a nontrivial task. If you're not careful, you can easily get broken code (need to wait for data to be on the CPU before you can start writing the buffer to the disk), or synchronous behaviour (because memory is not page-locked, so async copies are not possible)
Therefore, this PR introduces two new utility functions to do the disk <-> device data transfer.

In addition to being less error prone and reducing code duplication, this also gives us a single point at which we can implement double buffering to actually get an overlap between device transfer and disk writes; with an added bonus that we no longer need to allocate giant CPU-side arrays (think: 8 A100s on the same node each wanting to write 40 GB model state; we'd attempt to allocate 320 GB of host memory there; maybe the boxes are big enough even under that scenario, but do we really want to do that?)

I've also added a unit test for these functions; its very rough, I am not going to touch the makefile so for now you need to compile yourself, and it leaves behind its temp file. But it already payed off, because my first implementation had a wrong offset somewhere, and the test caught this :)

There is also the problem notices first with #522 that we currently miss master weights in the saved state. This is hacked in here very quickly, but its not really a good solution; should be combined at least with #522's addition of a flag in the file that indicates whether to expect master weights or not.

Because we now do file IO with cuda, we get that cuda_common.h includes utils.h, and requires us to mark all functions there inline. I also had to add a checked write function, for which I've just copied the error handling; not 100% sure if that makes sense.

For the double-buffered transfers, I've put a size of 32MiB, but that is not based on any actual data. I just wanted to put something there that is neither super tiny nor super big :)

@ngc92 ngc92 mentioned this pull request Jun 8, 2024
@ngc92 ngc92 marked this pull request as ready for review June 17, 2024 12:52
@ngc92 ngc92 changed the base branch from feature/streams to master June 17, 2024 22:34
train_gpt2.cu Outdated
cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice));
free(params_memory_cpu);
file_to_device(model->params_memory, model_file, model->num_parameters_bytes,
32*1024*1024, main_stream);
Copy link
Contributor

@gordicaleksa gordicaleksa Jun 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any particular reason why buffers are hardcoded to 32 MBs?

nit: maybe extract this into a global variable and pass in everywhere, easier for maintenance if we want to change it later on

// prime the read buffer; first copy means we have to wait
char* gpu_read_ptr = (char*)src;
size_t copy_amount = std::min(buffer_size, num_bytes);
cudaCheck(cudaMemcpyAsync(read_buffer, gpu_read_ptr, copy_amount, cudaMemcpyDeviceToHost, stream));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious: does async matter here given that we call synchronize immediately on the next line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, its just for consistency; since non-async versions of this function don't take a stream argument, if you want this transfer to show up in the "right" stream in nsight, you need to use the async version.


// copy the last remaining write buffer to gpu
cudaCheck(cudaMemcpyAsync(gpu_write_ptr, write_buffer, write_buffer_size, cudaMemcpyHostToDevice, stream));
cudaCheck(cudaFreeHost(buffer_space));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cudaFreeHost is blocked until the line above finishes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question. Allocation is listed as an implicit sync, I would assume deallocation also needs to be, but I'm not 100% sure, so maybe we should be explicit here.
https://docs.nvidia.com/cuda/cuda-c-programming-guide/#implicit-synchronization

train_gpt2.cu Outdated
@@ -1231,20 +1222,24 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename
printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20);
cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float)));
}

if(state_header[4] == 1 && !model->use_master_weights) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will have to be refactored a bit due to recent changes

@gordicaleksa
Copy link
Contributor

Left some comments - lgtm!

@karpathy karpathy merged commit 2543b62 into karpathy:master Jun 23, 2024
11 checks passed
@ngc92 ngc92 deleted the streams-io branch July 11, 2024 12:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants