Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP16 overflow with GPT-Neo when using sequence lengths of 2048. #11076

Closed
2 of 4 tasks
LouisCastricato opened this issue Apr 6, 2021 · 62 comments
Closed
2 of 4 tasks

Comments

@LouisCastricato
Copy link

Environment info

  • transformers version: 4.5.0.dev0
  • Platform: Linux-5.4.0-54-generic-x86_64-with-glibc2.29
  • Python version: 3.8.5
  • PyTorch version (GPU?): 1.8.0+cu111
  • Tensorflow version (GPU?): N/A
  • Using GPU in script?: Yes
  • Using distributed or parallel set-up in script?: No

Who can help

@stas00

Models:

  • GPT-Neo 1.3b

Library:

Information

Model I am using (Bert, XLNet ...):

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. Use GPT-Neo 1.3b with The Pile dataset and built in trainer. Artificial data also suffices. It does not matter what the data is, as long as the attention mask spans all 2048 tokens.
  2. Enable FP16 and set max_length to 2048
  3. Observe that all loses reported are NaN

Also reproducible using AMP or DeepSpeed. It seems like there is code to circumvent this outlined in the GPT-Neo implementation where q,k,v are casted to fp32 in the attention block.

When the max_length is shorter (512) this overflow does not occur.

Expected behavior

I expected no overflows.

Aside

I'm reaching out on behalf of EleutherAI, Lysandre told us to create an issue about this.

@stas00
Copy link
Contributor

stas00 commented Apr 6, 2021

Thank you for the report, @LouisCastricato

I think it's pretty safe to take DeepSpeed out of the equation for now, since as you're saying the problem is due to mixed precision, so let's deal with AMP first.

How was GPT-Neo pre-trained?

Is it by chance the case that GPT-Neo was pre-trained with bfloat16 like t5/mt5 (#10956) or was it pre-trained in fp32?

@LouisCastricato
Copy link
Author

The 1.3b model was pretrained on TPUs in mesh-tf using fp16.

@stas00
Copy link
Contributor

stas00 commented Apr 6, 2021

You mean mixed precision fp16, correct?

As I haven't used mesh-tf - what would be the equivalent of this setup in the pytorch land? Since if we find the exact equivalent and the model was ported correctly and is used under the same setup - this problem shouldn't exist. does it make sense?

So let's find out what is different here (assuming the porting was done correctly).

@stas00
Copy link
Contributor

stas00 commented Apr 6, 2021

OK so bf16 and not fp16 - a very important difference. thank you for this correctlion, @leogao2

I just wrote about it today: https://discuss.huggingface.co/t/mixed-precision-for-bfloat16-pretrained-models/5315

I will try to look at it tomorrow, this is probably the same story as t5/mt5 then.

@stas00
Copy link
Contributor

stas00 commented Apr 6, 2021

It'd help to save time if you had a ready way to reproduce the problem, I tried:

export BS=1; rm -rf /tmp/test-clm; PYTHONPATH=src USE_TF=0 CUDA_VISIBLE_DEVICES=0 python examples/language-modeling/run_clm.py \
    --model_name_or_path EleutherAI/gpt-neo-1.3B \
    --dataset_name  wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --do_train \
    --max_train_samples 1 \
    --per_device_train_batch_size $BS \
    --output_dir /tmp/test-clm \
    --block_size 128 \
    --logging_steps 1 

It hardly fits onto a 24GB card with a tiny block size, and fp16 OOMs right away.

I don't suppose you have a smaller model to experiment with?

Straightforward generate in full fp16 seems to work fine on a single sample to a full max_length, so this is good.

from transformers import GPTNeoForCausalLM, GPT2Tokenizer
model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")

prompt = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
         "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
         "researchers was the fact that the unicorns spoke perfect English."

inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs.input_ids.to("cuda:0")
model = model.half().to("cuda:0")

gen_tokens = model.generate(input_ids. do_sample=True, temperature=0.9, max_length=2048,)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
print(gen_text)

Thanks.

@LouisCastricato
Copy link
Author

We are working on producing a minimal example for you currently. After checking our internal documents we realized that 1.3b is bfp16 where as 2.7b is fp32

@LouisCastricato
Copy link
Author

If you need an A100 to test on, let us know.

@LysandreJik
Copy link
Member

LysandreJik commented Apr 6, 2021

Hi! As we're doing a few changes to the implementation to make it cleaner over in #10985, we ran a quick training to ensure that the model could still train.

We leveraged @Xirider's script detailed in https://github.com/Xirider/finetune-gpt2xl in order to fine-tune the 1.3B checkpoint, and we did see a decrease in the loss over this small sample:
image

We didn't investigate further, but this allows to fine-tune the 1.3B variant on a single V100 GPU.

cc @patil-suraj

@LouisCastricato
Copy link
Author

That was sequence length 2048?

@patil-suraj
Copy link
Contributor

That was sequence length 2048?

It's 1024 on wikitext

@LysandreJik
Copy link
Member

Thanks for pointing that out, it was 1024. The tokenizer configurations on the hub were ill-configured to have a model_max_length set to 1024, I've updated them to have the correct 2048.

I added a --block_size=2048 parameter, see below the training loss:

image

It is slightly higher, but isn't a NaN!

@LouisCastricato
Copy link
Author

Hm... Maybe our project is just cursed then. Thanks for the pointer, I'll go through installations and see if anything is weird.

@LysandreJik
Copy link
Member

I ran the fine-tuning on the recent branch so I thought this might be it; but I just tested on master and I don't get any NaNs either.
Don't hesitate to tell us if we can help further.

@stas00
Copy link
Contributor

stas00 commented Apr 6, 2021

I'm running this on 24GB rtx-3090 and while it's not converging it's not getting NaNs:

git clone https://github.com/huggingface/transformers
cd transformers
git clone finetune-gpt2xl
rm -rf output_dir; PYTHONPATH=src USE_TF=0 deepspeed --num_gpus=1 examples/language-modeling/run_clm.py \
--deepspeed finetune-gpt2xl/ds_config_gptneo.json \
--model_name_or_path EleutherAI/gpt-neo-1.3B \
--train_file finetune-gpt2xl/train.csv \
--validation_file finetune-gpt2xl/validation.csv \
--do_train \
--do_eval \
--fp16 \
--overwrite_cache \
--evaluation_strategy="steps" \
--output_dir output_dir \
--num_train_epochs 1 \
--eval_steps 15 \
--gradient_accumulation_steps 2 \
--per_device_train_batch_size 1 \
--use_fast_tokenizer False \
--learning_rate 5e-06 \
--warmup_steps 10 --logging_steps 5 --block_size 2048

@EricHallahan
Copy link
Contributor

It looks like the version of DeepSpeed we are running (0.3.11) prevents us from running that example on our hardware. We are in the process of updating DeepSpeed to a newer version (>0.3.12) so that it is not caught by line 287 of integrations.py.

@stas00
Copy link
Contributor

stas00 commented Apr 6, 2021

I'm able to reproduce loss=Nan while testing deepspeed zero-3 with this:

git clone https://github.com/huggingface/transformers
cd transformers
git clone https://github.com/Xirider/finetune-gpt2xl
# create finetune-gpt2xl/ds_config_gptneo_zero3.json as shown below
BS=2; rm -rf output_dir; PYTHONPATH=src USE_TF=0 deepspeed --num_gpus=2 examples/language-modeling/run_clm.py \
--deepspeed finetune-gpt2xl/ds_config_gptneo_zero3.json --model_name_or_path EleutherAI/gpt-neo-1.3B \
--train_file finetune-gpt2xl/train.csv --validation_file finetune-gpt2xl/validation.csv --do_train --do_eval --fp16 \
--overwrite_cache --evaluation_strategy="steps" --output_dir output_dir --num_train_epochs 1 --eval_steps 15 \
--gradient_accumulation_steps 2 --per_device_train_batch_size $BS --per_device_train_batch_size $BS \
--use_fast_tokenizer False --learning_rate 9e-06 --warmup_steps 100 --logging_steps 5 --block_size 1048

around step 24/174.

except I'm using 2 uncommitted branches mentioned in #11044

I will try to reduce it something smaller.

p.s. for reproducibility purpose here is the config I used: finetune-gpt2xl/ds_config_gptneo_zero3.json

{
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "zero_optimization": {
        "stage": 3,
        "cpu_offload": true,
        "cpu_offload_params": true,
        "cpu_offload_use_pin_memory" : true,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e14,
        "reduce_bucket_size": 0,
        "stage3_prefetch_bucket_size": 0,
        "stage3_param_persistence_threshold": 0,
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_fp16_weights_on_model_save": true
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 3e-5,
            "betas": [0.8, 0.999],
            "eps": 1e-8,
            "weight_decay": 3e-7
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 3e-5,
            "warmup_num_steps": 500
        }
    },

    "steps_per_print": 2000,
    "wall_clock_breakdown": false
}

@LouisCastricato
Copy link
Author

Thank you!

@LouisCastricato
Copy link
Author

Does that happen with zero-2?

@stas00
Copy link
Contributor

stas00 commented Apr 7, 2021

Oddly enough it's fine with zero-2 in this particular setup, but the configurations aren't the same so we aren't comparing the same things.

But also if I do the same zero-3 training on one gpu there no nan either.

But that doesn't matter, as long as we have a way to reproduce nans it's good enough to start working on understanding the cause and then fixing it.

@samyam from DeepSpeed suggested an idea to try, so I'm going to go back to the mt5 which gets a NaN on the very first step and experiment with it first, since it's much faster than dealing with step 21 of this really heavy model. And then if it works will come back to gpt-neo.

If meanwhile you find a much faster way to get to NaNs that would be helpful.

@stas00
Copy link
Contributor

stas00 commented Apr 7, 2021

Also I don't know if this is somehow related, but this message looks alarming:

[WARNING|tokenization_utils_base.py:3136] 2021-04-07 00:05:14,268 >> Token indices sequence length is longer than the specified maximum sequence length for this model (1462828 > 2048). Running this sequence through the model will result in indexing errors

I think there is a bug somewhere, but it might be unrelated.

edit: I verified this is just a misplaced warning, not a problem

@StellaAthena
Copy link
Contributor

@stas00 an update:

We were able to run your code on both 125M and 1.3B models without issue. The loss goes down, we get Shakespearean language, all is good.

Unfortunately, we cannot use your code for our task. We are seeking to train a dual objective model with two complete different datasets. We have two datasets that we are mashing together and trying to train via contrastive loss. Unfortunately, it appears that using the HF trainer class makes that more or less impossible.

Is there a better way to do the pipelining, so we can evade whatever the bug we are running into is? We tried to run it on sequence length 1024, but it ended up eventually going to NaN anyways after a thousand steps or so.

@stas00
Copy link
Contributor

stas00 commented Apr 7, 2021

... we can evade whatever the bug we are running into is?

The NaNs appearance in this particular situation is not caused by a bug in either transformers or deepspeed.

The model was trained in one numerical range and you're trying to run it in a different range that it wasn't trained for - there is not too much that can be done here.

It's the same problem for any bfloat16-pretrained model. Which includes t5/mt5/pegasus to name a few.

The fine-tuning/inference should be done in the same environment it was trained in or an environment that the model can numerically translate to. This is not the case with bf16 vs fp16 - please refer to my commentary at https://discuss.huggingface.co/t/mixed-precision-for-bfloat16-pretrained-models/5315

What we are trying to do now is to find a workaround that will not provide a full mixed precision regime, but a partial one. For that we need to find which operations are safe to run in fp16 and which aren't. And unfortunately as you can see some of these "runaways" happen after thousands of steps.

We were able to run your code on both 125M and 1.3B models without issue.

Oh, fantastic! Just discovered https://huggingface.co/EleutherAI/gpt-neo-125M after you mentioned it - it'd be much easier to debug with. Thank you for that!

Which "your code" are you referring to? Trainer + run_clm.py?

I hear you that the HF Trainer is not suitable for your task. But if you have your own Trainer that works, why won't you use that instead? On other words how can we support you in this situation?

@stas00
Copy link
Contributor

stas00 commented Apr 7, 2021

What I'm going to do next is:

  • Try to see if deepspeed can be run in fp32 - apparently it was never tried by anyone since until now mixed fp16 just worked.

This proved to be not possible at the moment

  • Try to use the suggestion from samyam to only do fp16 matmul in ff layers with pre-scaling and post-unscaling since that's where the bulk of the processing happens
  • try to find a short example I can reproduce gpt-neo Nans with, because if I have to wait 10+min before it nans, it will be a very difficult going.

@LouisCastricato
Copy link
Author

The dual objective code we refer to can be found here: https://github.com/EleutherAI/visual-grounding

And ok sounds good. The offer for A100s still stands btw, fp32 might be a nightmare on an RTX 3090.

@stas00
Copy link
Contributor

stas00 commented Apr 7, 2021

And ok sounds good. The offer for A100s still stands btw, fp32 might be a nightmare on an RTX 3090.

Thank you for your offer, @LouisCastricato - You're very likely correct - I may take you up on that offer at a later time. I'm not planning on finetuning gpt-neo in fp32 on rtx-3090, but just to test that deepspeed can even run in fp32 on a small model. Because if it works you could at least do that.

The dual objective code we refer to can be found here: https://github.com/EleutherAI/visual-grounding

Yes, but I'm not sure what to do with this information. My guess is that you developed your own trainer and you're trying to integrate deepspeed into it and are running into issues? What is it specifically that you need to move forward with your project or what is blocking you?

@LouisCastricato
Copy link
Author

Oh apologies. I shared it so that you could see the configuration we're using. I think I might have accidentally deleted that part though (bigs thumbs and touchscreens)

Yes, we're trying to integrate DeepSpeed directly with our training code. Both ds_config.json and amp_config.json produce the same NaN error strictly on autoregressive batches- before the forward step. We have not seen the NaN error on the backwards step.

Therefore, since we do not see it on the other component of our dual objective (in this case is Google's WIT dataset) which has sequence lengths at most 128 tokens. We can see NaNs beginning to appear at sequence length 768 and once we get to 2048 its every batch that has NaNs.

@stas00
Copy link
Contributor

stas00 commented Apr 7, 2021

Thank you for clarifying that, @LouisCastricato

Understood. I will have to get to know this model as I have never worked with it. So I will comment once I had a chance to sit with it after I install all kinds of debug hooks into it.

wrt your config, it looks good.

    "allgather_bucket_size": 50000000,
    "reduce_bucket_size": 50000000,

these might be too small for an efficient operation. You want these to be in 2e8 to 5e8 range according to Samyam.

I also recommend you switch to the e-notation - it's too easy to miss a zero. In zero3 they have a param with 14 zeros!

You may want to enable cpu-offload if you have extra RAM.

Otherwise there isn't that much to configure in zero-2. There is a lot more to tune up in zero-3.

As I mentioned a few comments up, Deepspeed makes an efficient use of hardware, but if the model itself is an issue there is not much that changing Deepspeed configuration can do.

@LouisCastricato
Copy link
Author

Hi,

I was curious if there was any update on this?

@stas00
Copy link
Contributor

stas00 commented Apr 9, 2021

I was busy working on the DeepSpeed ZeRO-3 integration with transformers release, so I haven't had a chance to research this issue yet.

If I knew it was a quick fix I'd have done it right away, but this kind of a problem is a long process so I need to have uninterrupted time to work on it. Moreover, fixing it in AMP won't necessarily fix it in DeepSpeed (but it'd surely help).

I started working on the checklist, I'm aware that this is a big problem for you guys and I thought that perhaps at least you could run DeepSpeed in fp32, but, alas, currently it's not possible - you can disable fp16, but there are a lot of hardcoded half() calls in ZeRO-3 code so it basically ignores the fp16 setting at the moment and just does everything in fp16.

I doubt the DeepSpeed developers will take care of this any time soon as they have no resources to do so, so if you want to help that could be one task that might help to move things forward a bit - making Deepspeed work with fp32. Then the next stage would be to optimize the parts that can be done in fp16 w/o overflow leaving most of it in fp32. Samyam suggested the matmuls in FF layers would be the best part to do in fp16 as I mentioned some comments earlier.

Just to give you an idea, the process goes like this: I find something that doesn't work or is missing for transformers needs, I file an Issue, nothing happens for awhile, and since I need this for integration I just go and implement it, with the DeepSpeed team's guidance. If we don't do the work it will eventually happen, but that eventual might be in a very long time.

Let's hope they manage to expand their team with the recent job openings they posted and have more resources to support the projects that integrate their work.

I also asked them and all the models they have been working with were trained in mixed fp16 precision, so had no reason to sort out bfloat16 (yet).

So priorities-wise, will having full fp32-support be useful to you or not really?

@stas00
Copy link
Contributor

stas00 commented Apr 16, 2021

Yes, but large logits are a potential symptom of what's going on in the network.

I've just created a new debug tool that helps diagnosing the activation overflow issue, just waiting for review to complete, but if you want to try it sooner please grab this branch: #11274

and add --debug activation_overflow to the training cmd line. It will abort and dump the trace of the last 40 input/outputs of the forward calls preceding the inf/nan encountering, which should hopefully give an indication of where the problem is.

@stas00
Copy link
Contributor

stas00 commented Apr 17, 2021

I noticed a correction for that was already coded in the attention block (casting the query, key, value to fp32)

which could indicate an issue in the model design I think. If the original model can't even do the math in bfloat16 and needs to go a "level up" not to overflow, that means that near by activation values are huge and of course these same corrections were ported to the transformers version, except we are now running those under fp16. So I'd look really close around those "hot" spots where the upcasting to fp32 was done and check that activations around that forward path aren't very big numbers.

Or alternatively perhaps those were high precision numbers that bfloat16 couldn't handle, it'd be good to understand why the original was coded this way.

@LouisCastricato
Copy link
Author

LouisCastricato commented Apr 17, 2021

last 40 frames:
abs_max= 1.55e+01 < [2] module.lm.transformer.h.0.attn.attention.k_proj: Linear: output
abs_max= 2.64e+00 > [2] module.lm.transformer.h.0.attn.attention.v_proj: Linear: input[0]
abs_max= 1.05e+01 < [2] module.lm.transformer.h.0.attn.attention.v_proj: Linear: output
abs_max= 1.00e+00 > [2] module.lm.transformer.h.0.attn.attention.attn_dropout: Dropout: input[0]
abs_max= 1.00e+00 < [2] module.lm.transformer.h.0.attn.attention.attn_dropout: Dropout: output
abs_max= 9.81e+00 > [2] module.lm.transformer.h.0.attn.attention.out_proj: Linear: input[0]
abs_max= 2.18e+02 < [2] module.lm.transformer.h.0.attn.attention.out_proj: Linear: output
abs_max= 2.18e+02 > [2] module.lm.transformer.h.0.attn.attention.resid_dropout: Dropout: input[0]
abs_max= 2.18e+02 < [2] module.lm.transformer.h.0.attn.attention.resid_dropout: Dropout: output
abs_max= 2.64e+00 > [2] module.lm.transformer.h.0.attn.attention: GPTNeoSelfAttention: input[0]
abs_max= 2.18e+02 < [2] module.lm.transformer.h.0.attn.attention: GPTNeoSelfAttention: output[0]
abs_max= 1.55e+01 < [2] module.lm.transformer.h.0.attn.attention: GPTNeoSelfAttention: output[1][0]
abs_max= 1.05e+01 < [2] module.lm.transformer.h.0.attn.attention: GPTNeoSelfAttention: output[1][1]
abs_max= 2.64e+00 > [2] module.lm.transformer.h.0.attn: GPTNeoAttention: input[0]
abs_max= 2.18e+02 < [2] module.lm.transformer.h.0.attn: GPTNeoAttention: output[0]
abs_max= 1.55e+01 < [2] module.lm.transformer.h.0.attn: GPTNeoAttention: output[1][0]
abs_max= 1.05e+01 < [2] module.lm.transformer.h.0.attn: GPTNeoAttention: output[1][1]
abs_max= 2.18e+02 > [2] module.lm.transformer.h.0.ln_2: LayerNorm: input[0]
abs_max= 2.34e+00 < [2] module.lm.transformer.h.0.ln_2: LayerNorm: output
abs_max= 2.34e+00 > [2] module.lm.transformer.h.0.mlp.c_fc: Linear: input[0]
abs_max= 9.97e+00 < [2] module.lm.transformer.h.0.mlp.c_fc: Linear: output
abs_max= 9.63e+00 > [2] module.lm.transformer.h.0.mlp.c_proj: Linear: input[0]
abs_max= 2.16e+02 < [2] module.lm.transformer.h.0.mlp.c_proj: Linear: output
abs_max= 2.16e+02 > [2] module.lm.transformer.h.0.mlp.dropout: Dropout: input[0]
abs_max= 2.16e+02 < [2] module.lm.transformer.h.0.mlp.dropout: Dropout: output
abs_max= 2.34e+00 > [2] module.lm.transformer.h.0.mlp: GPTNeoMLP: input[0]
abs_max= 2.16e+02 < [2] module.lm.transformer.h.0.mlp: GPTNeoMLP: output
abs_max= 9.73e-01 > [2] module.lm.transformer.h.0: GPTNeoBlock: input[0]
abs_max= 3.94e+02 < [2] module.lm.transformer.h.0: GPTNeoBlock: output[0]
abs_max= 1.55e+01 < [2] module.lm.transformer.h.0: GPTNeoBlock: output[1][0]
abs_max= 1.05e+01 < [2] module.lm.transformer.h.0: GPTNeoBlock: output[1][1]
abs_max= 3.94e+02 > [2] module.lm.transformer.h.1.ln_1: LayerNorm: input[0]
abs_max= 1.86e+00 < [2] module.lm.transformer.h.1.ln_1: LayerNorm: output
abs_max= 1.86e+00 > [2] module.lm.transformer.h.1.attn.attention.q_proj: Linear: input[0]
abs_max= 1.78e+00 < [2] module.lm.transformer.h.1.attn.attention.q_proj: Linear: output
abs_max= 1.86e+00 > [2] module.lm.transformer.h.1.attn.attention.k_proj: Linear: input[0]
abs_max= 3.69e+00 < [2] module.lm.transformer.h.1.attn.attention.k_proj: Linear: output
abs_max= 1.86e+00 > [2] module.lm.transformer.h.1.attn.attention.v_proj: Linear: input[0]
abs_max= 3.21e+00 < [2] module.lm.transformer.h.1.attn.attention.v_proj: Linear: output
abs_max=      nan > [2] module.lm.transformer.h.1.attn.attention.attn_dropout: Dropout: input[0]

@LouisCastricato
Copy link
Author

Let me know if I can provide any more traces for you.

@stas00
Copy link
Contributor

stas00 commented Apr 18, 2021

So this looks more like an underflow, rather than overflow, as activations are tiny and you got nan and not inf and already at batch 2! So this is totally different from t5/mt5 problem.

In this case I will modify the code to print abs_min as well - we are probably going to see tiny-tiny numbers there.

How do I reproduce this?

@LouisCastricato
Copy link
Author

We are trying to find a minimal example for you that can be ran on a 3090.

@LouisCastricato
Copy link
Author

LouisCastricato commented Apr 19, 2021

The lowest we could make the memory requirement was 32GB. We sent you a login for an instance with 6x A100s.

The command you need to run, under ~/visual-grounding/Training/ is

deepspeed --num_gpus 1 distill.py --deepspeed_config ds_config.json --debug

It should output NaN information after the first batch. It is a (semi) minimal example that uses a custom AR trainer, but it crashes before the first optimizer step. The code is (relatively) easy to follow without reading through any of the custom data loaders. We've already confirmed it works with GPT2-XL.

Transformers was not installed from source as editable but I assume you wanted to use a custom branch for this so I just installed it from pypi for you.

@stas00
Copy link
Contributor

stas00 commented Apr 19, 2021

Thank you, @LouisCastricato!

I needed to install my own branch, but I was able to reproduce with the updated detector, which now gives a much better picture. So with your custom code getting:

Detected inf/nan during batch_number=2
last 21 frames:
abs min  abs max  metadata
                  module.lm.transformer.drop Dropout
0.00e+00 9.73e-01 input[0]
0.00e+00 9.73e-01 output
                  module.lm.transformer.h.0.ln_1 LayerNorm
1.05e-01 8.80e-01 weight
6.50e-06 5.97e-01 bias
0.00e+00 9.73e-01 input[0]
5.96e-08 3.12e+00 output
                  module.lm.transformer.h.0.attn.attention.q_proj Linear
0.00e+00 2.51e-01 weight
5.96e-08 3.12e+00 input[0]
3.58e-07 1.25e+01 output
                  module.lm.transformer.h.0.attn.attention.k_proj Linear
0.00e+00 2.60e-01 weight
5.96e-08 3.12e+00 input[0]
5.96e-07 1.59e+01 output
                  module.lm.transformer.h.0.attn.attention.v_proj Linear
0.00e+00 2.82e-01 weight
5.96e-08 3.12e+00 input[0]
1.79e-07 1.05e+01 output
                  module.lm.transformer.h.0.attn.attention.attn_dropout Dropout
0.00e+00 1.00e+00 input[0]
0.00e+00 1.00e+00 output
                  module.lm.transformer.h.0.attn.attention.out_proj Linear
0.00e+00 1.61e+00 weight
4.35e-06 1.80e+00 bias
1.79e-07 9.81e+00 input[0]
0.00e+00 2.13e+02 output
                  module.lm.transformer.h.0.attn.attention.resid_dropout Dropout
0.00e+00 2.13e+02 input[0]
0.00e+00 2.13e+02 output
                  module.lm.transformer.h.0.attn.attention GPTNeoSelfAttention
5.96e-08 3.12e+00 input[0]
0.00e+00 2.13e+02 output[0]
5.96e-07 1.59e+01 output[1][0]
1.79e-07 1.05e+01 output[1][1]
                  module.lm.transformer.h.0.attn GPTNeoAttention
5.96e-08 3.12e+00 input[0]
0.00e+00 2.13e+02 output[0]
5.96e-07 1.59e+01 output[1][0]
1.79e-07 1.05e+01 output[1][1]
                  module.lm.transformer.h.0.ln_2 LayerNorm
3.46e-02 1.28e+00 weight
1.33e-05 9.21e-01 bias
0.00e+00 2.14e+02 input[0]
4.17e-07 2.34e+00 output
                  module.lm.transformer.h.0.mlp.c_fc Linear
0.00e+00 4.42e-01 weight
5.96e-08 2.65e-01 bias
4.17e-07 2.34e+00 input[0]
0.00e+00 9.97e+00 output
                  module.lm.transformer.h.0.mlp.c_proj Linear
0.00e+00 1.20e+00 weight
1.29e-05 1.27e+00 bias
0.00e+00 9.63e+00 input[0]
0.00e+00 2.16e+02 output
                  module.lm.transformer.h.0.mlp.dropout Dropout
0.00e+00 2.16e+02 input[0]
0.00e+00 2.16e+02 output
                  module.lm.transformer.h.0.mlp GPTNeoMLP
4.17e-07 2.34e+00 input[0]
0.00e+00 2.16e+02 output
                  module.lm.transformer.h.0 GPTNeoBlock
0.00e+00 9.73e-01 input[0]
0.00e+00 3.94e+02 output[0]
5.96e-07 1.59e+01 output[1][0]
1.79e-07 1.05e+01 output[1][1]
                  module.lm.transformer.h.1.ln_1 LayerNorm
5.66e-02 1.36e+00 weight
4.65e-06 9.51e-01 bias
0.00e+00 3.94e+02 input[0]
0.00e+00 1.86e+00 output
                  module.lm.transformer.h.1.attn.attention.q_proj Linear
0.00e+00 5.24e-01 weight
0.00e+00 1.86e+00 input[0]
0.00e+00 1.78e+00 output
                  module.lm.transformer.h.1.attn.attention.k_proj Linear
0.00e+00 8.80e-01 weight
0.00e+00 1.86e+00 input[0]
1.19e-07 3.69e+00 output
                  module.lm.transformer.h.1.attn.attention.v_proj Linear
0.00e+00 7.60e-01 weight
0.00e+00 1.86e+00 input[0]
0.00e+00 3.21e+00 output
                  module.lm.transformer.h.1.attn.attention.attn_dropout Dropout
     nan      nan input[0]
     nan      nan output

@stas00
Copy link
Contributor

stas00 commented Apr 20, 2021

I run the detector under the fp32 and the deepspeed/fp16 mode and as I suspected we are having an underflow here - a serious underflow.

Attached 2 traces:

frames_overflow.txt
frames_normal.txt

from the very first forward, we have a problem with embeddings:

fp32:

                  lm.transformer.wte Embedding
2.32e-11 1.02e+00 weight
0.00e+00 5.00e+04 input[0]
1.64e-07 6.96e-01 output

fp16:

                  module.lm.transformer.wte Embedding
0.00e+00 1.02e+00 weight
0.00e+00 5.00e+04 input[0]
1.79e-07 6.96e-01 output

As you can see some weights are immediately 0.0. Of course, it only is getting worse as things progress.

As shown here: https://github.com/stas00/ml-ways/blob/master/numbers/bfloat16-vs-float16-study.ipynb fp16 can barely handle 1e-08, and we have 1e-11 here.

torch.tensor(2.32e-11).to(dtype=torch.bfloat16)
# tensor(2.3192e-11, dtype=torch.bfloat16)
torch.tensor(2.32e-11).to(dtype=torch.float16)
# tensor(0., dtype=torch.float16)

Deepspeed runs in model.half()-mode, that's why this happens from the get-going.

@stas00
Copy link
Contributor

stas00 commented Apr 20, 2021

OK, trying to force fp32 mode in deepspeed by editing its engine to skip model.half(),

--- a/deepspeed/runtime/engine.py
+++ b/deepspeed/runtime/engine.py
@@ -572,7 +572,7 @@ class DeepSpeedEngine(Module):
     def _configure_distributed_model(self, model):
         self.module = model
         if self.fp16_enabled():
-            self.module.half()
+            pass

         if not self.dont_change_device:
             self.module.to(self.device)

Your distill.py script doesn't print any stats, so it's hard to tell if it's doing well, but the training is chugging along.

I stopped it after 5482 steps (2h).

@LouisCastricato
Copy link
Author

Yeah the minimal example removes all evaluation. In FP32 it does work though. I tested a checkpoint the other day.

@stas00
Copy link
Contributor

stas00 commented Apr 20, 2021

I'm asking Deepspeed devs if they have some ideas on how to overcome this, I will keep you posted if we find a good intermediary solution.

But at the very least we now know why the model fails under fp16.

I wonder if pre-training processes targeted for mixed precision use should have a loss penalty component that forces the model to remain within fp16 dynamic range, both upper and lower.

@LouisCastricato
Copy link
Author

microsoft/DeepSpeed#974 (comment)

This could be relevant.

@stas00
Copy link
Contributor

stas00 commented Apr 20, 2021

OP is asking to support bf16 training, but you're asking for fp16 training. These two are significantly different issues.

It'd be awesome for deepspeed to support bf16, but this is not going to help users w/o hardware natively supporting bf16.

@LouisCastricato
Copy link
Author

OP is asking to support bf16 training, but you're asking for fp16 training. These two are significantly different issues.

It'd be awesome for deepspeed to support bf16, but this is not going to help users w/o hardware natively supporting bf16.

I meant the changes they recommended making could also help resolve our FP16 issues. They outlined what would need to be changed for bf16

@stas00
Copy link
Contributor

stas00 commented Apr 20, 2021

Currently the very first problem is deepspeed calling model.half(), which leads to immediate underflow in model weights. As I have shown above:

torch.tensor(2.32e-11).to(dtype=torch.float16)
# tensor(0., dtype=torch.float16)

Therefore I can't see how any of the suggestions directed to support bf16 training would help in this case.

Chances are that deepspeed will need a new mode, which is not all-fp16 and only doing fp16 conversion when it's safe to do so and scaling the weights and activations up/down when they are in unsafe for the fp16 range. So it won't be as slow / memory demanding as a full fp32 mode, but it won't be a normal fp16 mixed precision.

@stas00
Copy link
Contributor

stas00 commented Apr 21, 2021

OK, please have a look at the current setup on your instance, try:

PYTHONPATH=~/DeepSpeed deepspeed --num_gpus 1 distill.py --deepspeed_config ds_config_zero3.json --debug

~/DeepSpeed currently contains an experimental branch by @samyam https://github.com/microsoft/DeepSpeed/tree/samyamr/full-precision-for-stage3
who created a semi-fp32 deepspeed mode that according to him should be only 2x slower than normal mixed precision fp16 for bfloat16-pretrained models, but much faster than fp32.

It also currently requires a hardcoded change:

diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 82a0a9917..9a23bc55b 100755
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -1085,7 +1085,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):

             logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
             # this immediately partitions the model to avoid the overhead in time and memory copying it on CPU or each GPU first
-            with deepspeed.zero.Init():
+            with deepspeed.zero.Init(dtype=torch.float):
                 model = cls(config, *model_args, **model_kwargs)
         else:
             model = cls(config, *model_args, **model_kwargs)

which is already applied under ~/transformers-stas/ and also fp16 is set to false in ds config files. i.e. it is all already setup for you.

so this is zero3. zero2 still needs some work.

@stas00
Copy link
Contributor

stas00 commented Apr 21, 2021

But looking closer at your code, I see now that we have been trying to solve the wrong problem all along.

Why is your code using "EleutherAI/gpt-neo-2.7B", when one of you said earlier was pre-trained in full fp32? how could you possibly expect it to train or eval in fp16? or did you just want deepspeed in fp32 mode? Please clarify.

One of you said it's 1.3B checkpoint that was trained in bf16.

@stas00
Copy link
Contributor

stas00 commented Apr 21, 2021

OK, zero2 now works too.

PYTHONPATH=~/DeepSpeed deepspeed --num_gpus 1 distill.py --deepspeed_config ds_config.json --debug

So Samyam explained that this new deepspeed branch enables full FP32 mode.

But since your setup is running on A100, pytorch uses TF32, so you're getting an equivalent speed to fp16 on V100.

RTX-3090 should also be able to get this performance.

All kudos go to @samyam.

@leogao2
Copy link
Contributor

leogao2 commented Apr 21, 2021

But looking closer at your code, I see now that we have been trying to solve the wrong problem all along.

Why is your code using "EleutherAI/gpt-neo-2.7B", when one of you said earlier was pre-trained in full fp32? how could you possibly expect it to train or eval in fp16? or did you just want deepspeed in fp32 mode? Please clarify.

One of you said it's 1.3B checkpoint that was trained in bf16.

We've been having the nan issue with both the bf16 1.3B checkpoint and the fp32 2.7B checkpoint; we were under the assumption that as both have the same dynamic range, both would have the same under/overflow problems. I'm also pretty sure that the bf16 1.3B checkpoint was trained with bf16 activations with fp32 master weights quantized to bf16 (the quantization was a mistake by one of our devs).

Our main problem is that with fp32, 1.3B, and no deepspeed, we can't even fit a single full batch without OOM, and we can't turn on any deepspeed optimizations without fp16 being on (interestingly, it seems the OOM doesn't happen with Samyam's branch). Of course, we would like to train our model using mixed-precision (using fp32 for the parts that are underflowing) for the obvious memory savings, so we thought it would be much easier to just make our model work with mixed-precision and also get those memory savings than to make deepspeed work with fp32. We would also be fine with making deepspeed work with fp32 or bf16 if it's significantly easier.

Thanks for all your time in helping us with this issue.

@stas00
Copy link
Contributor

stas00 commented Apr 21, 2021

In general if you want users to be able to use fp16 mixed precision for fine-tuning and inference you need to pre-train the model using this mode. For some models we find certain workarounds that localize switching to fp32 for specific submodules, that lead to underflow/overflow under fp16, but often users still get NaNs during long training.

Bottom line, if you pre-train in bf16 be prepared to tell users to use fp32 or bf16 in their fine-tuning/inference processes. As the new hardware supporting bf16/tf32 formats emerges (rtx-3090 + a100) this will be come the simple go-to solution in the future.

Now that deepspeed will have a full-fp32 mode this is great.

So to summarize, at this moment with Samyam's branch if you use:

  • zero2 you just need to do fp16.enable=false in ds config
  • zero3, same as above, plus zero.Init(dtype=torch.float) is needed in modeling_utils.py (instead of just zero.Init()) - I need to think how to make that configurable.

@LouisCastricato
Copy link
Author

How would one use this special fp32 mode without zero?

@stas00
Copy link
Contributor

stas00 commented Apr 23, 2021

You mean w/o deepspeed (or fairscale)?

Just don't enable mixed precision in the training. i.e. in transformers don't use --fp16 in train and don't use --fp16_full_eval in eval.

Unless you ask how to use deepspeed w/o zero - why would you want to do that? ZeRO is the core of deepspeed and if you are not using it, you don't really need deepspeed.

If I misunderstood your question please clarify.

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@stas00
Copy link
Contributor

stas00 commented Dec 24, 2021

@LouisCastricato

After checking our internal documents we realized that 1.3b is bfp16 where as 2.7b is fp32

You wrote: bfp16

Did you mean to write fp16 or bf16?

According to the detector tool I'm working on it is most likely fp16. It'd be super helpful if you could check on how it was trained. Thank you!

If you have other published model checkpoints and their dtype that would be very helpful too, as I'm trying to gather that information.

@stas00
Copy link
Contributor

stas00 commented Feb 11, 2022

Talked to Stella and she confirmed Louis meant to write bf16 for 1.3B model.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants