Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama 2 model divergence with FSDP #28826

Closed
2 of 4 tasks
Teng-xu opened this issue Feb 2, 2024 · 7 comments
Closed
2 of 4 tasks

Llama 2 model divergence with FSDP #28826

Teng-xu opened this issue Feb 2, 2024 · 7 comments

Comments

@Teng-xu
Copy link

Teng-xu commented Feb 2, 2024

System Info

  • transformers version: 4.37.1
  • Platform: Linux-5.10.199-190.747.amzn2.x86_64-x86_64-with-glibc2.31
  • Python version: 3.10.8
  • Huggingface_hub version: 0.20.2
  • Safetensors version: 0.3.3
  • Accelerate version: 0.26.1
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: yes

Who can help?

When fine-tuning Llama 2 model with HF 4.37 and PT FSDP, found model divergence in comparison to HF 4.31. Fine-tuning with 4.31 works fine, but with HF 4.37, the loss consistently rises instead of stabilizing when setting attn_implementation="flash_attention_2", while attn_implementation="sdpa" works fine.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The model is inited as
model = AutoModelForCausalLM.from_pretrained(pretrained_model_weights, attn_implementation="flash_attention_2")

Expected behavior

The loss should not go up as the training goes.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Feb 2, 2024

cc @younesbelkada I think we have seen something similar recently?

@younesbelkada
Copy link
Contributor

@Teng-xu are you correctly enabling mixed precision through bf16=True in TrainingArguments ?

@Teng-xu
Copy link
Author

Teng-xu commented Feb 2, 2024

Yeah bf16 was passed into the training args, and I can verify it is being applied correctly.

@rnadimp
Copy link

rnadimp commented Feb 13, 2024

Just to provide more context on this issue I am attaching a simple script to reproduce the issue and its associated output. Note, I am just using a random tensor as the dataset and for consistency I just saved the labels associated from another training script and loaded it from a pickle object.

Script:

import functools

import numpy as np
import torch

# pylint: disable=import-error,import-outside-toplevel,invalid-name,no-member,no-name-in-module,protected-access
import transformers
from fsdp_utils import get_transformer_layer
from learning_rates import AnnealingLR  # pylint: disable=wrong-import-order
from logging_utils import get_logger
from packaging import version as pversion
from torch.nn import LayerNorm
from transformers import AutoModelForCausalLM
from transformers.models.llama.modeling_llama import LlamaRMSNorm
#model init
# flash_attention_2, sdpa, eager
model1 = AutoModelForCausalLM.from_pretrained(pretrained_model_weights, attn_implementation="flash_attention_2")
model2 = AutoModelForCausalLM.from_pretrained(pretrained_model_weights, attn_implementation="sdpa")

model1.model.layers = model1.model.layers[:4]
model2.model.layers = model2.model.layers[:4]

model1 = model1.type(torch.bfloat16)
model2 = model2.type(torch.bfloat16)

model1 = model1.to("cuda")
model2 = model2.to("cuda")


# creating dummy tensor
tensor = torch.randint(low=0, high=9, size=(1, 4096), dtype=torch.int32).to("cuda")
#tensor = torch.randint([1, 4096], dtype=torch.int32).to("cuda")
import pickle
labels = pickle.load( open( "labels.p", "rb" ) ).to("cuda")

# model fwd/bwd pass
out1 = model1(input_ids=tensor, attention_mask=None, labels=labels)
loss1 = out1["loss"]
logits1 = out1["logits"]

out2 = model2(input_ids=tensor, attention_mask=None, labels=labels)
loss2 = out2["loss"]
logits2 = out2["logits"]

# model output cmp
if torch.allclose(logits1, logits2, atol=1e-0):
    print("logits equal~~~~~~~~~")
else:
    print("logits not equal~~~~~~~~~~")

print("logits 1:")
print(logits1)

print("logits 2:")
print(logits2)

print("max diff between logits:")
print(torch.max(torch.abs(logits1 - logits2)))

loss1.backward()
loss2.backward()

print("loss 1:")
print(loss1)

print("loss 2:")
print(loss2)

if (torch.allclose(loss1, loss2)):
    print("loss equal~~~~~~~~~")
else:
    print("loss not equal~~~~~~~~~~")

Output of script:

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.                                                                        
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. No dtype was provided, you should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator.       
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes. No dtype was provided, you should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator.       
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:15<00:00,  7.91s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:24<00:00, 12.15s/it]
logits equal~~~~~~~~~                                                                                                                                                                                                                              
logits 1:                                                                                                                                                                                                                                          
tensor([[[-1.3047, -2.2812,  2.2500,  ..., -1.6094, -1.5078,  0.6914],                                                                                                                                                                             
         [-2.1094, -4.6875,  1.2031,  ..., -1.1484, -1.7109, -0.4336],                                                                                                                                                                             
         [-1.1719, -4.6562,  0.3516,  ...,  0.3301, -0.9727,  0.2852],                                                                                                                                                                             
         ...,                                                                                                                                                                                                                                      
         [-2.2188,  8.8125,  1.4219,  ..., -1.3906, -1.7266, -3.6250],                                                                                                                                                                             
         [-0.9844, 11.0625,  0.7617,  ..., -0.4609,  0.0225, -2.7188],                                                                                                                                                                             
         [-1.0234, 10.8750,  0.8125,  ..., -0.4395, -0.1641, -2.7656]]],                                                                                                                                                                           
       device='cuda:0', grad_fn=<ToCopyBackward0>)                                                                                                                                                                                                 
logits 2:                                                                                                                                                                                                                                          
tensor([[[-1.3047e+00, -2.2812e+00,  2.2500e+00,  ..., -1.6094e+00,                                                                                                                                                                                
          -1.5078e+00,  6.9141e-01],                                                                                                                                                                                                               
         [-2.1094e+00, -4.6875e+00,  1.2031e+00,  ..., -1.1484e+00,                                                                                                                                                                                
          -1.7109e+00, -4.3359e-01],                                                                                                                                                                                                               
         [-1.1719e+00, -4.6562e+00,  3.5156e-01,  ...,  3.3008e-01,                                                                                                                                                                                
          -9.7266e-01,  2.8516e-01],                                                                                                                                                                                                               
         ...,                                                                                                                                                                                                                                      
         [-2.2188e+00,  8.8125e+00,  1.4297e+00,  ..., -1.3984e+00,                                                                                                                                                                                
          -1.7344e+00, -3.6562e+00],                                                                                                                                                                                                               
         [-9.8047e-01,  1.1062e+01,  7.5391e-01,  ..., -4.3945e-01,                                                                                                                                                                                
          -3.9673e-04, -2.7188e+00],                                                                                                                                                                                                               
         [-1.0391e+00,  1.0875e+01,  8.1641e-01,  ..., -4.4922e-01,                                                                                                                                                                                
          -1.7188e-01, -2.7812e+00]]], device='cuda:0',                                                                                                                                                                                            
       grad_fn=<ToCopyBackward0>)                                                                                                                                                                                                                  
max diff between logits:                                                                                                                                                                                                                           
tensor(0.2500, device='cuda:0', grad_fn=<MaxBackward1>)                                                                                                                                                                                            
loss 1:                                                                                                                                                                                                                                            
tensor(13.4215, device='cuda:0', grad_fn=<NllLossBackward0>)                                                                                                                                                                                       
loss 2:                                                                                                                                                                                                                                            
tensor(13.4206, device='cuda:0', grad_fn=<NllLossBackward0>)                                                                                                                                                                                       
loss not equal~~~~~~~~~~```

@goswamig
Copy link

Tagging @pacman100 to take a look.

@younesbelkada
Copy link
Contributor

Hi @rnadimp
Thanks for the snippet !
I am not surprised to see that there is a relatively small difference between SDPA and FA2. The diff you shared is quite small and acceptable IMO, note that even though FA2 guarantees numerically identical results against SDPA, in practice due to kernels being different, there is always going to be a small difference between both implementations.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants