Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Parallelism and Big Models #8771

Open
alexorona opened this issue Nov 24, 2020 · 68 comments
Open

Model Parallelism and Big Models #8771

alexorona opened this issue Nov 24, 2020 · 68 comments
Assignees
Labels
Model Parallel Model Parallelilsm Implementations WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@alexorona
Copy link
Contributor

🚀 Feature request

This is a discussion issue for training/fine-tuning very large transformer models. Recently, model parallelism was added for gpt2 and t5. The current implementation is for PyTorch only and requires manually modifying the model classes for each model. Possible routes (thanks to @stas00 for identifying these):

  • fairscale to avoid individual model implementation
  • deepspeed to possibly enable even larger models to be trained
@stas00
Copy link
Contributor

stas00 commented Nov 24, 2020

Thank you, @alexorona!

I'm still in the process of gathering info/reading up and doing some small experimentation, so will post my thoughts once I have something concrete to share.

Here are some resources if someone wants to join in:

Abbreviations:

  • MP = Model Parallelism
  • DP = Data Parallelism
  • PP = Pipeline Parallelism

Resources:

@stas00
Copy link
Contributor

stas00 commented Dec 30, 2020

Update: so we have

I don't have proper benchmarks yet, but I can definitely see 3-5 times less gpu ram usage! So these would be the first go-to solution when a model doesn't fit onto a single GPU.

@stas00
Copy link
Contributor

stas00 commented Dec 30, 2020

OK, so studying @alexorona's t5 MP implementation I think we have a few issues related to how we spread out the models across different devices.

For the purpose of this discussion let's use a simplistic approach of having just 2 GPUs (g1 and g2)

@alexorona's current approach is to assume that encoder and decoder are of the same size and then split 1/2 encoder layers onto g1 and the other half onto g2. Repeat the same for decoder.

This approach has 3 issues:

  1. it doesn't work if encoder and decoder aren't of the same size, which is the case with many models.

  2. it introduces unnecessary copying of data from g1 to g2 in the middle of encoder and then again in the middle of decoder, rather than doing just one copy between end of encoder and beginning of decoder. 3 times vs 1 (in our simplistic 2-gpu example).

  3. it leaves out all other layers from the device map and assigns them to the first or the last device in a hardcoded way depending to where they fit better, so the user has no control over where these go.

It does make the implementation relatively simple, since we just need to move half the layers of the encoder to g1 and the other half to g2 and bring the inputs/outputs to the right devices.

  • Issue 1 can be fixed by providing 2 device maps - one for encoder and a different one for decoder. They would be the same if len(encoder) == len(decoder). i.e. we are still using @alexorona, split-encoder and split-decoder approach.

  • Issue 2 can be solved again by 2 separate device maps, but the first one will map encoder - the second decoder. So there will be no splitting of the layers of encoder or decoder between separate devices. I think I may try to use this solution for Bart.

encoder_device_map > {0 => [1...6]}
decoder_device_map=> {1 => [1..6]}

(note: I'm using a non-python notation of a range here)

It will be trickier to allow overlap if the number of layers is different between encoder and decoder - say 6:9 or 6:12 - In which case it might be:

encoder_device_map > {0 => [1...6]}             # 6 layer encoder
decoder_device_map=> {0 => [1..2], 1=> [3..9]}  # 9 layer decoder

So the model will need to be able to transparently handle switching layers and inputs/outputs not only through its encode/decoder layers but also from encoder to decoder - but it's quite doable.

This uneven situation would also be the case on some weird setups like mine where the gpus are of different sizes. On my setup I have one card of 8GB and another 24GB. This won't be an issue with @alexorona's current implementation.

If any of you have had a chance to think about possible solutions and some totally different ways of approaching that please share your insights.

@alexorona
Copy link
Contributor Author

alexorona commented Dec 30, 2020

I was so full of hope that a simple dictionary could serve as a device_map for everything, but now you have shattered my blissful ignorance @stas00. But thanks so much for pointing this out! Super important! The characterization is not quite right and I think it's because you're using 2 GPUs, but the problem you identified is real. Basically both the decoder and encoder use the same map, so the first attention block of the decoder is located on the same device as the first attention block of the encoder. The performance degradation is trivial because the hand-off between GPUs when you have 8 or less is pretty efficient (when you have more, there's problems you have to work around by changing the NCCL environment variables). I thought about trying to do what you've suggested, but it meant that the device_map would have to get more complicated, which I was trying to avoid. However, if some of the decoder architectures have a different number of layers in the decoder than the encoder, the generalizability of the implementation will just collapse. Oh well. It was nice while it lasted.

It looks like you've really busy the last week. Responding to your comments and PRs...

@stas00
Copy link
Contributor

stas00 commented Dec 30, 2020

Thank you for your follow up, @alexorona.

As you're saying that from your experience the copying overhead is negligible then your current solution would work perfectly fine in some situations, like the balanced t5, but will need to be altered in others. So very likely it's this and that, rather than not this but that. i.e. no shuttered hopes.

And if this doesn't fit in other situations it can be extended with a separate device_map for encoder and decoder. Perhaps for some models it'd be most efficient to keep the encoder on one set of devices and decoder on the other, and others shared. So that means we need to come with a way of accepting a variety of different device maps.

Perhaps, we make the device_map to have two parts, but the second part (decoder) to be optional and if not passed then the first one is used for both? Then the simple solution remains mainly unchanged.

May I ask if you have used some existing implementation to model your current implementation after, and perhaps you have a list of various MP implementations so that we could study and find the most suitable way that would fit. So far I have only studied the way you approached it.

Thank you.

p.s. here are some examples of models with different encoder/decoder sizes:

@stas00
Copy link
Contributor

stas00 commented Dec 30, 2020

I have a few follow up questions, @alexorona

  1. on use of torch.cuda.empty_cache() - I guess as long as it remains in deparallelize it is not really going to interfere with whatever normal caching is going on. I don't think it will do what you intended it to do with an explicit gc.collect() as I explained in [test_model_parallelization] multiple fixes #9354

  2. when do you think it's better to use this split as you implemented it (again simplifying to 2 gpus 6 layers in encoder and same in decoder):

     encoder decoder
gpu0 1 2 3    1 2 3
gpu1 4 5 6    4 5 6

vs giving the whole gpu to one of them:

        encoder        decoder
gpu0 1 2 3 4 5 6  
gpu1                 1 2 3 4 5 6

Thank you!

@g-karthik
Copy link

@alexorona I had a chance to briefly look at your approach to model-parallelism via explicit device map construction. What are your thoughts on extending this approach via the construction of a generic Megatron-style mpu object that implements basic methods such as get_{model,data}_parallel_{rank,group,world_size}()? My understanding is that DeepSpeed works with any model-parallelism approach that implements these methods (the mpu object needs to be passed to deepspeed.initialize()), it doesn't have to necessarily be a tensor-splicing approach like Megatron.

Would it make sense to extend/tweak the device map approach to model-parallelism to fit within the mpu setup, as opposed to trying to get deepspeed's memory optimization primitives to work with the MP implementation without leveraging mpu?

@stas00
Copy link
Contributor

stas00 commented Jan 2, 2021

@alexorona, I think I found at least one culprit for needing torch.cuda.set_device(id) all over the place. There could be more than one culprit, but at least with pytorch-nightly I have to add it in a bunch of places if apex.normalization.FusedLayerNorm is used. NVIDIA/apex#1022 If I remove its use, I don't need any torch.cuda.set_device(id).

On the other hand I don't see apex.normalization.FusedLayerNorm is being used in either t5 or gpt2. So perhaps it's something else. I see many bug reports wrt to switching devices and some ops failing without torch.cuda.set_device(id) or some solid pytorch op running just before it. It sounds like a bug in some pytorch operations.

@stas00
Copy link
Contributor

stas00 commented Jan 2, 2021

Meanwhile I've finished porting BartForConditionalGeneration to MP and pretty much adopted a variation of your device_map, so it won't change much from your original design if accepted.

It supports either type of map - your split approach or the one I proposed (flat). Here are some examples:

device_maps_flat = {
    "sshleifer/tinier_bart": {
        "encoder": {0: [0, 1] },
        "decoder": {1: [0] },
    },
    "sshleifer/distilbart-xsum-6-6": {
        "encoder": {0: [0, 1, 2, 3, 4, 5] },
        "decoder": {1: [0, 1, 2, 3, 4, 5] },
    },
}


device_maps_split = {
    "sshleifer/tinier_bart": {
        "encoder": {0: [0],
                    1: [1],
                    },
        "decoder": {1: [0] },
    },
    "sshleifer/distilbart-xsum-6-6": {
        "encoder": {0: [0, 1, 2],
                    1: [3, 4, 5],
                    },
        "decoder": {0: [0, 1, 2],
                    1: [3, 4, 5],
                    },
    },
}

I think down the road we could support other types by simply using different keys for whatever other configuration is desired.

I think eventually we will need to benchmark the different splits and see which one is more efficient. e.g. the flat approach currently suffers from the shared embeddings since they need to be constantly switched back and forth between devices!

I also have much improved magical device switching functions so it should be much faster to port to MP in the future.

One other design change I will propose is to drop first/last devices and instead have self.main_device, so that everything happens on just one device and we only send to other devices whatever needs to be offloaded - layer/block work that is. So probably it'd mean that the main device should have less than equal number of layers/blocks assigned to it as it'll use more memory for all the inputs and outputs. I still need to polish this idea.

@stas00 stas00 added the Model Parallel Model Parallelilsm Implementations label Jan 2, 2021
@stas00
Copy link
Contributor

stas00 commented Jan 5, 2021

We also may need to take into consideration @osalpekar's suggestion at pytorch/pytorch#49961 (comment) - I haven't studied that side of things yet so can't comment at the moment. On one side it appear much more complex to setup, on the other side it might make things much easier model-side-wise. If you already familiar with that side of things please share your insights.

@stas00
Copy link
Contributor

stas00 commented Jan 5, 2021

And another suggestion is to potentially use Pipe Parallelism here: pytorch/pytorch#49961 (comment) by @pritamdamania87

The main issue would be that it'll be enabled in pt-1.8

But @pritamdamania87 raises a super-important point - and that the current implementation doesn't take advantage of the multiple gpus, other than for their memory. So all the other gpus idle while one works, which is probably not what we want.

Unless I'm missing something then this means that the current approach that we have been discussing (and released) is really a no-go. Please correct me if I'm wrong.

@g-karthik
Copy link

g-karthik commented Jan 5, 2021

Pipeline parallelism is already supported in DeepSpeed, although I haven't played around with it.

https://www.deepspeed.ai/tutorials/pipeline/

@stas00
Copy link
Contributor

stas00 commented Jan 5, 2021

yes, and fairscale too!

@stas00
Copy link
Contributor

stas00 commented Jan 5, 2021

@alexorona, please have a look at this super-important comment pytorch/pytorch#49961 (comment)
which I understand that torch.cuda.set_device() is not just for fixing bugs in some pytorch ops, but it's actually an essential tool to avoid back-n-forth copying of data which happens when torch.cuda.set_device() is not set to the device the ops are happening on. Ouch. I couldn't find any docs covering that culprit.

We were trying to get rid of it. Now it looks like we need to make sure we have it in every place we switch to a new device. So when switching to a new device we need:

  1. torch.cuda.set_device(device)
  2. inputs.to(device)
  3. layer.to(device)

@stas00
Copy link
Contributor

stas00 commented Jan 6, 2021

I was asked to share a sort of design/explanation of what we have implemented so far, so here you go (@alexorona please correct me if I have missed anything - thank you!)


Here is an example of a sshleifer/distilbart-xsum-6-6 BartForConditionalGeneration model:

 (model): BartModel(
    (shared): Embedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024, padding_idx=1)
      (layers): ModuleList( 6 x BartEncoderLayer)
      (layernorm_embedding): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)
    )
    (decoder): BartDecoder(
      (embed_tokens): Embedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024, padding_idx=1)
      (layers): ModuleList( 6 x BartDecoderLayer)
      (layernorm_embedding): FusedLayerNorm(torch.Size([1024]), eps=1e-05, elementwise_affine=True)
    )
  )
  (lm_head): Linear(in_features=1024, out_features=50264, bias=False)
)

Note that I collapsed the huge bulk of it and it's represented by just 2 lines that I wrote myself - it was not the output of the model dump.

      (layers): ModuleList( 6 x BartEncoderLayer)
      (layers): ModuleList( 6 x BartDecoderLayer)

this is some 90% of the model and that's what we want to spread out through multiple gpus.

So we have the bulk of memory used by 6 x BartEncoderLayer and 6 x BartDecoderLayer, plus some other components.

For the simplicity of the example let's say we have 2 gpus we want to split the model into.

Currently the idea is to put the 6 encoder layers on gpu 0 and the same for decoder layers but on gpu 1:

device_map = {
        "encoder": {0: [0, 1, 2, 3, 4, 5] },
        "decoder": {1: [0, 1, 2, 3, 4, 5] },
    }

or alternatively, splice each group as following:

device_map = {
        "encoder": {0: [0, 1, 2],
                    1: [3, 4, 5],
                    },
        "decoder": {0: [0, 1, 2],
                    1: [3, 4, 5],
                    },
    }

and the remaining non-encoder/decoder layer modules can be all on gpu 0 or grouped closer to where they are needed. We still haven't quite finalized that map.

Of course, other models may have more or less layers and they don't have to have the same number of layers in encoder and decoder.

Now that we have the map, we can place different layers/blocks on different devices

A simplified explanation would be with the usual drawing of the deep nn (random blocks in this example)

blocks   | [blk] ... [blk 2] | [blk 3] ... [blk 5] | [blk 6] ... [blk 7] | [head]
devices  |         0         |          1          |          2          |   0

Implementation details:

  1. create model
  2. model.parallelize(): run through the model's layers and remap them to specific devices as defined by the device map by simply runnin to(device)
  3. inside forward we switch inputs to the same device as the layer's params using a handy wrapper I shared here: rfc: automating the switching of inputs to the device of the params pytorch/pytorch#49961 (comment)
  4. some outputs need to be brought back to the device where the logic of the main program happens (e.g. beam search)

Complications:

  • shared embeds are a performance issue - we have to switch them back and forth between different devices.
  • because some layers have params on different devices the developer has to explicitly choose which device to switch input to
  • looks like we may need to sort out that torch.cuda.set_device() which apparently is needed too - sometimes to cover for bugs in pytorch, other times for performance - I haven't figured it out yet, I opened an issue:
    need a clear guide for when and how to use torch.cuda.set_device() pytorch/pytorch#50112
  • beam search works extremely slow with this approach - 10x slowdown.

To port a model one needs to apply the device map (stage 2 above) and then gradually deal with wrong device errors, by remapping the inputs to the devices of the params of the layer. Alex was doing each variable manually, which is a huge pain. I automated this process (it's in 2 PRs that haven't been merged yet, the Bart PR has a smarter function)

Transitions:

  • Alex defined first/last devices to work with. In Bart MP I shifted to a different mapping where everything happens on main_device (say 0), and we only ever switch devices for those stacks of encoder/decoder layers that repeat, but all the helping params remain on device 0, which greatly simplifies things.

  • So when we pass data to the parallelized model we .to(main_device) and most of the layers are already on the main_device, so now we only need to switch devices when the stacks end. So if you take the following map:

device_map = {
        "encoder": {0: [0, 1, 2, 3, 4, 5] },
        "decoder": {1: [0, 1, 2, 3, 4, 5] },
    }

Here one only need to change devices twice

  1. once when switching between encoder.5 and encoder.0 and
  2. once more when returning from forward of decoder.5,

but of course, since the user may choose to split them vertically as so:

device_map = {
        "encoder": {0: [0, 1, 2],
                    1: [3, 4, 5],
                    },
        "decoder": {0: [0, 1, 2],
                    1: [3, 4, 5],
                    },
    }

there will be more switches here.

So with the automation of switching forward input to the desired device it's only a few surprises that one has to resolve, since each model has some unexpected needs.

Overall, with the great foundation @alexorona laid out and with a bit of the automation I added the implementation is solid and would work just fine for those who can afford idling gpus.

What we need to figure out next is how these idling gpus will co-operate with all the other great components we have been working on (fairscale/deepspeed/pytorch pipelines/etc.)

@julien-c
Copy link
Member

julien-c commented Jan 6, 2021

Great recap @stas00

@stas00
Copy link
Contributor

stas00 commented Jan 7, 2021

update: I made t5 work with HF trainer and --model_parallel in eval mode #9323 - needed to copy the outputs back to the first device - it's more or less fine in the training stage (it worked in the first place), but w/ beam search size 4 it's 10x slower on eval w/ MP than w/o MP - it gets hit badly by the back-n-forth data copying.

@stas00
Copy link
Contributor

stas00 commented Jan 11, 2021

The more I'm reading on various Parallelization strategies the more I see how confusing the terminology is.

What's most call Model Parallel (MP) should probably be called "Model Distributed" - since all we are doing here is splitting the model across several GPUs, as such "Model Distributed" is a much closer to reality term.

Next comes Pipeline Parallelism (PP) - where we split the mini-batch into micro-batches and feed into Model Parallel / Model Distributed, so that while a GPU that completed its forward idles waiting for other GPUs to compute their chunks of layers of the model and backprop, it can start on a new input. It is a Pipeline for sure, is this parallel though - I have a hard time calling it Parallel, since all the ops are sequential still.

It's much easier to understand this by studying this diagram from the GPipe paper

mp-pp

This diagram makes it very clear why what we have implemented is what it calls a a naive MP, and you can see the huge idling with 4 GPUs.

It then shows how it tries to resolve this idling problem with Pipeline. There is still idling but less so.

It also misrepresents the length of time forward and backward paths take. From asking the experts in general backward is ~2x slower than forward. But as I was corrected on slack, the length of the bubble is about the same regardless of their execution speed. (Thanks @deepakn94)

And Deepak also stressed out that since with PP there is a splitting into micro-batches, the effective batch size has to be big enough, otherwise PP will be idling too - so it requires experimentation to find a good batch size.

Bottom line, PP is an improved version of MP, according to my current understanding. I'm still still researching.

I think the real Parallelization is the ZeRO paper where Sharding/Partitioning is done and then it's truly parallel processing, but I'm still trying to understand what exactly is going on there. (Need to find a good diagram visually showing what it does) Grr, I see others use sharding/partitioning as a replacement for parallelism... so confusing.

I updated #8771 (comment) with resources on PP and next need to try to convert perhaps t5 to PP and see how it works in practice. There will be issues to overcome due to BN and tied weights.

@stas00
Copy link
Contributor

stas00 commented Jan 12, 2021

@deepakn94 helped me to finally grasp ZeRO-powered data parallelism, as it's described on this diagram from this blog post
DeepSpeed-Image-1

So it's quite simple conceptually, this is just your usual DataParallel (DP), except, instead of replicating the full model params, gradients and optimizer states, each gpu stores only a slice of it. And then at run-time when the full layer params are needed just for the given layer, all gpus sync to give each other parts that they miss - this is it.

Consider this simple model with 3 layers and each layer has 3 params:

La | Lb | Lc
---|----|---
a0 | b0 | c0
a1 | b1 | c1
a2 | b2 | c2

Lx being the layer and we have 3 layers, and ax being the weights - 3 weights

If we have 3 GPUs, the Sharded DDP (= Zero DP) splits the model onto 3 GPUs like so:

GPU0:
La | Lb | Lc
---|----|---
a0 | b0 | c0

GPU1:
La | Lb | Lc
---|----|---
a1 | b1 | c1

GPU2:
La | Lb | Lc
---|----|---
a2 | b2 | c2

In a way this is horizontal slicing, if you imagine the typical DNN diagram. Vertical slicing is where one puts whole layer-groups on different GPUs. But it's just the starting point.

Now each of these GPUs will get the usual mini-batch as it works in DP:

x0 => GPU0
x1 => GPU1
x2 => GPU2

The inputs are unmodified - they think they are going to be processed by the normal model.

So the inputs first hit the first layer La.

Let's focus just on GPU0: x0 needs a0, a1, a2 params to do its forward path, but GPU0 has only a0 - so what it does is it gets sent a1 from GPU1 and a2 from GPU2. Now the forward step can happen.

In parallel GPU1 gets mini-batch x1 and it only has a1, but needs a0 and a2 params, so it gets those from GPU0 and GPU2.

Same happens to GPU2 that gets input x2. It gets a0 and a1 from GPU0 and GPU1.

As soon as the calculation is done, the data that is no longer needed gets dropped - it's only used during the calculation.

The same is repeated at every other stage.

And the whole larger thing is repeated for layer Lb, then Lc forward-wise, and then backward Lc -> Lb -> La.

To me this sounds like an efficient group backpacking weight distribution strategy:

  1. person A carries the tent
  2. person B carries the stove
  3. person C carries the entertainment system

Now each night they all share what they have with others and get from others what the don't have, and in the morning they pack up their allocated type of gear and continue on their way. This is Sharded DDP / Zero DP.

Compare this strategy to the simple one where each person has to carry their own tent, stove and entertainment system, which would be far more inefficient. This is DataParallel in pytorch.

And I think pretty much everywhere I read Sharded == Partitioned, so I think those are synonyms in the context of distributed models.

@stas00
Copy link
Contributor

stas00 commented Jan 13, 2021

edit: 2021-02-15: Note that finetune_trainer.py was moved to examples/legacy/seq2seq/, and there is a new script run_seq2seq.py that took over finetune_trainer.py, you will find transition notes here

The simplest way to quickly reproduce the following is to switch to the transformers sha of the time this was posted, that is:

git clone https://github.com/huggingface/transformers
cd transformers
git checkout 7e662e6a3be0ece4 

The amazing discovery of the day is DeepSpeed's Zero-Offload. ZeRO-Offload is a ZeRO optimization that offloads the optimizer memory and computation from the GPU to the host CPU.

You can use DeepSpeed with a single GPU and train with huge models that won't normally fit onto a single GPU.

First let's try to finetune the huge t5-3b with a 24GB rtx-3090:

export BS=1; rm -r output_dir; CUDA_VISIBLE_DEVICES=0 PYTHONPATH=../../src USE_TF=0 ./finetune_trainer.py \
--model_name_or_path t5-3b --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval \
--do_predict --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 \
--logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 \
--overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate \
--eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 \
--val_max_target_length 128 --warmup_steps 5 --n_train 60 --n_val 10 --n_test 10 --fp16

No cookie, even with BS=1

RuntimeError: CUDA out of memory. Tried to allocate 64.00 MiB (GPU 0; 23.70 GiB total capacity; 21.37 GiB already allocated; 45.69 MiB free; 22.05 GiB reserved in total by PyTorch)

Now update your transformers to master, then install deepspeed:

pip install deepspeed

and let's try again:

export BS=20; rm -r output_dir; CUDA_VISIBLE_DEVICES=0 PYTHONPATH=../../src USE_TF=0 deepspeed --num_gpus=1 \
./finetune_trainer.py --model_name_or_path t5-3b --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro \
--do_eval --do_predict --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 \
--logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 \
--overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate \
--eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 \
--val_max_target_length 128 --warmup_steps 5 --n_train 60 --n_val 10 --n_test 10 --deepspeed ds_config_1gpu.json --fp16

et voila! we get a BS=20 trained just fine. I can probably push BS even further. It OOMed at BS=30.

2021-01-12 19:06:31 | INFO | __main__ |   train_n_objs = 60
2021-01-12 19:06:31 | INFO | __main__ |   train_runtime = 8.8511
2021-01-12 19:06:35 | INFO | __main__ |   val_n_objs = 10
2021-01-12 19:06:35 | INFO | __main__ |   val_runtime = 3.5329
2021-01-12 19:06:39 | INFO | __main__ |   test_n_objs = 10
2021-01-12 19:06:39 | INFO | __main__ |   test_runtime = 4.1123

Amazing!

Important note - I used CUDA_VISIBLE_DEVICES=0 to single out one gpu, but deepspeed has a bug now where it ignores that env var, so it'll be using the first GPU instead. microsoft/DeepSpeed#662 But hoping it will get fixed eventually.

The config file ds_config_1gpu.json is:

{
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "zero_optimization": {
        "stage": 2,
       "allgather_partitions": true,
       "allgather_bucket_size": 2e8,
       "reduce_scatter": true,
       "reduce_bucket_size": 2e8,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "cpu_offload": true
    },

    "optimizer": {
        "type": "Adam",
        "params": {
            "adam_w_mode": true,
            "lr": 3e-5,
            "betas": [ 0.9, 0.999 ],
            "eps": 1e-8,
            "weight_decay": 3e-7
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 3e-5,
            "warmup_num_steps": 500
        }
    }
}

I had to lower the ZeRO buffers from the default 5e8 to 2e8, otherwise it was OOM'ing even on BS=1.

important: DeepSpeed made some changes in the non-released version as of this writing and so the above config won't work anymore. It dropped adam_w_mode and added a proper AdamW optimizer (it was always there, but just not exposed normally), so replace that section with:

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 3e-5,
            "betas": [ 0.9, 0.999 ],
            "eps": 1e-8,
            "weight_decay": 3e-7
        }
    },

And it's not optimized yet, I just found at least one config that worked for this simple proof-of-concept test.

Go and check it out!

edit: I was asked about RAM usage for this task, it was 71GB peak, I re-run the same command as above with:
/usr/bin/time -v before deepspeed and got:

        User time (seconds): 117.12
        System time (seconds): 53.46
        Percent of CPU this job got: 122%
        Elapsed (wall clock) time (h:mm:ss or m:ss): 2:19.38
        Average shared text size (kbytes): 0
        Average unshared data size (kbytes): 0
        Average stack size (kbytes): 0
        Average total size (kbytes): 0
        Maximum resident set size (kbytes): 70907544
        Average resident set size (kbytes): 0
        Major (requiring I/O) page faults: 3245
        Minor (reclaiming a frame) page faults: 31346864
        Voluntary context switches: 16348
        Involuntary context switches: 52489
        Swaps: 0
        File system inputs: 1402864
        File system outputs: 11143504
        Socket messages sent: 0
        Socket messages received: 0
        Signals delivered: 0
        Page size (bytes): 4096
        Exit status: 0

So the peak RSS entry is 71GB:

        Maximum resident set size (kbytes): 70907544

The doc is here: https://huggingface.co/transformers/master/main_classes/trainer.html#deepspeed
And it's already slightly outdated - I need to modify it to cover that it works with single GPUs too!

@alexorona, I think you'd be super-happy about this one.

p.s. if you need to setup the dir and the data, first do:

git clone https://github.com/huggingface/transformers/
cd transformers/
cd examples/seq2seq
wget https://cdn-datasets.huggingface.co/translation/wmt_en_ro.tar.gz
tar -xzvf wmt_en_ro.tar.gz

before running any of the above scripts.

Oh, and I'm on pytorch-nightly since that's the only version that works at the moment with rtx-3090.

@stas00
Copy link
Contributor

stas00 commented Jan 13, 2021

edit: 2021-02-15: Note that finetune_trainer.py was moved to examples/legacy/seq2seq/, and there is a new script run_seq2seq.py that took over finetune_trainer.py, you will find the transition notes here

The simplest way to quickly reproduce the following is to switch to the transformers sha of the time this was posted, that is:

git clone https://github.com/huggingface/transformers
cd transformers
git checkout 7e662e6a3be0ece4 

OK and to finish the day here are some benchmarks - thank you @sgugger for letting me run those on your machine with dual titan rtx.

Let's start with the results table:

Method max BS train time eval time
baseline 16 30.9458 56.3310
fp16 20 21.4943 53.4675
sharded_ddp 30 25.9085 47.5589
sharded_ddp+fp16 30 17.3838 45.6593
deepspeed w/o cpu offload 40 10.4007 34.9289
deepspeed w/ cpu offload 50 20.9706 32.1409

Baseline + data setup was:

git clone https://github.com/huggingface/transformers/
cd transformers/
cd examples/seq2seq
wget https://cdn-datasets.huggingface.co/translation/wmt_en_ro.tar.gz
tar -xzvf wmt_en_ro.tar.gz
export BS=16; rm -r output_dir; PYTHONPATH=../../src USE_TF=0  python -m torch.distributed.launch \
--nproc_per_node=2 ./finetune_trainer.py --model_name_or_path t5-large --output_dir output_dir \
--adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_train --evaluation_strategy=steps --freeze_embeds \
--label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 \
--max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS \
--per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler \
--task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 500 \
--n_train 2000 --n_val 500

Notes:

  • We are doing a small train=2000, eval=500 items to do the comparisons. Eval does by default beam search size=4, so it's slower than training with the same number of samples, that's why I used 4x less eval items
  • task: translation
  • model: t5-large
  • We have 2x 24GB GPUs
  • DeepSpeed wasn't really designed for evaluation according to its developers but you can see it rocks there too.

Results: Well, Deepspeed beats all solutions that were compared - it's much faster and can fit much bigger batches into the given hardware. as you can see from the previous post #8771 (comment) - the cpu offloading while is slower on training it can fit more into your hardware. and it's the winner for eval!

Note: these benchmarks aren't perfect as they take a lot of time to handle you can see that BS numbers are pretty rounded - surely they can be somewhat bigger and speed somewhat better as a result, so I'm sure both sharded ddp and deepspeed can be optimized further.

But that's a good start. As both sharded ddp and deepspeed are now in master https://huggingface.co/transformers/master/main_classes/trainer.html#trainer-integrations please go ahead and do your own benchmarks.

And now the raw results - sorry it's not markdown'ed:


# setup

conda install -y pytorch==1.7.1 torchvision cudatoolkit=10.2 -c pytorch
pip install deepspeed fairscale

# versions

PyTorch version: 1.7.1
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.1 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.0.130
GPU models and configuration:
GPU 0: TITAN RTX
GPU 1: TITAN RTX

Nvidia driver version: 450.102.04
cuDNN version: Probably one of the following:
/usr/local/cuda-10.2/targets/x86_64-linux/lib/libcudnn.so.7.6.5

transformers_version": "4.2.0dev0", (master)

# baseline


max that I could fit was BS=16

export BS=16; rm -r output_dir; PYTHONPATH=../../src USE_TF=0  python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py --model_name_or_path t5-large --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 500 --n_train 2000 --n_val 500


01/13/2021 05:31:19 - INFO - __main__ -     train_runtime = 30.9458
01/13/2021 05:32:15 - INFO - __main__ -     val_bleu = 25.8269
01/13/2021 05:32:15 - INFO - __main__ -     val_runtime = 56.331

# w/ --fp16

could fit BS=20

export BS=20; rm -r output_dir; PYTHONPATH=../../src USE_TF=0  python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py --model_name_or_path t5-large --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 500 --n_train 2000 --n_val 500 --fp16

01/13/2021 05:33:49 - INFO - __main__ -     train_runtime = 21.4943
01/13/2021 05:34:42 - INFO - __main__ -     val_bleu = 25.7895
01/13/2021 05:34:42 - INFO - __main__ -     val_runtime = 53.4675


------------------------------------------------

# w/ --sharded_ddp

to compare with BS=20

export BS=20; rm -r output_dir; PYTHONPATH=../../src USE_TF=0  python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py --model_name_or_path t5-large --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 500 --n_train 2000 --n_val 500 --sharded_ddp


01/13/2021 06:26:11 - INFO - __main__ -     train_runtime = 28.9404
01/13/2021 05:36:16 - INFO - __main__ -     val_bleu = 25.7201
01/13/2021 05:36:16 - INFO - __main__ -     val_runtime = 55.0909

but can fit more now, so same with BS=30

export BS=30; rm -r output_dir; PYTHONPATH=../../src USE_TF=0  python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py --model_name_or_path t5-large --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 500 --n_train 2000 --n_val 500 --sharded_ddp

01/13/2021 06:28:02 - INFO - __main__ -     train_runtime = 25.9085
01/13/2021 05:39:08 - INFO - __main__ -     val_bleu = 25.7178
01/13/2021 05:39:08 - INFO - __main__ -     val_runtime = 47.5589


# w/ --sharded_ddp --fp16

export BS=20; rm -r output_dir; PYTHONPATH=../../src USE_TF=0  python -m torch.distributed.launch --nproc_per_node=2 ./finetune_trainer.py --model_name_or_path t5-large --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 500 --n_train 2000 --n_val 500 --sharded_ddp --fp16

01/13/2021 06:29:08 - INFO - __main__ -     train_runtime = 21.4775
01/13/2021 05:41:39 - INFO - __main__ -     val_bleu = 25.7162
01/13/2021 05:41:39 - INFO - __main__ -     val_runtime = 53.2397

but can fit more now, so same with BS=30

01/13/2021 06:30:03 - INFO - __main__ -     train_runtime = 17.3838
01/13/2021 05:43:56 - INFO - __main__ -     val_bleu = 25.7314
01/13/2021 05:43:56 - INFO - __main__ -     val_runtime = 45.6593

# w/ --deepspeed ds_config.json (stage 2 w/o cpu offloading)

I changed the config file to:

       "cpu_offload": false

export BS=40; rm -r output_dir; PYTHONPATH=../../src USE_TF=0  deepspeed ./finetune_trainer.py --model_name_or_path t5-large --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 500 --n_train 2000 --n_val 500 --deepspeed ds_config.json

01/13/2021 06:32:35 - INFO - __main__ -     train_runtime = 10.4007
01/13/2021 06:33:10 - INFO - __main__ -     val_bleu = 25.9687
01/13/2021 06:33:10 - INFO - __main__ -     val_runtime = 34.9289


# w/ --deepspeed ds_config.json (stage 2 w/ cpu offloading)

if we lower the buffers to `1.5e8` and enable cpu offloading:

       "allgather_bucket_size": 1.5e8,
       "reduce_bucket_size": 1.5e8,
       "cpu_offload": true

we can get to BS=50!

BS=50 rm -r output_dir; PYTHONPATH=../../src USE_TF=0  deepspeed ./finetune_trainer.py --model_name_or_path t5-large --output_dir output_dir --adam_eps 1e-06 --data_dir wmt_en_ro --do_eval --do_train --evaluation_strategy=steps --freeze_embeds --label_smoothing 0.1 --learning_rate 3e-5 --logging_first_step --logging_steps 1000 --max_source_length 128 --max_target_length 128 --num_train_epochs 1 --overwrite_output_dir --per_device_eval_batch_size $BS --per_device_train_batch_size $BS --predict_with_generate --eval_steps 25000  --sortish_sampler --task translation_en_to_ro --test_max_target_length 128 --val_max_target_length 128 --warmup_steps 500 --n_train 2000 --n_val 500 --deepspeed ds_config.json

01/13/2021 06:40:51 - INFO - __main__ -     train_runtime = 20.9706
01/13/2021 06:41:23 - INFO - __main__ -     val_bleu = 25.9244
01/13/2021 06:41:23 - INFO - __main__ -     val_runtime = 32.1409

I'm pretty sure if the buffers are even smaller it could do even higher BS. But it's late and I'm going to sleep.

Here is the config file that was used for deepspeed: https://github.com/huggingface/transformers/blob/69ed36063a732c37fdf72c605c65ebb5b2e85f44/examples/seq2seq/ds_config.json

@stas00 stas00 added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Apr 15, 2021
@stas00 stas00 self-assigned this Apr 15, 2021
@huggingface huggingface deleted a comment from github-actions bot Apr 15, 2021
@tscholak
Copy link

tscholak commented Apr 23, 2021

Hi, a bit of time has passed, and it seems some information here is outdated. If possible, could someone please describe what is necessary in order to train a T5-3b or T5-11b model on 1 or more 32GB or 40GB GPUs and with a sequence length in the input of up to 512 and up to 256 for the target? Has this been achieved?

Are additional pieces of configuration necessary for model parallelism or is the deepspeed wrapper somehow triggering model parallelism in the hf trainer?

My observations so far have been that T5 training is very unstable with --fp16 and torch.distributed.launch, and I am not sure that deepspeed can overcome this problem. Could anyone comment on the training stability? So far this conversation has mostly touched on avoiding OOM while the aspect of training results has not been given much attention.

Thank you!

EDIT: I would also be thankful for an explanation for why smaller buffer sizes enable larger batch sizes.

@stas00
Copy link
Contributor

stas00 commented Apr 23, 2021

Hi, a bit of time has passed, and it seems some information here is outdated. If possible, could someone please describe what is necessary in order to train a T5-3b or T5-11b model on 1 or more 32GB or 40GB GPUs and with a sequence length in the input of up to 512 and up to 256 for the target? Has this been achieved?

I'm pretty sure it should be possible, certainly with t5-3b, with t5-11b I will have to try.
Please let me know what is not working for you (exact command) and I can try to help tune it up.

And if you have access to NVMe you can train even larger models with DeepSpeed ZeRO-Infinity. Just give me a few more days to finalize the ZeRO-Infinity integration into transformers. This is all very new and their docs are very lacking still, but it will be fixed, so I'm trying to gather the information needed to take advantage of it, as it's not trivial to configure - need to run a benchmark first.

In the good news you can extend your CPU memory with any storage, it just might be very slow if the storage is slow :)

Are additional pieces of configuration necessary for model parallelism or is the deepspeed wrapper somehow triggering model parallelism in the hf trainer?

We don't use the parallelism from Deepspeed, but mainly its ZeRO features, which more or less allow one not to worry about parallelism and be able to train huge models. Parallelism requires huge changes to the models.

My observations so far have been that T5 training is very unstable with --fp16 and torch.distributed.launch, and I am not sure that deepspeed can overcome this problem. Could anyone comment on the training stability? So far this conversation has mostly touched on avoiding OOM while the aspect of training results has not been given much attention.

Yes, all bf16-pretrained models are, please see: https://discuss.huggingface.co/t/compiling-data-on-how-models-were-pre-trained-fp16-fp32-bf16/5671
They weren't meant to be used under fp16 mixed precision.

You will find a handful of issues wrt Nan/Inf in t5 and mt5.

You can try this workaround I experimented with: #10956
It seems to overcome a big part of instability in mt5, but one person reported a problem after an extensive run.

If you have access to Ampere-based cards (rtx-3090/A100), please see: #11076 (comment)
This is not yet in deepspeed master, but soon they will have fp32 mode, which will be equivalent to v100 fp16 since it'd use TF32 on those Ampere cards.

@tscholak
Copy link

Hi @stas00, thanks for the prompt response.

Am I understanding correctly that deepspeed with T5 is inadvisable at the moment because until deepspeed supports FP32 it will use FP16 which will destroy the T5 model?

@stas00
Copy link
Contributor

stas00 commented Apr 24, 2021

Most complaints were mainly about mt5 and not t5 as of recent,

@PeterAJansen, could you please comment here since I know at some point you were extensively working with t5-11b w/ deepspeed - did you run into nan/inf problems there?

I asked @samyam to make a PR from his full-fp32 branch https://github.com/microsoft/DeepSpeed/tree/samyamr/full-precision-for-stage3, but you can already use it. gpt-neo folks appear to have successfully started using it to overcome the over/underflow issue.

@PeterAJansen
Copy link

@stas00 it's a good question. I only became aware of the potential T5 fp16 issue recently, and I haven't noticed anything wonky in the models that I've been training -- but that's not to say that everything I've trained might be underperforming and able to perform vastly better, since I've been training models on new tasks rather than existing ones.

To verify things are running as expected, I should probably run an fp16 version of a common dataset task that (ideally) could be trained and evaluated in less than a day. Any suggestions from the examples section?

@stas00
Copy link
Contributor

stas00 commented Apr 24, 2021

Thank you for sharing your experience, @PeterAJansen. I mostly encountered reports with mt5 as of recent.

Since you own A100s (and those with RTX-3090) it shouldn't be too long before pytorch and deepspeed support native bf16 mixed precision, as both are actively working on adding this support. Once there, the NaN issue is expected to disappear in all bf16-pretrained models when they are finetuned/eval'ed in the same mode. So if you aren't in a rush and don't have a deadline to meet, I'd say just wait a bit longer and nothing needs to be done.

@Moldoteck
Copy link

Have you managed to use activation checkpointing?

@stas00
Copy link
Contributor

stas00 commented May 17, 2021

Have you managed to use activation checkpointing?

Would be happy to follow up, but such kind of questions are impossible to answer. Who is "you"? In what context? What is the problem?

May I suggest opening a new Issue and providing full context and the exact problem you're dealing with or a need you have? Thank you!

@sacombs
Copy link

sacombs commented Jul 22, 2021

Hi @stas00,

Thanks for all your contributions with deepzero integration. I find it fascinating and awesome!

According to your comments, it doesnt seem like deepspeed is able to use model parallelism (not data parallelism). Does this make it impossible to use t5-3b on an nvidia v100 16G 8 gpu card? I have tried a couple of different configurations of deepzero stage 3, including the provided configuration in master; however, I am only able to use a batchsize of 1 or 2. I am using a max sequence length of 512 for both input and output. I can achieve these same results if I use model.parallelism and split t5 across the 8 gpus.

Thanks!

@stas00
Copy link
Contributor

stas00 commented Jul 22, 2021

In general:

  1. Deepspeed can do 3D: PP+TP+DP no problem please see https://huggingface.co/transformers/master/parallelism.html
    The problem is that HF transformers currently supports only the naive PP for gpt2/t5, i.e. the limitation is on our side.
    The plan is to implement TP first and then eventually PP. (update: DS doesn't currently do TP, only supports it via MPU, but they are working on it)

  2. ZeRO is a completely different approach to scaling which when used with the fast interconnects performs on par with 3D parallelism. The key is that it doesn't require changes to the model (well, sometimes very minor changes). That's why we eagerly adopted Deepspeed as the easy scalability solution.

Now to your specific setup. Offloading some of the memory should do the trick.

Here is some helpful API to estimate the memory needs for params, optim states and gradients: https://deepspeed.readthedocs.io/en/latest/memory.html#api-to-estimate-memory-usage It still is missing the activations and temps memory needs but it already gives you a pretty good picture of which configuration to pick:

Zero2

python -c 'from transformers import AutoModel; \
from deepspeed.runtime.zero.stage2 import estimate_zero2_model_states_mem_needs_all_live; \
model = AutoModel.from_pretrained("t5-3b"); \
estimate_zero2_model_states_mem_needs_all_live(model, num_gpus_per_node=8, num_nodes=1)'
Estimated memory needed for params, optim states and gradients for a:
HW: Setup with 1 node, 8 GPUs per node.
SW: Model with 2851M total params.
  per CPU  |  per GPU |   Options
  127.48GB |   5.31GB | cpu_offload=1
  127.48GB |  15.93GB | cpu_offload=0

Zero3

python -c 'from transformers import AutoModel; \
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live; \
model = AutoModel.from_pretrained("t5-3b"); \
estimate_zero3_model_states_mem_needs_all_live(model, num_gpus_per_node=8, num_nodes=1)'

Estimated memory needed for params, optim states and gradients for a:
HW: Setup with 1 node, 8 GPUs per node.
SW: Model with 2851M total params, 32M largest layer params.
  per CPU  |  per GPU |   Options
   71.71GB |   0.12GB | cpu_offload=1, cpu_offload_params=1, zero_init=1
  127.48GB |   0.12GB | cpu_offload=1, cpu_offload_params=1, zero_init=0
   63.74GB |   0.79GB | cpu_offload=1, cpu_offload_params=0, zero_init=1
  127.48GB |   0.79GB | cpu_offload=1, cpu_offload_params=0, zero_init=0
    1.47GB |   6.10GB | cpu_offload=0, cpu_offload_params=0, zero_init=1
  127.48GB |   6.10GB | cpu_offload=0, cpu_offload_params=0, zero_init=0

So you can see that if you have a nice chunk of CPU memory available, it should be trivial for you to load a large bs with large seqlen.

and this was written pre-NVMe offload addition, so you have that option too if you don't have much CPU memory, but consider it as an extension of CPU memory so the above numbers will still be the same gpu memory-wise.

p.s. Megatron-LM has just added t5 to their arsenal, but it lacks PP as of this writing.

@stas00
Copy link
Contributor

stas00 commented Jul 22, 2021

Yes, specific problem solving is best done in a dedicated thread. So let's continue there. Please tag me so that I see it.

@alexorona
Copy link
Contributor Author

alexorona commented Jul 25, 2021

@stas00 @sacombs Maybe there's two or three typical use cases we could articulate? After having studied the documentation and your threads on this Stas, I'm still only able to get models in the range of 1.5B parameters training on a single 16GB GPU. The advantage is that it uses far less GPU memory than it would normally take (about 30%), but it is 5 times slower. That's a very acceptable trade-off in terms of VM cost.

I haven't been able to effectively train large models like GPTNeo-2.7B and T5 using multiple GPUs. It seems like the deepspeed integration automatically creates a number of nodes/workers equal to the number of GPUs, so if you can't train it on one GPU, adding multiple GPUs makes no difference. I've tried with both zero3 and zero3-nvme configurations.

@stas00
Most of the big model use cases are around T5, GPTNeo and less frequently CTRL, DeBERTa and M2M100. T5 has a lot of use cases and GPTNeo is the most in-demand for generative tasks. Let's assume someone has a training script that cleans data, trains and evaluates. Training uses Trainer. Would it be possible to provide something like this:

Example 1: Fine-tuning t5-3B Using zero3 and zero3-nvme with Multiple GPUs

Requirements

  • Install deepspeed with pip install deepspeed, pip install transformers[deepspeed], or from source (see Installation)
  • Use zero3_config.json for zero3 and zero3_nvme_config.json for zero3_nvme
  • You'll need to run on Linux, as the preferred nccl backend that deepspeed uses is not supported on Windows. You cannot use WSL to get around this requirement.
  • You cannot use a Notebook like Google Colab or Jupyter because of how deepspeed initiates processes when multiple GPUs are used.
  • Create a training script that prepares your data and trains your model. To make this example work, the deepspeed configuration file must be passed to Trainer, e.g. trainer = Trainer(deepspeed = "zero3_config.json", ...)
  • It is best to keep most of the values in zero3_config.json or zero3_nvme.json on "auto" and use TrainingArguments to adjust the deepspeed configuration
  • For zero3: You'll need at least x GPU memory and x CPU memory for this example -- you might be able to get away with less GPU memory (see GPU OOM Messages below)
  • For zero3 with nvme: You'll need at least x GPU memory, x CPU memory and NVMe with about x spare GB for this example -- you might be able to get away with less GPU memory (see GPU OOM Messages below)

Running
Here's how to run it:
deepspeed -your_training_script.py <normal cl args> --deepspeed zero3_config.json

GPU OOM Messages
If you are running out of memory, here's what you can try tweaking:

  • Reduce batch_size passed to TrainingArguments
  • Reduce gradient_accumulation_steps passed to TrainingArguments
  • In the zero3_config.json or zero3_nvme_config.json file, reduce the size of the "stage3_max_live_parameters" and "stage3_max_reuse_distance"

Example 2: Fine-tuning EleutherAI/gpt-neo-1.3B Using zero3 on a Single GPU

Requirements

  • Install deepspeed with pip install deepspeed, pip install transformers[deepspeed], or from source (see Installation)
  • Use zero3_config.json
  • You'll need to run on Linux, as the preferred nccl backend that deepspeed uses is not supported on Windows. You cannot use WSL to get around this requirement.
  • It is possible to do this in a Notebook when using just one GPU. See Deployment in Notebooks below.
  • Create a training script that prepares your data and trains your model. To make this example work, the deepspeed configuration file must be passed to Trainer, e.g. trainer = Trainer(deepspeed = "zero3_config.json", ...)
  • It is best to keep most of the values in zero3_config.json or zero3_nvme.json on "auto" and use TrainingArguments to adjust the deepspeed configuration
  • For zero3: You'll need at least 16GB GPU memory and x CPU memory for this example -- you might be able to get away with less GPU memory (see GPU OOM Messages below)

Running
Here's how to run it:
deepspeed -your_training_script.py <normal cl args> --deepspeed zero3_config.json

GPU OOM Messages
If you are running out of memory, here's what you can try tweaking:

  • Reduce batch_size passed to TrainingArguments
  • Reduce gradient_accumulation_steps passed to TrainingArguments
  • In the zero3_config.json file, reduce the size of the "stage3_max_live_parameters" and "stage3_max_reuse_distance"

@stas00
Copy link
Contributor

stas00 commented Jul 27, 2021

That's a great idea, @alexorona! These would be super-useful.

Let's do it!

Do you want to also define the actual GPU sizes? It'd be very different if one uses 80GB A100 comparatively to 16GB V100.

Perhaps repasting each of these into a separate issue so that we could work on tuning these up independently?

Let's start with 2-3 and then we can expand it to more.

I'm a bit busy in the next few days with the bigscience first launch, but otherwise can work on it when I get some free time and we can of course ask the Deepspeed to help.

Once polished these would make a great article/blog_post.

@stas00
Copy link
Contributor

stas00 commented Aug 2, 2021

Just to update: I think we will get the best outcome if one or a few people with an actual need and hardware to match will post an issue and then we will work on solving it and while at it come up with the settings/guidelines for models in question.

Also I'm at the moment mostly busy with the bigscience project, which takes the lion's share of my time. So I'd be delighted to support someone with a need, but probably won't have enough incentive to carve out the time to act on both sides.

I hope this makes sense.

@ZeyiLiao
Copy link

ZeyiLiao commented Jun 2, 2022

Hi, I followed what you said here, but it said that "TypeError: issubclass() arg 1 must be a class".
And even I replace the finetuner.py with run_seq2seq.py, it still doesn't work.

@stas00
Copy link
Contributor

stas00 commented Jun 2, 2022

this is a very old thread, could you please open a proper new Issue with full details of what you did, versions, the full traceback and how we could reproduce the problem and please tag me. Thank you.

@dswang2011
Copy link

dswang2011 commented Sep 24, 2023

Based on a working Python model training script, I made the simplest changes with Trainer (by add deepspeed='ds_config,json'), but met the below error, any tips? I did not set local_rank at all, no idea why that error mentioned:

(pytorch_p39) ubuntu@ip-10-0-3-65:~/python_projects/TestGPT/src$ deepspeed pretrain.py --deepspeed /home/ubuntu/python_projects/TestGPT/src/config/ds_config.json [2023-09-24 20:52:14,971] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect) [2023-09-24 20:52:17,101] [WARNING] [runner.py:203:fetch_hostfile] Unable to find hostfile, will proceed with training with local resources only. [2023-09-24 20:52:17,101] [INFO] [runner.py:570:main] cmd = /home/ubuntu/anaconda3/envs/pytorch_p39/bin/python3.9 -u -m deepspeed.launcher.launch --world_info=eyJsb2NhbGhvc3QiOiBbMCwgMSwgMiwgM119 --master_addr=127.0.0.1 --master_port=29500 --enable_each_rank_log=None pretrain.py --deepspeed /home/ubuntu/python_projects/TestGPT/src/config/ds_config.json [2023-09-24 20:52:18,925] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect) [2023-09-24 20:52:20,642] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [0, 1, 2, 3]} [2023-09-24 20:52:20,642] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=4, node_rank=0 [2023-09-24 20:52:20,642] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0, 1, 2, 3]}) [2023-09-24 20:52:20,642] [INFO] [launch.py:163:main] dist_world_size=4 [2023-09-24 20:52:20,642] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=0,1,2,3 usage: pretrain.py [-h] [--config CONFIG_FILE] pretrain.py: error: unrecognized arguments: --local_rank=2 --deepspeed /home/ubuntu/python_projects/TestGPT/src/config/ds_config.json usage: pretrain.py [-h] [--config CONFIG_FILE] pretrain.py: error: unrecognized arguments: --local_rank=0 --deepspeed /home/ubuntu/python_projects/TestGPT/src/config/ds_config.json usage: pretrain.py [-h] [--config CONFIG_FILE] pretrain.py: error: unrecognized arguments: --local_rank=1 --deepspeed /home/ubuntu/python_projects/TestGPT/src/config/ds_config.json usage: pretrain.py [-h] [--config CONFIG_FILE] pretrain.py: error: unrecognized arguments: --local_rank=3 --deepspeed /home/ubuntu/python_projects/TestGPT/src/config/ds_config.json [2023-09-24 20:52:24,666] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 20335 [2023-09-24 20:52:24,674] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 20336 [2023-09-24 20:52:24,682] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 20337 [2023-09-24 20:52:24,682] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 20338 [2023-09-24 20:52:24,689] [ERROR] [launch.py:321:sigkill_handler] ['/home/ubuntu/anaconda3/envs/pytorch_p39/bin/python3.9', '-u', 'pretrain.py', '--local_rank=3', '--deepspeed', '/home/ubuntu/python_projects/TestGPT/src/config/ds_config.json'] exits with return code = 2 (pytorch_p39) ubuntu@ip-10-0-3-65:~/python_projects/TestGPT/src$

@stas00
Copy link
Contributor

stas00 commented Sep 25, 2023

As mentioned earlier please open a new Issue and for all deepspeed integration-related issues please tag @pacman100 who is the current maintainer of it. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Model Parallel Model Parallelilsm Implementations WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

No branches or pull requests