Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable deepspeed.zero.Init causes very strange spikes in PPO policy_loss #4932

Open
wuxibin89 opened this issue Jan 11, 2024 · 68 comments
Open

Comments

@wuxibin89
Copy link

Hi, we're OpenRLHF team, we heavily use deepspeed to build our RLHF framework and really appreciate to your great work.

Recently we encounter a problem with deepspeed.zero.Init. We want to enable HfDeepSpeedConfig to support larger model(70B+) with ZeRO-3. But after enabling HfDeepSpeedConfig, we see some very strange spikes in PPO policy_loss. If HfDeepSpeedConfig is disabled, the policy_loss is normal and same with ZeRO-2.

Below is wandb metrics with Llama-7B PPO training

  • red([ppo_0104T10:37]): ZeRO-3 w/o HfDeepSpeedConfig
  • purple([ppo_0105T13:01]): ZeRO-3 w/ HfDeepSpeedConfig
    WechatIMG3076
    image

The only difference is enable HfDeepSpeedConfig or not. The reproduce script is below:
https://github.com/OpenLLMAI/OpenRLHF/blob/main/examples/scripts/train_ppo_llama_ray_70b.sh

BTW. I debug parameter's sharding tensor with parameter.ds_tensor, they're same whether enable HfDeepSpeedConfig or not! So we really don't know why. Can you offer some clues that we can keep debugging?

cc @stas00

@hijkzzz
Copy link

hijkzzz commented Jan 11, 2024

we use these two models to reproduce it:

    --pretrain OpenLLMAI/Llama-2-7b-sft-model-ocra-500k \
    --reward_pretrain OpenLLMAI/Llama-2-7b-rm-anthropic_hh-lmsys-oasst-webgpt \

We found that this problem appeared in the middle of each training epoch of the Actor in PPO (multiple mini-batches)
The specific reason is that the value of P_new / P_old for some samples is too large, which reflects the instability of the training

@stas00
Copy link
Contributor

stas00 commented Jan 13, 2024

OK, the whole area of having more than 1 model is new and not well tested. When I designed HfDeepSpeedConfig's side that controls whether we enclose zero.Init around model loading or not in HF Transformers was always with just one model as Deepspeed didn't support multiple models until around last summer. And back then it was clear that if you wanted to use stage3 you'd want zero.Init always on.

HF Accelerate/Deepspeed integration that was done relatively recently already included a flag to turn zero.Init on/off. It shouldn't be complicated to extend HfDeepSpeedConfig to optionally have a flag to whether zero.Init should be on or off but it'd also require tweaks to HF Transformers' modeling_utils.py as currently it just checks is_deepspeed_zero3_enabled but would need to be changed to a new helper is_deepspeed_zero_init_enabled which doesn't currently exist.

You might be able to hack this function to always return False, which would disable zero.Init
https://github.com/huggingface/transformers/blob/29a2b1420633d322140062d7c76b807f41fb90aa/src/transformers/integrations/deepspeed.py#L261
I don't remember if it's being used elsewhere, but you can just hack https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py to replace it with False wherever it's called just in that file.

HF Accelerate for example still doesn't support multiple Deepspeed models, I'm not sure about other frameworks, so when we needed more than one model we were using the workaround of wrapping 2 models into a super-model and using the super-model's forward to do this or that on each model. This works well if it fits your training mode.

So chances are that you're running into this problem where you have one DS engine and two models and I don't think that can work. or at least it shouldn't work.

So you can do 2 things:

  1. Try the super-model approach if it fits the workflow
  2. If you need 2 separate engines - do not rely on HfDeepSpeedConfig and manually add zero.Init on each model and instantiate 2 deepspeed engines - one for each model, and drive each of them separately.

Please let me know if it's helpful.

@wuxibin89
Copy link
Author

@stas00 Very thanks for you reply! Actually, we already separate actor/critic/reward/reference models to different deespeed engines. We use distributed framework Ray to launch multiples torch distributed groups, and each group only load single model with ZeRO-2/3.

Overview architecture is as below
RLHF-DP-TP-vLLM drawio

@wuxibin89
Copy link
Author

We also try HfDeepSpeedConfig with ZeRO-3 to train sft and reward model, all metrics is compatible with ZeRO-2, so we may expect that HfDeepSpeedConfig with ZeRO-3 works correctly.

@stas00
Copy link
Contributor

stas00 commented Jan 17, 2024

So what are you using HfDeepSpeedConfig for?

Do things work correctly if you drop it altogether? Its only use is to:

  1. help with HF Trainer or Accelerate to setup the deepspeed integration
  2. turn the zero.Init context during from_pretrained call

If you use neither HF Trainer nor Accelerate then (1) is irrelevant and it's just then only an issue of zero.Init - which you can try to turn on by yourself inside modeling_utils.py or not use it at all if you don't need the pre-sharding at the startup - which shouldn't be needed if you have a lot of cpu memory.

@wuxibin89
Copy link
Author

We want to use HfDeepSpeedConfig to load models >70B (e.g Llama2-70B) to GPU directly. And yes, we didn't use HF Trainer nor Accelerate, so I agree with that it's only an issue of zero.Init.

@stas00
Copy link
Contributor

stas00 commented Jan 17, 2024

Thank you for clarifying, @wuxibin89

Now that we narrowed it down to just zero.Init it's no longer an issue of you using HfDeepSpeedConfig, since the way you use it is just triggering zero.Init on all models.

So tagging @tjruwase to take over.

@tjruwase
Copy link
Contributor

Below is wandb metrics with Llama-7B PPO training

  • red([ppo_0104T10:37]): ZeRO-3 w/o HfDeepSpeedConfig
  • purple([ppo_0105T13:01]): ZeRO-3 w/ HfDeepSpeedConfig

@wuxibin89, can you clarify what you mean by ZeRO-3 w/o HfDeepSpeedConfig? Do you mean ZeRO-3 w/o zero.Init?

@wuxibin89
Copy link
Author

wuxibin89 commented Jan 18, 2024

@tjruwase @stas00 I create a reproduce script here with our very first micro batch data.
https://gist.github.com/wuxibin89/7d4801d62a743d5bbc72b507e8f0e326

@hijkzzz
Copy link

hijkzzz commented Jan 18, 2024

@tjruwase @stas00 I create a reproduce script here with our very first micro batch data. https://gist.github.com/wuxibin89/7d4801d62a743d5bbc72b507e8f0e326

we can see significant logits difference with .max() .min()

>>> import torch
>>> s0 = torch.load("result_False.pt", map_location="cpu")
>>> s1 = torch.load("result_True.pt", map_location="cpu")
>>> (s0["log_probs"] - s1["log_probs"]).sum()
tensor(51.7910)
>>> (s0["log_probs"] - s1["log_probs"]).mean()
tensor(0.0020)
>>>
>>>
>>>
>>> (s0["log_probs"] - s1["log_probs"]).max()
tensor(4.3943)
>>> (s0["log_probs"] - s1["log_probs"]).min()
tensor(-2.8594)

@stas00
Copy link
Contributor

stas00 commented Jan 19, 2024

Thank you for making a super-easy to setup and use repro script, @hijkzzz and @wuxibin89

So first I tested that the issue has nothing to do with HfDeepSpeedConfig - I can repro the problem by removing it completely and just wrapping from_pretrained in zero.init

    kwargs = dict(
        trust_remote_code=True,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
    )

    if args.enable_hf_deespeed and ds_config["zero_optimization"]["stage"] == 3:
        print("==========enable zero Init ============")
        with deepspeed.zero.Init(config_dict_or_path=ds_config):
            model = AutoModelForCausalLM.from_pretrained(args.pretrain, **kwargs)
    else:
        model = AutoModelForCausalLM.from_pretrained(args.pretrain,**kwargs)

I also reduced it to just one sample and a sub-section of it:

    start, end = (600, 900)
    input_ids = input_ids[0][start:end].unsqueeze(0)
    attention_mask = attention_mask[0][start:end].unsqueeze(0)

already shows a largish diff. and it doesn't fail if I make the slice smaller - like 200-250 tokens. It seems to start diverging at ~300 tokens.

Finally, I validated that it has to do with half-precision. If I switch to fp32 in ds config, the problem goes away (tiny difference of 1.e-5). Both bf16 and fp16 manifest this.

I thought that perhaps it might have had with RNG, but resetting the seed after model loading didn't seem to make a difference.

oh and I had to hack transformers to deal with the removal of HfDeepSpeedConfig, otherwise it'd fail to load the weights.

diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 9c4639b47..ce3a3e1c4 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -605,7 +605,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix):
         # Parameters of module and children will start with prefix. We can exit early if there are none in this
         # state_dict
         if len([key for key in state_dict if key.startswith(prefix)]) > 0:
-            if is_deepspeed_zero3_enabled():
+            if 1 or is_deepspeed_zero3_enabled():
                 import deepspeed

                 # In sharded models, each shard has only part of the full state_dict, so only gather

So it has something to do with some precisions issue that gets aggravated with more tokens.

@wuxibin89
Copy link
Author

@stas00 Thanks for you help! I also try removing deepspeed completely and just use transformers alone, the result is same with deepspeed ZeRO-3 without using zero.Init. So I guess the problem is not ZeRO-3, but how zero.Init initialize the model.

@stas00
Copy link
Contributor

stas00 commented Jan 19, 2024

I tested that the state_dict is identical in both cases:

    # disable HfDeepSpeedConfig
    # dschf = HfDeepSpeedConfig(ds_config)

    kwargs = dict(
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )

    model_a = AutoModelForCausalLM.from_pretrained(args.pretrain, **kwargs)
    sd_a = model_a.state_dict()
    del model_a

    with deepspeed.zero.Init(config_dict_or_path=ds_config):
        model_b = AutoModelForCausalLM.from_pretrained(args.pretrain, **kwargs)

    with deepspeed.zero.GatheredParameters(model_b.parameters(), modifier_rank=None):
        sd_b = {k: v.cpu() for k, v in model_b.state_dict().items()}

    torch.testing.assert_close(sd_a, sd_b, rtol=0.0, atol=0.0)

This would have failed if they were different.

We need to check if there are perhaps some non-parameter buffers that get missed or are invalid.

@stas00
Copy link
Contributor

stas00 commented Jan 19, 2024

It's the buffers! Llama-2 has 3 buffers per layer, e.g. for layer 28:

model.layers.28.self_attn.rotary_emb.inv_freq
model.layers.28.self_attn.rotary_emb.cos_cached
model.layers.28.self_attn.rotary_emb.sin_cached

Here is the proof:

    kwargs = dict(
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    )

    model_a = AutoModelForCausalLM.from_pretrained(args.pretrain, **kwargs)
    buffers_a = dict(model_a.named_buffers())
    del model_a

    with deepspeed.zero.Init(config_dict_or_path=ds_config):
        model_b = AutoModelForCausalLM.from_pretrained(args.pretrain, **kwargs)
    buffers_b = {k: v.cpu() for k, v in dict(model_b.named_buffers()).items()}

    torch.testing.assert_close(buffers_a, buffers_b, rtol=0.0, atol=0.0)

with bf16 there is a 63% difference:

Mismatched elements: 329332 / 524288 (62.8%)
Greatest absolute difference: 2.0 at index (1148, 2)
Greatest relative difference: 8320.0 at index (3657, 11)

with fp32 there is almost no difference, but it's not 100% exact!

Mismatched elements: 88 / 524288 (0.0%)
Greatest absolute difference: 0.00390625 at index (2308, 47)
Greatest relative difference: 0.00762939453125 at index (4060, 22)

@tjruwase, now that we know the source of the problem, passing this to you now.

It must be something about how buffers are treated differently than params under zero.Init

@stas00
Copy link
Contributor

stas00 commented Jan 19, 2024

I wanted to see the actual differences so I traced it down to specific elements so if I print just one element that's mismatched in one buffer:

    key = "model.layers.0.self_attn.rotary_emb.inv_freq"
    print(f"{buffers_a[key][22]:.20f}")
    print(f"{buffers_b[key][22]:.20f}")

So you can see the difference:

0.04216964915394783020
0.04216965287923812866

so in this single buffer 4 out of 64 are mismatched.

Mismatched elements: 4 / 64 (6.2%)
Greatest absolute difference: 3.725290298461914e-09 at index (22,)
Greatest relative difference: 1.0081151913254871e-07 at index (47,)

The other surprise is that buffers remain in fp32, despite torch_dtype=torch.bfloat16 arg in from_pretrained and bf16-deepspeed config. So perhaps the error happens through some conversion between types?

I printed:

    key = "model.layers.0.self_attn.rotary_emb.inv_freq"
    print(f"{buffers_a[key].dtype}")
    print(f"{buffers_b[key].dtype}")

and got:

torch.float32
torch.float32

But looking at the modeling code the buffers are created at model init time and they are explicitly converted to float in the LLama code.

@stas00
Copy link
Contributor

stas00 commented Jan 19, 2024

ok, I further reduced it to a very simple repro:

import torch
import deepspeed

ds_config = {
    "zero_optimization": {
        "stage": 3,
    },
    "bf16": {
        "enabled": True,
    },
    "train_micro_batch_size_per_gpu": 1,
    "train_batch_size": 1,
}

def do_math():
    inv_freq = 1.0 / (10_000 ** (torch.arange(0, 128, 2).float() / 128))
    print(f"{inv_freq[22]:.20f}")

do_math()
with deepspeed.zero.Init(config_dict_or_path=ds_config):
    do_math()

These numbers should be the same, but they aren't:

0.04216964915394783020
0.04216965287923812866

I derived it from the buffer init code:
https://github.com/huggingface/transformers/blob/3f69f415adcbdaedec154ba8eac220ef3276975d/src/transformers/models/llama/modeling_llama.py#L130

@stas00
Copy link
Contributor

stas00 commented Jan 19, 2024

The problem comes from:

def patch_init_and_builtins(self):

commenting it out removes the mismatch in the repro above

@stas00
Copy link
Contributor

stas00 commented Jan 19, 2024

So in the particular case, it's the torch.* overrides, commenting out this one fixes the specific code in the repro as it uses torch.arange:

torch.arange = zero_wrapper_for_fp_tensor_constructor(_orig_torch_arange, self.dtype)

@stas00
Copy link
Contributor

stas00 commented Jan 19, 2024

and the mismatch happens because the buffers creation code runs:

  1. w/o zero.Init before deepspeed overrides torch.* API
  2. w/ zero.Init after deepspeed overrides torch.* API

both cases eventually override torch.* API, so the difference happens because of the timing.

What's worrying is that there is a mismatch in the first place - it should not be there. Otherwise clearly inference running under deepspeed ZeRO (regardless of zero.Init) or not isn't mathematically equivalent to running w/o deepspeed ZeRO - which is very worrying!

So what needs to be investigated and fixed - isn't zero.Init but most likely zero_wrapper_for_fp_tensor_constructor and tests added to ensure math for all overridden APIs is exact (and it's non-trivial since the discrepancy can be seen only in special cases when the tensor is big enough and there is ** involved in this case).

@stas00
Copy link
Contributor

stas00 commented Jan 19, 2024

ok, so the problem is this: the math on cpu and the cuda isn't the same!

the trigger for this problem is this:

def wrapped_fn(*args, **kwargs) -> Tensor:
if kwargs.get("device", None) is None:
kwargs['device'] = torch.device(get_accelerator().device_name(os.environ["LOCAL_RANK"]))

If I add

kwargs['device'] = None

This forces the code in repro to be calculated on the same device (cpu) and the math is exact.

So we are moving to the next level here, the discrepancy comes from the devices, which don't do math in the same way and the outcomes aren't guaranteed to be identical.

  1. w/ zero Init the model is created on CUDA
  2. w/o zero Init the model is created on CPU

So if the math(CUDA) != math(CPU) this is why you get different results.

A one time small math outcome discrepancy makes little difference at first but it gets compounded when it goes through dozens of layers, doing hundreds of matmuls and spreading and increasing the discrepancy wider and wider.

So if my diagnosis make sense IMO there are no bugs in Deepspeed ZeRO wrt this outcome discrepancy issue.

But this nuance should be documented as it's very important.

@stas00
Copy link
Contributor

stas00 commented Jan 19, 2024

Here is the final repro w/o deepspeed involved:

import torch

def do_math(device):
    inv_freq = (10 ** (torch.arange(0, 10, device=device) / 10))
    print(f"{inv_freq[9]:.20f}")
    return inv_freq.cpu()

a = do_math(torch.device("cpu"))
b = do_math(torch.device("cuda"))

torch.testing.assert_close(a, b, rtol=0.0, atol=0.0)
7.94328212738037109375
7.94328308105468750000
Traceback (most recent call last):
  File "test.py", line 11, in <module>
    torch.testing.assert_close(a, b, rtol=0.0, atol=0.0)
  File "/home/stas/anaconda3/envs/py39-pt21/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1520, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not equal!

Mismatched elements: 2 / 10 (20.0%)
Greatest absolute difference: 9.5367431640625e-07 at index (9,)
Greatest relative difference: 1.200604771156577e-07 at index (9,)

@tjruwase
Copy link
Contributor

@stas00, amazing debugging as always! I am stumped as to the solution here. Shall we file this as a cuda bug?

@tjruwase
Copy link
Contributor

That is assuming cpu compute is correct, which is perhaps unfair :)

@stas00
Copy link
Contributor

stas00 commented Jan 19, 2024

It's not a cuda bug, it's just how hardware works. There are discrepancies even between 2 NVIDIA gpus, e.g. V100 and A100 won't always produce the same math output

You can find a ton of threads discussing this: https://www.google.ca/search?q=pytorch+math+discrepancies+cuda+cpu

This gotcha should just be documented with a pointer to this Issue for those who want to understand it better.

Now which mode is better for the inference user will depend on how the model was instantiated for training. If it was on cpu, then loading it on cpu will lead to the closest outcome. If on gpu, then loading on gpu will be the closest (i.e. with zero.Init)

@stas00
Copy link
Contributor

stas00 commented Jan 19, 2024

Please note that this impacts only buffers and not trained parameters - those are loaded from the checkpoint w/o any modification - regardless of zero.init or not. So if the model has very few or no buffers, or the buffers aren't created via a complex math the 2 modes will be close to identical logits outcome-wise.

@stas00
Copy link
Contributor

stas00 commented Jan 19, 2024

@stas00, amazing debugging as always!

Thank you, Tunji.

As promised I started to write down the debug methodologies I developed in this repo:

https://github.com/stas00/the-art-of-debugging

@stas00
Copy link
Contributor

stas00 commented Jan 20, 2024

@hijkzzz and @wuxibin89, so now that we understand well what's going on how do we support your need?

How do you tell which of the 2 behaviours is the preferred or (correct?) one?

If you want the cpu-way as a short-term hack you could update the model's buffers with their values calculated separately on cpu. I think since those buffers are invarient of the model size you could even create a tiny model on cpu and then copy them over to the zero.Init'ed model.

I'm also thinking that in this situation perhaps LLama-2 should store its exact buffers how they were when they were trained in the distributed model files, so that the user always gets the exact weights - since clearly this leads to ambiguity and possibly invalid inference outcome. So you may want to consider opening a request at HF Transformers?

@wuxibin89
Copy link
Author

@stas00 Thanks for your amazing work, that's a very impressive debugging process. To compatible with our previous experiments with ZeRO-2, we prefer to use the cpu-way which is to initialize RoPE buffer at CPU first and then copy to GPU.

For quick proof, I do some nasty hack with modeling_llama.py.

diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 8ceee2d1d..f6d108986 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -124,6 +124,10 @@ class LlamaRotaryEmbedding(nn.Module):
     def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
         super().__init__()
 
+        import deepspeed
+        hooked_arange = torch.arange
+        torch.arange = deepspeed.runtime.zero.partition_parameters._orig_torch_arange
+
         self.dim = dim
         self.max_position_embeddings = max_position_embeddings
         self.base = base
@@ -134,6 +138,10 @@ class LlamaRotaryEmbedding(nn.Module):
         self._set_cos_sin_cache(
             seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
         )
+        self.cos_cached = self.cos_cached.to("cuda")
+        self.sin_cached = self.sin_cached.to("cuda")
+
+        torch.arange = hooked_arange
 
     def _set_cos_sin_cache(self, seq_len, device, dtype):
         self.max_seq_len_cached = seq_len

This hack is to force initializing RoPE buffer at CPU and then copying to GPU. After this hack, the difference in my reproduce script has gone.

Moreover, I quickly run our e2e PPO training script with ZeRO-3 and deepspeed.zero.Init. The first 10 steps show that spike in policy_loss has gone! I will do some more experiments later.

screenshot-20240120-172531

@rwightman
Copy link

rwightman commented Jan 24, 2024

If I compare just: "model.layers.0.self_attn.rotary_emb.sin_cached"

Mismatched elements: 324808 / 524288 (62.0%)
Greatest absolute difference: 2.0 at index (1211, 0)
Greatest relative difference: 2128.0 at index (3305, 26)

So yeah, that's a huge difference, and I feel it is definitely the arange issue pointed above, any use of zero.Init() context manager with a low precision dtype will trash the range > 256 for bfloat16 or > 2048 for float16. The dtype for arange should always be torch.long / torch.int and then cast to float dtype as late as possible.

EDIT: The div issue really should have little to no impact compared to above.

@rwightman
Copy link

Note that GPT-NeoX and origina Llama do this correctly

https://github.com/EleutherAI/gpt-neox/blame/63991555ec082c8f80c475f851d008193b10008c/megatron/model/positional_embeddings.py#L27

https://github.com/facebookresearch/llama/blob/ef351e9cd9496c579bf9f2bb036ef11bdc5ca3d2/llama/model.py#L100-L103

And Transformers does not. So feel Transformers needs some fixes, this isn't the only instance of the pattern where a float dtype has been used in the arange for an int enumeration ....

@stas00
Copy link
Contributor

stas00 commented Jan 24, 2024

Ross, do you have resources to take this to HF Transformers? It doesn't look like @hijkzzz and @wuxibin89 are planning to take it there.

I was helping with the diagnostics but I'm no longer at HF to take the lead.

or perhaps please recommend who should be tagged from HF so that they could take the lead?

@rwightman
Copy link

rwightman commented Jan 24, 2024

Ross, do you have resources to take this to HF Transformers? It doesn't look like @hijkzzz and @wuxibin89 are planning to take it there.

I was helping with the diagnostics but I'm no longer at HF to take the lead.

or perhaps please recommend who should be tagged from HF so that they could take the lead?

I scanned and found ~66 code blocks across many models that are using float dtypes in position embedding style enumerations that should be left as torch.long. I've submitted this and will likely create an issue for it, it's already posted in slack there.

@stas00
Copy link
Contributor

stas00 commented Jan 24, 2024

You rock, Ross - thank you very much for taking the lead on this issue. It'll help so many users who don't even know they may have it.

@stas00
Copy link
Contributor

stas00 commented Jan 24, 2024

What do you think about my proposal to store the buffers with the model weights in the checkpoint? Then it's even easier because it'll always be correct.

@rwightman
Copy link

Yeah, it could be silently impacting many people using with this sort of init patching context manager, is DS the only common one? Looks like llama, falcon, phi2, mistral/mixtral, kosmos2, neox, many detr, and others are potentially impacted depending on the length of the ranges in each use...

@rwightman
Copy link

And for anyone coming into this thread, reason why there's the big difference between cast to float after arange vs passing the dtype, is this guard in DS Zero Init() patching, if the dtype for the range is not a floating point type, it won't cast to the low precision type:

if tensor.is_floating_point():
tensor.data = tensor.data.to(target_fp_dtype)

@stas00
Copy link
Contributor

stas00 commented Jan 24, 2024

Yeah, it could be silently impacting many people using with this sort of init patching context manager, is DS the only common one?

I don't think if FSDP does it, but they probably should as loading a huge model on FSDP is a huge problem right now.

I think that if a user sets torch.cuda.set_device("cuda") it will have the same effect.

Also if the user or a framework calls torch.set_default_dtype() it could also have a problem unless an explicit dtype is used during the op. And we know HF transformers does that here:

https://github.com/huggingface/transformers/blob/f40b87de0ca234df61f76928956c4a2118c0b548/src/transformers/modeling_utils.py#L1430

so I think this problem again isn't DS-specific but a much wider one.

@rwightman
Copy link

rwightman commented Jan 24, 2024

@stas00 on storing buffers, I do believe that would have avoided most of the issues, as after all of the calculations from the int range -> final float pos, casting the final result from float -> float16/bfloat16 results in a much much much smaller difference than doing the int enumeration in the low precision type.

However, I've never been a fan of storing this info in checkpoints myself. Depending on model, sometimes when wanting to adap seq len, resolutions, etc it's better to re-generate the buffers with new params than interpolate the stored ones...

@stas00
Copy link
Contributor

stas00 commented Jan 24, 2024

My thinking is that correctness out of the box is more important than convenience - perhaps HF transformers could add a flag to skip buffer overwriting if the advanced user prefers not to have that overhead? that way both needs will get met.

@rwightman
Copy link

rwightman commented Jan 24, 2024

torch.set_default_dtype() setting a lower precision type will only be an issue if dtype=None ... where the dtype used will be inferred int64 if the range start/end/step are ints, or default dtype if any of those are float...

For using set_default_dtype()

This is a problem

torch.set_default_dtype(torch.bfloat16)
torch.arange(4096.0)

This is not:

torch.set_default_dtype(torch.bfloat16)
torch.arange(4096)

This is the safest (and also prevents issues with DS zero.Init()) and makes intent explicit

torch.set_default_dtype(torch.bfloat16)
torch.arange(4096, dtype=torch.int64)

@hijkzzz
Copy link

hijkzzz commented Jan 25, 2024

@rwightman torch.arange(4096, dtype=torch.int64) cannot completely solve this problem. The calculation of other parts of RoPE still running on the GPU + BF16 will cause accuracy issues.

@rwightman
Copy link

@rwightman torch.arange(4096, dtype=torch.int64) cannot completely solve this problem. The calculation of other parts of RoPE still running on the GPU + BF16 will cause accuracy issues.

I'm not following, why would the 'other parts' of RoPE be calculated in bfloat16? The code posted above is forcing them to float32, BUT, it does this after it's generated an incorrect sequence via the patched arange.

torch.arange is being used with dtype=float, zero.Init() patches arange and says oh look, float, I will set this to bfloat16, then the .float() moves it -> float32, but the damage has been done, bfloat16 cannot represent integers > 256 without loss, so for a 4096 sequence, the input range will have whole lot of non-trivial differences right off the bat.

aa = torch.arange(4096, dtype=torch.bfloat16, device='cuda').float()
bb = torch.arange(4096, device='cuda').float()
max(aa - bb)
Out[142]: tensor(8., device='cuda:0')
sum(abs(aa - bb))
Out[143]: tensor(10880., device='cuda:0')

That's significant! It completely dwarfs any losses by casting the final float32 -> bfloat16 or float16...

@hijkzzz
Copy link

hijkzzz commented Jan 25, 2024

@rwightman torch.arange(4096, dtype=torch.int64) cannot completely solve this problem. The calculation of other parts of RoPE still running on the GPU + BF16 will cause accuracy issues.

I'm not following, why would the 'other parts' of RoPE be calculated in bfloat16? The code posted above is forcing them to float32, BUT, it does this after it's generated an incorrect sequence via the patched arange.

torch.arange is being used with dtype=float, zero.Init() patches arange and says oh look, float, I will set this to bfloat16, then the .float() moves it -> float32, but the damage has been done, bfloat16 cannot represent integers > 256 without loss, so for a 4096 sequence, the input range will have whole lot of non-trivial differences right off the bat.

aa = torch.arange(4096, dtype=torch.bfloat16, device='cuda').float()
bb = torch.arange(4096, device='cuda').float()
max(aa - bb)
Out[142]: tensor(8., device='cuda:0')
sum(abs(aa - bb))
Out[143]: tensor(10880., device='cuda:0')

That's significant! It completely dwarfs any losses by casting the final float32 -> bfloat16 or float16...

Hi, you could try the script to verify whether this bug can be solved: https://gist.github.com/wuxibin89/7d4801d62a743d5bbc72b507e8f0e326

@rwightman
Copy link

If you use bfloat16 as an arange dtype, look at the value progression after 256...

tensor([254., 255., 256., 256., 258., 260., 260., 260., 262., 264., 264., 264.,
        266., 268., 268., 268., 270., 272., 272., 272., 274., 276., 276., 276.,
        278., 280., 280., 280., 282., 284., 284., 284., 286., 288., 288., 288.,
        290., 292., 292., 292., 294., 296., 296., 296., 298., 300., 300., 300.,
        302., 304., 304., 304., 306., 308., 308., 308., 310., 312., 312., 312.,
        314., 316., 316., 316., 318., 320.],

@rwightman
Copy link

floating point numbers have non-uniform coverage of real numbers, for lower precision types, esp with small mantissa the repr of large integers is going to be awful, but after you've multiplied with the inv freq and performed sin/cos you're in a much better region and the losses from float32 -> bfloat16 or float16 will be much less significant.

@hijkzzz
Copy link

hijkzzz commented Jan 25, 2024

Reference in ne

I replaced the dtype of all torch.arange with torch.long, and there is still an error of about 0.8

 (s0["log_probs"] - s1["log_probs"]).max()
0.8xxxx

@rwightman
Copy link

@hijkzzz it's not clear what this script is trying to compare? https://gist.github.com/wuxibin89/7d4801d62a743d5bbc72b507e8f0e326

the --enable_hf_deespeed seems to have no impact, the same ds_config is used in either case

@rwightman
Copy link

okay, nevermind, it is changing the operation not familiar with that mechanism, but anways.. I looked at what's happening for the buffer init calculation

With the original t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) what I've described is definitely happening, when deepspeed config is enabled it's using bfloat16 for the arange and the generated sequence will have issues right away.

Changing to t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.long).type_as(self.inv_freq.dtype) definitely leaves the operations in float32 until

  self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
  self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

If we change the self._set_cos_sin_cache call in __init__ to below, it will leave those buffers as float32 in initialization pass, but they do still get cast to the low precision type later, presumably by a model level .to() or something similar.

        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=inv_freq.dtype,
        )

So in the end we do end up with low precision buffers for the RoPe values. Normally I wouldn't think this is a big issue, not seeing how anyone's use of this model with deepspeed (e.g. other Llaama implementations) would be any different at this point. Not super familiar with the deepspeed usage patterns, without deepspeed you could keep those buffers in float32 and have everything else in low precision as long as you managed any necessary casts in forward yourself. But don't think this is typical?

The mentioned cpu div issue should be neglible next to the final cast from float32 -> bfloat16. The 'hack' to disable patched arange should also be no different from the proposed dtype fix as the final cast would still happen. So if you see a difference between those two hacks not sure what's going on...

@rwightman
Copy link

Also worth pointing out, those buffers end up in low precision regardless of whether or not the deepspeed config is enabled, presumably due to the dtype, but the logits/probs are definitely different by non-trivial amt, I'm rather confused.

@rwightman
Copy link

rwightman commented Jan 26, 2024

Had another chance at end of day to look, so yeah, model is very sensitive to the sin/cos values, and it's easy to end up with differening values for those depending on the dtype used or cpu vs gpu calcs, and that's after avoiding the arange dtype problem which results in even larger errors...

Other issues about this, we're not the first to the party

I have a variation of the EleutherAI solution that seems to work, a bit less hacky than the one originally proposed here, but still rather unsatisfying. Note, I had to explicitly set the device='cpu' arg for arange to prevent deepspeed from overriding it and keep calcs on the CPU. Convert to final dtype/device only done on use in foward(), that should alow the resize on the fly to work when bigger seq len is passed, but guess it does a device transfer every forward ... meh :/

A bit of a mess...

class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(seq_len=max_position_embeddings)

    def _set_cos_sin_cache(self, seq_len):
        self.max_seq_len_cached = seq_len
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device='cpu').float() / self.dim))
        t = torch.arange(self.max_seq_len_cached, device='cpu').float()
        freqs = torch.outer(t, inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()
        self.sin_cached = emb.sin()

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len)

        return (
            self.cos_cached[:seq_len].to(x),
            self.sin_cached[:seq_len].to(x),
        )

@wuxibin89
Copy link
Author

@hijkzzz it's not clear what this script is trying to compare? https://gist.github.com/wuxibin89/7d4801d62a743d5bbc72b507e8f0e326

the --enable_hf_deespeed seems to have no impact, the same ds_config is used in either case

Yes, the ds_config is same in either case, this flag control whether enable HfDeepSpeedConfig or not. HfDeepSpeedConfig calls deepspeed.zero.Init and causes RoPE buffer init on GPU.

@rwightman
Copy link

rwightman commented Jan 26, 2024

@hijkzzz it's not clear what this script is trying to compare? https://gist.github.com/wuxibin89/7d4801d62a743d5bbc72b507e8f0e326
the --enable_hf_deespeed seems to have no impact, the same ds_config is used in either case

Yes, the ds_config is same in either case, this flag control whether enable HfDeepSpeedConfig or not. HfDeepSpeedConfig calls deepspeed.zero.Init and causes RoPE buffer init on GPU.

Yeah, figured that bit of magic out, reproduced the issue, my snipped above seems to work around.

With the arange dtype issue avoided, there's max difference ~+/- 7e-6 between the final output after sin/cos on CPU vs
GPU. Once that is then cast to a lower precision type we end up with +/- 5e-4 difference in float16 and +/- 2e-3 in bfloat16, and that makes for a big impact in the logits...

It's a rather remarkable progression of float error propagation... each step in the rope calc. After inf_freq, only a 1.87e-9 diff, 3.8e-6 after the outer product, 7e-6, then finally 5e-4 or 2e-3 before being added to q/k

@stas00
Copy link
Contributor

stas00 commented Jan 26, 2024

And if these buffers are to be restored from the checkpoint to exactly the values they were set to while training then none of this would be an issue. Please correct me if I'm wrong.

@rwightman
Copy link

@stas00 some of the issues like having the calc move to the GPU or be done in lower precesion might be helped by that, but doesn't solve the need to recalc on the fly or at init for different seq len, also thinking more the embed should really be applied to q/k in float32, having them in buffers means deep speed will downcast them, not sure that can easily be avoided unless you prevent them from being buffers of any sort...

@stas00
Copy link
Contributor

stas00 commented Jan 26, 2024

Understood. Thank you for clarifying, Ross!

@gante
Copy link

gante commented Jan 27, 2024

Hi everyone 👋 FYI, we're having a look at this issue from the transformers side, including potentially forcing the computations of RoPE to be done in FP32. This issue will track it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants