-
Notifications
You must be signed in to change notification settings - Fork 3.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable deepspeed.zero.Init causes very strange spikes in PPO policy_loss #4932
Comments
we use these two models to reproduce it:
We found that this problem appeared in the middle of each training epoch of the Actor in PPO (multiple mini-batches) |
OK, the whole area of having more than 1 model is new and not well tested. When I designed HF Accelerate/Deepspeed integration that was done relatively recently already included a flag to turn You might be able to hack this function to always return False, which would disable HF Accelerate for example still doesn't support multiple Deepspeed models, I'm not sure about other frameworks, so when we needed more than one model we were using the workaround of wrapping 2 models into a super-model and using the super-model's forward to do this or that on each model. This works well if it fits your training mode. So chances are that you're running into this problem where you have one DS engine and two models and I don't think that can work. or at least it shouldn't work. So you can do 2 things:
Please let me know if it's helpful. |
We also try HfDeepSpeedConfig with ZeRO-3 to train sft and reward model, all metrics is compatible with ZeRO-2, so we may expect that HfDeepSpeedConfig with ZeRO-3 works correctly. |
So what are you using Do things work correctly if you drop it altogether? Its only use is to:
If you use neither HF Trainer nor Accelerate then (1) is irrelevant and it's just then only an issue of |
We want to use HfDeepSpeedConfig to load models >70B (e.g Llama2-70B) to GPU directly. And yes, we didn't use HF Trainer nor Accelerate, so I agree with that it's only an issue of |
Thank you for clarifying, @wuxibin89 Now that we narrowed it down to just So tagging @tjruwase to take over. |
@wuxibin89, can you clarify what you mean by |
@tjruwase @stas00 I create a reproduce script here with our very first micro batch data. |
we can see significant logits difference with .max() .min()
|
Thank you for making a super-easy to setup and use repro script, @hijkzzz and @wuxibin89 So first I tested that the issue has nothing to do with HfDeepSpeedConfig - I can repro the problem by removing it completely and just wrapping
I also reduced it to just one sample and a sub-section of it:
already shows a largish diff. and it doesn't fail if I make the slice smaller - like 200-250 tokens. It seems to start diverging at ~300 tokens. Finally, I validated that it has to do with half-precision. If I switch to fp32 in ds config, the problem goes away (tiny difference of 1.e-5). Both bf16 and fp16 manifest this. I thought that perhaps it might have had with RNG, but resetting the seed after model loading didn't seem to make a difference. oh and I had to hack transformers to deal with the removal of HfDeepSpeedConfig, otherwise it'd fail to load the weights.
So it has something to do with some precisions issue that gets aggravated with more tokens. |
@stas00 Thanks for you help! I also try removing deepspeed completely and just use transformers alone, the result is same with deepspeed ZeRO-3 without using |
I tested that the state_dict is identical in both cases:
This would have failed if they were different. We need to check if there are perhaps some non-parameter buffers that get missed or are invalid. |
It's the buffers! Llama-2 has 3 buffers per layer, e.g. for layer 28:
Here is the proof:
with bf16 there is a 63% difference:
with fp32 there is almost no difference, but it's not 100% exact!
@tjruwase, now that we know the source of the problem, passing this to you now. It must be something about how buffers are treated differently than params under |
I wanted to see the actual differences so I traced it down to specific elements so if I print just one element that's mismatched in one buffer:
So you can see the difference:
so in this single buffer 4 out of 64 are mismatched.
The other surprise is that buffers remain in fp32, despite I printed:
and got:
But looking at the modeling code the buffers are created at model init time and they are explicitly converted to float in the LLama code. |
ok, I further reduced it to a very simple repro:
These numbers should be the same, but they aren't:
I derived it from the buffer init code: |
The problem comes from:
commenting it out removes the mismatch in the repro above |
So in the particular case, it's the
|
and the mismatch happens because the buffers creation code runs:
both cases eventually override What's worrying is that there is a mismatch in the first place - it should not be there. Otherwise clearly inference running under deepspeed ZeRO (regardless of zero.Init) or not isn't mathematically equivalent to running w/o deepspeed ZeRO - which is very worrying! So what needs to be investigated and fixed - isn't |
ok, so the problem is this: the math on cpu and the cuda isn't the same! the trigger for this problem is this: DeepSpeed/deepspeed/runtime/zero/partition_parameters.py Lines 241 to 243 in e62a47e
If I add
This forces the code in repro to be calculated on the same device ( So we are moving to the next level here, the discrepancy comes from the devices, which don't do math in the same way and the outcomes aren't guaranteed to be identical.
So if the math(CUDA) != math(CPU) this is why you get different results. A one time small math outcome discrepancy makes little difference at first but it gets compounded when it goes through dozens of layers, doing hundreds of So if my diagnosis make sense IMO there are no bugs in Deepspeed ZeRO wrt this outcome discrepancy issue. But this nuance should be documented as it's very important. |
Here is the final repro w/o deepspeed involved:
|
@stas00, amazing debugging as always! I am stumped as to the solution here. Shall we file this as a cuda bug? |
That is assuming cpu compute is correct, which is perhaps unfair :) |
It's not a cuda bug, it's just how hardware works. There are discrepancies even between 2 NVIDIA gpus, e.g. V100 and A100 won't always produce the same math output You can find a ton of threads discussing this: https://www.google.ca/search?q=pytorch+math+discrepancies+cuda+cpu This gotcha should just be documented with a pointer to this Issue for those who want to understand it better. Now which mode is better for the inference user will depend on how the model was instantiated for training. If it was on cpu, then loading it on cpu will lead to the closest outcome. If on gpu, then loading on gpu will be the closest (i.e. with |
Please note that this impacts only buffers and not trained parameters - those are loaded from the checkpoint w/o any modification - regardless of |
Thank you, Tunji. As promised I started to write down the debug methodologies I developed in this repo: |
@hijkzzz and @wuxibin89, so now that we understand well what's going on how do we support your need? How do you tell which of the 2 behaviours is the preferred or (correct?) one? If you want the cpu-way as a short-term hack you could update the model's buffers with their values calculated separately on cpu. I think since those buffers are invarient of the model size you could even create a tiny model on cpu and then copy them over to the zero.Init'ed model. I'm also thinking that in this situation perhaps LLama-2 should store its exact buffers how they were when they were trained in the distributed model files, so that the user always gets the exact weights - since clearly this leads to ambiguity and possibly invalid inference outcome. So you may want to consider opening a request at HF Transformers? |
@stas00 Thanks for your amazing work, that's a very impressive debugging process. To compatible with our previous experiments with ZeRO-2, we prefer to use the cpu-way which is to initialize RoPE buffer at CPU first and then copy to GPU. For quick proof, I do some nasty hack with diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 8ceee2d1d..f6d108986 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -124,6 +124,10 @@ class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
+ import deepspeed
+ hooked_arange = torch.arange
+ torch.arange = deepspeed.runtime.zero.partition_parameters._orig_torch_arange
+
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
@@ -134,6 +138,10 @@ class LlamaRotaryEmbedding(nn.Module):
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
+ self.cos_cached = self.cos_cached.to("cuda")
+ self.sin_cached = self.sin_cached.to("cuda")
+
+ torch.arange = hooked_arange
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len This hack is to force initializing RoPE buffer at CPU and then copying to GPU. After this hack, the difference in my reproduce script has gone. Moreover, I quickly run our e2e PPO training script with ZeRO-3 and deepspeed.zero.Init. The first 10 steps show that spike in policy_loss has gone! I will do some more experiments later. |
So yeah, that's a huge difference, and I feel it is definitely the arange issue pointed above, any use of zero.Init() context manager with a low precision dtype will trash the range > 256 for bfloat16 or > 2048 for float16. The dtype for arange should always be torch.long / torch.int and then cast to float dtype as late as possible. EDIT: The div issue really should have little to no impact compared to above. |
Note that GPT-NeoX and origina Llama do this correctly And Transformers does not. So feel Transformers needs some fixes, this isn't the only instance of the pattern where a float dtype has been used in the arange for an int enumeration .... |
Ross, do you have resources to take this to HF Transformers? It doesn't look like @hijkzzz and @wuxibin89 are planning to take it there. I was helping with the diagnostics but I'm no longer at HF to take the lead. or perhaps please recommend who should be tagged from HF so that they could take the lead? |
I scanned and found ~66 code blocks across many models that are using float dtypes in position embedding style enumerations that should be left as torch.long. I've submitted this and will likely create an issue for it, it's already posted in slack there. |
You rock, Ross - thank you very much for taking the lead on this issue. It'll help so many users who don't even know they may have it. |
What do you think about my proposal to store the buffers with the model weights in the checkpoint? Then it's even easier because it'll always be correct. |
Yeah, it could be silently impacting many people using with this sort of init patching context manager, is DS the only common one? Looks like llama, falcon, phi2, mistral/mixtral, kosmos2, neox, many detr, and others are potentially impacted depending on the length of the ranges in each use... |
And for anyone coming into this thread, reason why there's the big difference between cast to float after arange vs passing the dtype, is this guard in DS Zero Init() patching, if the dtype for the range is not a floating point type, it won't cast to the low precision type: DeepSpeed/deepspeed/runtime/zero/partition_parameters.py Lines 245 to 246 in 0dd0c61
|
I don't think if FSDP does it, but they probably should as loading a huge model on FSDP is a huge problem right now. I think that if a user sets Also if the user or a framework calls so I think this problem again isn't DS-specific but a much wider one. |
@stas00 on storing buffers, I do believe that would have avoided most of the issues, as after all of the calculations from the int range -> final float pos, casting the final result from float -> float16/bfloat16 results in a much much much smaller difference than doing the int enumeration in the low precision type. However, I've never been a fan of storing this info in checkpoints myself. Depending on model, sometimes when wanting to adap seq len, resolutions, etc it's better to re-generate the buffers with new params than interpolate the stored ones... |
My thinking is that correctness out of the box is more important than convenience - perhaps HF transformers could add a flag to skip buffer overwriting if the advanced user prefers not to have that overhead? that way both needs will get met. |
For using set_default_dtype() This is a problem
This is not:
This is the safest (and also prevents issues with DS zero.Init()) and makes intent explicit
|
@rwightman |
I'm not following, why would the 'other parts' of RoPE be calculated in bfloat16? The code posted above is forcing them to float32, BUT, it does this after it's generated an incorrect sequence via the patched arange.
That's significant! It completely dwarfs any losses by casting the final float32 -> bfloat16 or float16... |
Hi, you could try the script to verify whether this bug can be solved: https://gist.github.com/wuxibin89/7d4801d62a743d5bbc72b507e8f0e326 |
If you use bfloat16 as an arange dtype, look at the value progression after 256...
|
floating point numbers have non-uniform coverage of real numbers, for lower precision types, esp with small mantissa the repr of large integers is going to be awful, but after you've multiplied with the inv freq and performed sin/cos you're in a much better region and the losses from float32 -> bfloat16 or float16 will be much less significant. |
I replaced the dtype of all
|
@hijkzzz it's not clear what this script is trying to compare? https://gist.github.com/wuxibin89/7d4801d62a743d5bbc72b507e8f0e326 the |
okay, nevermind, it is changing the operation not familiar with that mechanism, but anways.. I looked at what's happening for the buffer init calculation With the original Changing to
If we change the
So in the end we do end up with low precision buffers for the RoPe values. Normally I wouldn't think this is a big issue, not seeing how anyone's use of this model with deepspeed (e.g. other Llaama implementations) would be any different at this point. Not super familiar with the deepspeed usage patterns, without deepspeed you could keep those buffers in float32 and have everything else in low precision as long as you managed any necessary casts in forward yourself. But don't think this is typical? The mentioned cpu div issue should be neglible next to the final cast from float32 -> bfloat16. The 'hack' to disable patched arange should also be no different from the proposed dtype fix as the final cast would still happen. So if you see a difference between those two hacks not sure what's going on... |
Also worth pointing out, those buffers end up in low precision regardless of whether or not the deepspeed config is enabled, presumably due to the dtype, but the logits/probs are definitely different by non-trivial amt, I'm rather confused. |
Had another chance at end of day to look, so yeah, model is very sensitive to the sin/cos values, and it's easy to end up with differening values for those depending on the dtype used or cpu vs gpu calcs, and that's after avoiding the arange dtype problem which results in even larger errors... Other issues about this, we're not the first to the party
I have a variation of the EleutherAI solution that seems to work, a bit less hacky than the one originally proposed here, but still rather unsatisfying. Note, I had to explicitly set the A bit of a mess... class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(seq_len=max_position_embeddings)
def _set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device='cpu').float() / self.dim))
t = torch.arange(self.max_seq_len_cached, device='cpu').float()
freqs = torch.outer(t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len)
return (
self.cos_cached[:seq_len].to(x),
self.sin_cached[:seq_len].to(x),
) |
Yes, the ds_config is same in either case, this flag control whether enable |
Yeah, figured that bit of magic out, reproduced the issue, my snipped above seems to work around. With the arange dtype issue avoided, there's max difference ~+/- 7e-6 between the final output after sin/cos on CPU vs It's a rather remarkable progression of float error propagation... each step in the rope calc. After inf_freq, only a 1.87e-9 diff, 3.8e-6 after the outer product, 7e-6, then finally 5e-4 or 2e-3 before being added to q/k |
And if these buffers are to be restored from the checkpoint to exactly the values they were set to while training then none of this would be an issue. Please correct me if I'm wrong. |
@stas00 some of the issues like having the calc move to the GPU or be done in lower precesion might be helped by that, but doesn't solve the need to recalc on the fly or at init for different seq len, also thinking more the embed should really be applied to q/k in float32, having them in buffers means deep speed will downcast them, not sure that can easily be avoided unless you prevent them from being buffers of any sort... |
Understood. Thank you for clarifying, Ross! |
Hi everyone 👋 FYI, we're having a look at this issue from the |
Hi, we're OpenRLHF team, we heavily use deepspeed to build our RLHF framework and really appreciate to your great work.
Recently we encounter a problem with deepspeed.zero.Init. We want to enable HfDeepSpeedConfig to support larger model(70B+) with ZeRO-3. But after enabling HfDeepSpeedConfig, we see some very strange spikes in PPO policy_loss. If HfDeepSpeedConfig is disabled, the policy_loss is normal and same with ZeRO-2.
Below is wandb metrics with Llama-7B PPO training
The only difference is enable HfDeepSpeedConfig or not. The reproduce script is below:
https://github.com/OpenLLMAI/OpenRLHF/blob/main/examples/scripts/train_ppo_llama_ray_70b.sh
BTW. I debug parameter's sharding tensor with
parameter.ds_tensor
, they're same whether enable HfDeepSpeedConfig or not! So we really don't know why. Can you offer some clues that we can keep debugging?cc @stas00
The text was updated successfully, but these errors were encountered: