Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DeepSpeed] ZeRO-Infinity integration: getting started and issues #11464

Closed
stas00 opened this issue Apr 27, 2021 · 12 comments
Closed

[DeepSpeed] ZeRO-Infinity integration: getting started and issues #11464

stas00 opened this issue Apr 27, 2021 · 12 comments
Assignees

Comments

@stas00
Copy link
Contributor

stas00 commented Apr 27, 2021

DeepSpeed ZeRO-Infinity HF Integration is now available in the master branch of transformers. Here is a quick getting started/what's new post.

ZeRO-Infinity extends ZeRO-3 by extending CPU Offload with NVMe Offload, enabling training even bigger models. And it adds various other optimizations and improvements.

Getting started

Install the latest deepspeed version:

pip install git+https://github.com/microsoft/DeepSpeed

You will want to be on a transformers master branch, if you want to run a quick test:


git clone https://github.com/huggingface/transformers
cd transformers
BS=4; PYTHONPATH=src USE_TF=0 deepspeed examples/pytorch/translation/run_translation.py \
--model_name_or_path t5-small --output_dir /tmp/zero3 --overwrite_output_dir --max_train_samples 64 \
--max_eval_samples 64 --max_source_length 128 --max_target_length 128 --val_max_target_length 128 \
--do_train --num_train_epochs 1 --per_device_train_batch_size $BS --per_device_eval_batch_size $BS \
--learning_rate 3e-3 --warmup_steps 500 --predict_with_generate --logging_steps 0 --save_steps 0 \
--eval_steps 1 --group_by_length   --dataset_name wmt16 --dataset_config ro-en --source_lang en \
--target_lang ro --source_prefix "translate English to Romanian: " \
--deepspeed tests/deepspeed/ds_config_zero3.json

You will find a very detailed documentation here: https://huggingface.co/transformers/master/main_classes/trainer.html#deepspeed

Your new config file will look like this (for ZeRO-3 as an example):

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },

    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e14,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_fp16_weights_on_model_save": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

If you want to experiment with NVMe offload, please see: https://huggingface.co/transformers/master/main_classes/trainer.html#nvme-support

Deepspeed currently runs only fp16-mixed precision

While deepspeed devs are working on the fp32 mode, at this moment only fp16-amp-like train/eval is available. So if your model struggles under fp16/amp it will have the same struggles under deepspeed.

Moreover, because deepspeed does model.half() forcing all weights to fp16, some models might be ready for this (under AMP things are switched dynamically to fp16 where needed). If you run into this please post a new issue and we will try to find a solution/workaround for those special cases.

must use the latest transformers master

If you get deepspeed errors like it doesn't know what auto value is, you aren't on latest transformers master branch, git pull if you already have a clone and if you installed it already update your install.

For those who already use DeepSpeed HF integration

As the integration part is evolving it has gone through a major revamp and various improvements.

There are 2 important changes that you need to be aware of if you're already using DeepSpeed integration in transformers:

  1. After this release only config params that are set to auto will get automatically overriden/set to the correct/recommended values, everything else is left as is. This is to avoid the previously confusing behavior of never being quite sure what gets overridden and what not despite the logger telling what it did override. The new behavior is completely unambiguous.

    See examples

    Full doc: https://huggingface.co/transformers/master/main_classes/trainer.html#shared-configuration

  2. If you are using massive models and aren't using example scripts, make sure to read:

    Full doc: https://huggingface.co/transformers/master/main_classes/trainer.html#constructing-massive-models

Everything else should work as before or better.

The docs were revamped a lot too - if you find anything unclear or lacking please let me know.

If you encounter any problems please post an Issue and tag @stas00 to it.

Thank you!

@jncasey
Copy link
Contributor

jncasey commented Apr 27, 2021

Hi @stas00, is it normal for zero3 training to take a while to get started?

I haven't put in any time to investigating yet, but I updated transformers and deepspeed to the latest masters just to see if I could get them working. My simple training script (derived from the summarization example) works fine with deepspeed and the default zero2 config, but when I run the same script with the default zero3 config, training begins but hangs with the progress bar at step 0. I let it run for about half an hour before I killed the process. The quick test zero3 in your post above seems to run fine, however.

Is there some initial zero3 overhead I just need to be more patient with, or do I possibly have some deeper problem?

@stas00
Copy link
Contributor Author

stas00 commented Apr 27, 2021

Something is wrong then, deepspeed takes a bit longer to start than normal as it pre-allocates some memory, and extra so the first time if it needs to compile some cuda extensions, but once started it should work at the normal speed.

Hanging on zero3 could indicate that you're on multi-gpu and doing some code that blocks on trying to sync with other gpus. Anything involving forward calls must be performed on all gpus participating in the process. If one of them is skipped all other gpus will block waiting for that gpu.

For example, if you're doing some code that performs if trainer.is_world_process_zero() it could block - depending on the code. For example, saving checkpoints has to happen on all processes and not just rank0.

Could you please open a separate issue and help me to reproduce the problem and then we can look at it together.

To help diagnose, you can add this anywhere to your code:

import faulthandler
faulthandler.dump_traceback_later(20, repeat=True)

and it'll dump bt for all threads every 20 secs. So you will be able to see where it's hanging.

@thies1006
Copy link

Hello! I was trying out the command pasted above, but replacing the zero_optimization part from tests/deepspeed/ds_config_zero3.json with the configuration from the NVMe offload example (see link above). The error I get is:
AssertionError: num_elems 7563520> buffer 7563328.
I got this error before as well with the Megatron example from Deepspeed, but was able to solve it by increasing the aio block_size, however this time it did not work out. I should add that I used a SSD disk, in case that's important.

@stas00
Copy link
Contributor Author

stas00 commented Apr 30, 2021

Thank you for trying this new feature.

This looks like a potential bug in Deepspeed. I asked @tjruwase to have a look.

May be it's worthwhile to file an Issue at https://github.com/microsoft/DeepSpeed/issues if you have a few minutes? As this is definitely not an integration issue.

If you do please paste the full config you were using.

thank you, @thies1006

@tjruwase
Copy link
Contributor

tjruwase commented Apr 30, 2021

@thies1006, thanks for reporting this issue. As @stas00 suggested, could please report this as a deepspeed issue? It would be great if you included the exact ds_config.json in the issue report. Thanks so much!

@thies1006
Copy link

thies1006 commented May 3, 2021

Just now there appeared this issue which I guess is exactly the same case. Sorry for not posting the exact config right away. Thank you very much!

Edit: Lowering "sub_group_size" from 1e14 to 1e3 solved the issue (however another one comes up, filed another issue at Deepspeed).

@tjruwase
Copy link
Contributor

tjruwase commented May 3, 2021

@thies1006, there is now a PR for the assert:

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jun 5, 2021
@sinamoeini
Copy link
Contributor

@stas00 I am not sure if this is the right forum to ask. Feel free to direct me to somewhere else
Is there a standard way of cloning a partitioned parameter? The examples I have seen are usually using gather to reconstructing it into a pytorch parameter and then cloning it.

@stas00
Copy link
Contributor Author

stas00 commented Jul 25, 2023

indeed, but you have to do it before you called deepspeed.initialize - if you do after it - Deepspeed won't know about those new parameters and all kinds of undefined behaviors/breakages will occur.

You can still add/remove params after zero.Init context was run (if it's used), but the model needs to be complete wrt all params being in place before it's passed to deepspeed.initialize

@sinamoeini
Copy link
Contributor

sinamoeini commented Jul 25, 2023

@stas00 Thank you for your prompt response. so before deepspeed.initialize would this be a correct way of cloning a ds_module?

import deepspeed
# ds_module is already partitioned
with deepspeed.zero.GatheredParameters(list(ds_module.parameters())):
    new_module = copy.deepcopy(ds_module)

# at this point new_module is pytorch paramter
# to convert to ds module
new_module = deepspeed.zero.Init(new_module)

@stas00
Copy link
Contributor Author

stas00 commented Jul 25, 2023

I don't think this example can work, since deepspeed installs special attributes into the tensor which would be copied and point to the wrong place. You'd have to create a normal torch param and copy the data from another param, bu perhaps you can simply ask deepspeed for adding a new util that will do the right thing for you.

But let's stop this discussion here as this is offtopic to this thread and not really related to transformers - I propose for you to start a new issue at https://github.com/microsoft/DeepSpeed and discuss it there, where the Deepspeed team will be able to answer your needs better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants