Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving a checkpoint when training with NVMe offloading? #2082

Closed
aciborowska opened this issue Jul 8, 2022 · 23 comments · Fixed by #4707
Closed

Saving a checkpoint when training with NVMe offloading? #2082

aciborowska opened this issue Jul 8, 2022 · 23 comments · Fixed by #4707
Assignees
Labels
enhancement New feature or request training

Comments

@aciborowska
Copy link

I am testing NVMe offloading when training a model. When I try to save a checkpoint, I am getting (full stack trace below):

NotImplementedError: ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now.

Is that correct? There is no checkpointing with NVMe offloading or am I missing something in my setup/config file? If there is no checkpointing, how can I save the model?

Full stacktrace:

[2022-07-08 19:51:21]
[4e2c01e1] [rank=0] Traceback (most recent call last): <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]   File "/opt/conda/lib/python3.8/runpy.py", line 194, in _run_module_as_main <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]     return _run_code(code, main_globals, None, <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]   File "/opt/conda/lib/python3.8/runpy.py", line 87, in _run_code <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]     exec(code, run_globals) <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]   File "/run/determined/pythonuserbase/lib/python3.8/site-packages/determined/exec/harness.py", line 132, in <module> <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]     sys.exit(main(args.train_entrypoint)) <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]   File "/run/determined/pythonuserbase/lib/python3.8/site-packages/determined/exec/harness.py", line 123, in main <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]     controller.run() <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]   File "/run/determined/pythonuserbase/lib/python3.8/site-packages/determined/pytorch/deepspeed/_deepspeed_trial.py", line 296, in run <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]     self._run() <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]   File "/run/determined/pythonuserbase/lib/python3.8/site-packages/determined/pytorch/deepspeed/_deepspeed_trial.py", line 338, in _run <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]     self._save(path) <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]   File "/run/determined/pythonuserbase/lib/python3.8/site-packages/determined/pytorch/deepspeed/_deepspeed_trial.py", line 732, in _save <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]     self.trial.save(self.context, path) <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]   File "/run/determined/pythonuserbase/lib/python3.8/site-packages/determined/pytorch/deepspeed/_deepspeed_trial.py", line 934, in save <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]     m.save_checkpoint(path, tag=f"model{i}") <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]   File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 2842, in save_checkpoint <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]     self._save_zero_checkpoint(save_dir, tag) <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]   File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 3106, in _save_zero_checkpoint <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]     zero_sd = dict(optimizer_state_dict=self.optimizer.state_dict(), <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]   File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/zero/stage3.py", line 2699, in state_dict <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0]     raise NotImplementedError( <none> [2022-07-08 19:51:21]
[4e2c01e1] [rank=0] NotImplementedError: ZeRO-3 does not yet support checkpointing with NVMe offloading, please disable for now. 

Config file:

{
  "train_batch_size": 256,
  "steps_per_print": 2000,
  "optimizer": {
    "type": "Adam",
    "params": {
      "lr": 0.001,
      "betas": [
        0.8,
        0.999
      ],
      "eps": 1e-8,
      "weight_decay": 3e-7
    }
  },
  "scheduler": {
    "type": "WarmupLR",
    "params": {
      "warmup_min_lr": 0,
      "warmup_max_lr": 0.001,
      "warmup_num_steps": 1000
    }
  },
  "gradient_clipping": 1.0,
  "prescale_gradients": false,
  "fp16": {
      "enabled": true,
      "fp16_master_weights_and_grads": false,
      "loss_scale": 0,
      "loss_scale_window": 500,
      "hysteresis": 2,
      "min_loss_scale": 1,
      "initial_scale_power": 15
  },
  "wall_clock_breakdown": false,
  "zero_optimization": {
      "stage": 3,
      "contiguous_gradients": true,
      "stage3_max_live_parameters": 1e9,
      "stage3_max_reuse_distance": 1e9,
      "stage3_prefetch_bucket_size": 1e7,
      "stage3_param_persistence_threshold": 1e5,
      "reduce_bucket_size": 1e7,
      "sub_group_size": 1e9,
      "offload_param": {
        "device": "nvme",
        "nvme_path": "nvme0n1",
        "pin_memory": false
      },
      "offload_optimizer": {
        "device": "nvme",
        "nvme_path": "nvme0n1",
        "pin_memory": false
    }
  }
}
@tjruwase
Copy link
Contributor

tjruwase commented Jul 8, 2022

@aciborowska, sorry for this inconvenience. Model checkpointing for training with nvme offloading is not yet available.

@aciborowska
Copy link
Author

Okay. In that case, after I complete the training, I can still save the model with e.g., stage3_gather_16bit_weights_on_model_save?

Also, any plans to add checkpointing for NVMe in the near future?

@tjruwase
Copy link
Contributor

tjruwase commented Jul 9, 2022

Yes, we do plan to add checkpointing for NVMe. In reality, you are the first user to my knowledge with this request. Can you please explain a bit your training scenario and why CPU offloading is insufficient?

@aciborowska
Copy link
Author

I was mostly being curious about using NVMe in terms of performance trade offs compared to CPU and with different parameters. It is very surprising to me that NVMe does not support checkpointing (on the contrary to CPU), and this fact is not documented or even mentioned in tutorial/blogs, since (to me) is feels like a limitation.

Is there any reason why you decided not to implement checkpointing for NVMe? Is NVMe offloading mainly intended to be an inference-related feature?

One more question. When I was testing NVMe/CPU offloading (AWS, 1 x NVIDIA T4 Tensor Core GPU with 125 GB NVMe) I noticed that offloading with NVMe is about 3-4 times slower than CPU offloading. Is that something that can be generally expected? Can it get significantly larger/smaller?

@tjruwase
Copy link
Contributor

Thanks for the clarification. We did not yet implement checkpointing for NVMe due to lack of bandwidth and interest. NVMe offloading is meant for training, finetuning, and inference. But we are yet to see much interest in training models at the scales that require it.

NVMe offloading performance depends on the NVMe device read/write speeds. Please see #998 and here for tips on benchmarking your system.

@aciborowska
Copy link
Author

Thanks!

@timohear
Copy link

Yes, we do plan to add checkpointing for NVMe. In reality, you are the first user to my knowledge with this request. Can you please explain a bit your training scenario and why CPU offloading is insufficient?

FYI I've run into this problem as well using DeepSpeed 0.7.
The scenario is fine-tuning gpt-neox-20b on a 2x RTX a6000 machine with 128Gb of RAM. On this setup I've only been able to get DeepSpeed finetuning to work with both optimizer and parameter offloading to nvme.

@zyfedward
Copy link

I've run into this problem as well with finetuning BLOOM.

@tjruwase tjruwase reopened this Dec 22, 2022
@tjruwase
Copy link
Contributor

Revisiting given the recent interest.

@tjruwase tjruwase added enhancement New feature or request training labels Dec 22, 2022
@tjruwase tjruwase assigned loadams and unassigned samadejacobs Feb 1, 2023
@loadams
Copy link
Contributor

loadams commented Feb 6, 2023

Taking a look at this now.

@StevenArzt
Copy link

We are also running fine-tuning on BLOOM and need NVMe offloading due to memory constraints (apparently 2 nodes with 2 TB of memory each isn't enough). I would really appreciate snapshot support on NVMe.

@loadams
Copy link
Contributor

loadams commented Feb 17, 2023

Good to know, I'll post an update here shortly.

@loadams
Copy link
Contributor

loadams commented Feb 22, 2023

@aciborowska - what model were you trying to train when you first hit this?

@StevenArzt thanks, starting work on this now.

@eisene
Copy link
Contributor

eisene commented Mar 30, 2023

I would find this useful as well. My use case is that I'm working on a side project to make a machine translation system for a specific low resource language. I'm experimenting with large-ish decoders, on the order of 3-7B params. I want to do this for as little money as possible so I decided to use my home machine - RTX 3080 Ti with 32GB RAM.

The training works, but only with NVMe offload. It takes about two weeks to fine-tune one of these models but I'm fine with that.

I'm happy to help with either testing or implementation.

@Entropy-xcy
Copy link

I believe supporting this feature is super important! I am training LoRA (from the PEFT library) using DeepSpeed. Everything else works like magic, except for the checkpoint. The communication and gather overhead of NVME devices become less of a problem when fine-tuning using LoRA, as it represents only a small fraction of the parameters.

I am happy to assist with testing, benchmarking, or implementation.

@dblakely
Copy link

dblakely commented Aug 14, 2023

I'd also greatly appreciate this feature! 🙏

In the meantime, I feel like it would be nice to have DeepSpeed raise a value error or at least give a warning at the start of training that checkpointing won't work.

@PaulScotti
Copy link

Likewise adding support here that I'd extremely appreciate this feature :)

@loadams
Copy link
Contributor

loadams commented Aug 17, 2023

I'll prioritize this work, thanks @dblakely and @PaulScotti for your feedback

@chongxiaoc
Copy link

chongxiaoc commented Sep 22, 2023

+1. Would like this feature to be supported.

@gary-young
Copy link

+1. We really need this feature because the LLM is larger and larger ...

@loadams
Copy link
Contributor

loadams commented Jan 2, 2024

@gary-young and @chongxiaoc - work is continuing on this here, please see that for status and to test the work.

github-merge-queue bot pushed a commit that referenced this issue Jan 6, 2024
Previous PR #4416 had too many issues, closing that one and re-opening.
This PR includes a passing test.

This is a proposal for an implementation of checkpointing models when
training with ZeRO-3 with NVMe offload:

1. Currently, the names of the files used in the checkpoint are based on
the Python id of the parameter object, which is just the parameter's
address in memory. This is not stable across runs, which has two
disadvantages:
- The NVMe offloading files grow with every run of the model even if the
architecture did not change. This wastes disk space and, at least for
me, was a surprise when I first saw it. It is not related to
checkpointing.
- Without a way to match the file to the offloaded tensor we can't
reload the checkpoint.

We propose an alternative naming scheme. The parameters are named after
their ds_id instead of their Python id, and the tensors are named after
their state_name and (new) parameter id.

2. A model checkpoint now has to include all the offloaded tensor files.
During checkpoint save/load we copy all the tensor files to/from the
"offloaded_tensors" subdirectory of the checkpoint. We provide some
logging on the remaining space on the file system due to the potential
size of these files, especially as they accumulate in each checkpoint.
We do not copy the gradient files.

3. When loading the checkpoint, the optimizer already has prepared
buffers for swapping. We need to purge them so that they are replaced
with the freshly copied on-disk buffers from the checkpoint.

The key differences between this PR and the previous one:

- There's a test for a simple model with parameter/optimizer offload set
to cpu/cpu, cpu/nvme and nvme/nvme.
 - Gradient files are not copied.
 - FP16 and FP32 parameter buffers are handled correctly during load.

Fixes #2082.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
@0781532
Copy link

0781532 commented Jan 27, 2024

I would find this useful as well. My use case is that I'm working on a side project to make a machine translation system for a specific low resource language. I'm experimenting with large-ish decoders, on the order of 3-7B params. I want to do this for as little money as possible so I decided to use my home machine - RTX 3080 Ti with 32GB RAM.

The training works, but only with NVMe offload. It takes about two weeks to fine-tune one of these models but I'm fine with that.

I'm happy to help with either testing or implementation.

Hi essene,

I am trying to fine-tune a 3-7B LLM models using zero-3 by completely offloading to NVMe using single GPU 3090 24GB + 2TB SSD but always face "kill subprocess" before training process start. Could you please share your experience and ds_config.json to me?

@loadams
Copy link
Contributor

loadams commented Jan 29, 2024

@0781532 - I'd recommend starting a new issue to share your error code and s simple repro case if possible.

mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this issue Feb 17, 2024
Previous PR microsoft#4416 had too many issues, closing that one and re-opening.
This PR includes a passing test.

This is a proposal for an implementation of checkpointing models when
training with ZeRO-3 with NVMe offload:

1. Currently, the names of the files used in the checkpoint are based on
the Python id of the parameter object, which is just the parameter's
address in memory. This is not stable across runs, which has two
disadvantages:
- The NVMe offloading files grow with every run of the model even if the
architecture did not change. This wastes disk space and, at least for
me, was a surprise when I first saw it. It is not related to
checkpointing.
- Without a way to match the file to the offloaded tensor we can't
reload the checkpoint.

We propose an alternative naming scheme. The parameters are named after
their ds_id instead of their Python id, and the tensors are named after
their state_name and (new) parameter id.

2. A model checkpoint now has to include all the offloaded tensor files.
During checkpoint save/load we copy all the tensor files to/from the
"offloaded_tensors" subdirectory of the checkpoint. We provide some
logging on the remaining space on the file system due to the potential
size of these files, especially as they accumulate in each checkpoint.
We do not copy the gradient files.

3. When loading the checkpoint, the optimizer already has prepared
buffers for swapping. We need to purge them so that they are replaced
with the freshly copied on-disk buffers from the checkpoint.

The key differences between this PR and the previous one:

- There's a test for a simple model with parameter/optimizer offload set
to cpu/cpu, cpu/nvme and nvme/nvme.
 - Gradient files are not copied.
 - FP16 and FP32 parameter buffers are handled correctly during load.

Fixes microsoft#2082.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request training
Projects
None yet
Development

Successfully merging a pull request may close this issue.