-
Notifications
You must be signed in to change notification settings - Fork 4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Saving a checkpoint when training with NVMe offloading? #2082
Comments
@aciborowska, sorry for this inconvenience. Model checkpointing for training with nvme offloading is not yet available. |
Okay. In that case, after I complete the training, I can still save the model with e.g., Also, any plans to add checkpointing for NVMe in the near future? |
Yes, we do plan to add checkpointing for NVMe. In reality, you are the first user to my knowledge with this request. Can you please explain a bit your training scenario and why CPU offloading is insufficient? |
I was mostly being curious about using NVMe in terms of performance trade offs compared to CPU and with different parameters. It is very surprising to me that NVMe does not support checkpointing (on the contrary to CPU), and this fact is not documented or even mentioned in tutorial/blogs, since (to me) is feels like a limitation. Is there any reason why you decided not to implement checkpointing for NVMe? Is NVMe offloading mainly intended to be an inference-related feature? One more question. When I was testing NVMe/CPU offloading (AWS, 1 x NVIDIA T4 Tensor Core GPU with 125 GB NVMe) I noticed that offloading with NVMe is about 3-4 times slower than CPU offloading. Is that something that can be generally expected? Can it get significantly larger/smaller? |
Thanks for the clarification. We did not yet implement checkpointing for NVMe due to lack of bandwidth and interest. NVMe offloading is meant for training, finetuning, and inference. But we are yet to see much interest in training models at the scales that require it. NVMe offloading performance depends on the NVMe device read/write speeds. Please see #998 and here for tips on benchmarking your system. |
Thanks! |
FYI I've run into this problem as well using DeepSpeed 0.7. |
I've run into this problem as well with finetuning BLOOM. |
Revisiting given the recent interest. |
Taking a look at this now. |
We are also running fine-tuning on BLOOM and need NVMe offloading due to memory constraints (apparently 2 nodes with 2 TB of memory each isn't enough). I would really appreciate snapshot support on NVMe. |
Good to know, I'll post an update here shortly. |
@aciborowska - what model were you trying to train when you first hit this? @StevenArzt thanks, starting work on this now. |
I would find this useful as well. My use case is that I'm working on a side project to make a machine translation system for a specific low resource language. I'm experimenting with large-ish decoders, on the order of 3-7B params. I want to do this for as little money as possible so I decided to use my home machine - RTX 3080 Ti with 32GB RAM. The training works, but only with NVMe offload. It takes about two weeks to fine-tune one of these models but I'm fine with that. I'm happy to help with either testing or implementation. |
I believe supporting this feature is super important! I am training LoRA (from the PEFT library) using DeepSpeed. Everything else works like magic, except for the checkpoint. The communication and gather overhead of NVME devices become less of a problem when fine-tuning using LoRA, as it represents only a small fraction of the parameters. I am happy to assist with testing, benchmarking, or implementation. |
I'd also greatly appreciate this feature! 🙏 In the meantime, I feel like it would be nice to have DeepSpeed raise a value error or at least give a warning at the start of training that checkpointing won't work. |
Likewise adding support here that I'd extremely appreciate this feature :) |
I'll prioritize this work, thanks @dblakely and @PaulScotti for your feedback |
+1. Would like this feature to be supported. |
+1. We really need this feature because the LLM is larger and larger ... |
@gary-young and @chongxiaoc - work is continuing on this here, please see that for status and to test the work. |
Previous PR #4416 had too many issues, closing that one and re-opening. This PR includes a passing test. This is a proposal for an implementation of checkpointing models when training with ZeRO-3 with NVMe offload: 1. Currently, the names of the files used in the checkpoint are based on the Python id of the parameter object, which is just the parameter's address in memory. This is not stable across runs, which has two disadvantages: - The NVMe offloading files grow with every run of the model even if the architecture did not change. This wastes disk space and, at least for me, was a surprise when I first saw it. It is not related to checkpointing. - Without a way to match the file to the offloaded tensor we can't reload the checkpoint. We propose an alternative naming scheme. The parameters are named after their ds_id instead of their Python id, and the tensors are named after their state_name and (new) parameter id. 2. A model checkpoint now has to include all the offloaded tensor files. During checkpoint save/load we copy all the tensor files to/from the "offloaded_tensors" subdirectory of the checkpoint. We provide some logging on the remaining space on the file system due to the potential size of these files, especially as they accumulate in each checkpoint. We do not copy the gradient files. 3. When loading the checkpoint, the optimizer already has prepared buffers for swapping. We need to purge them so that they are replaced with the freshly copied on-disk buffers from the checkpoint. The key differences between this PR and the previous one: - There's a test for a simple model with parameter/optimizer offload set to cpu/cpu, cpu/nvme and nvme/nvme. - Gradient files are not copied. - FP16 and FP32 parameter buffers are handled correctly during load. Fixes #2082. --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Hi essene, I am trying to fine-tune a 3-7B LLM models using zero-3 by completely offloading to NVMe using single GPU 3090 24GB + 2TB SSD but always face "kill subprocess" before training process start. Could you please share your experience and ds_config.json to me? |
@0781532 - I'd recommend starting a new issue to share your error code and s simple repro case if possible. |
Previous PR microsoft#4416 had too many issues, closing that one and re-opening. This PR includes a passing test. This is a proposal for an implementation of checkpointing models when training with ZeRO-3 with NVMe offload: 1. Currently, the names of the files used in the checkpoint are based on the Python id of the parameter object, which is just the parameter's address in memory. This is not stable across runs, which has two disadvantages: - The NVMe offloading files grow with every run of the model even if the architecture did not change. This wastes disk space and, at least for me, was a surprise when I first saw it. It is not related to checkpointing. - Without a way to match the file to the offloaded tensor we can't reload the checkpoint. We propose an alternative naming scheme. The parameters are named after their ds_id instead of their Python id, and the tensors are named after their state_name and (new) parameter id. 2. A model checkpoint now has to include all the offloaded tensor files. During checkpoint save/load we copy all the tensor files to/from the "offloaded_tensors" subdirectory of the checkpoint. We provide some logging on the remaining space on the file system due to the potential size of these files, especially as they accumulate in each checkpoint. We do not copy the gradient files. 3. When loading the checkpoint, the optimizer already has prepared buffers for swapping. We need to purge them so that they are replaced with the freshly copied on-disk buffers from the checkpoint. The key differences between this PR and the previous one: - There's a test for a simple model with parameter/optimizer offload set to cpu/cpu, cpu/nvme and nvme/nvme. - Gradient files are not copied. - FP16 and FP32 parameter buffers are handled correctly during load. Fixes microsoft#2082. --------- Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
I am testing NVMe offloading when training a model. When I try to save a checkpoint, I am getting (full stack trace below):
Is that correct? There is no checkpointing with NVMe offloading or am I missing something in my setup/config file? If there is no checkpointing, how can I save the model?
Full stacktrace:
Config file:
The text was updated successfully, but these errors were encountered: