-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Correctly initialize the text model (Mistral) of Idefics2 with Flash Attention #30395
Correctly initialize the text model (Mistral) of Idefics2 with Flash Attention #30395
Conversation
Hi @zafstojano, thanks for opening this PR and addressing this issue! At the moment in the diff and commit history there's lots of changes which are unrelated to this PR which should be resolved before merge. It looks like what happens after rebasing and pushing without force pushing. If this is the case, simply force pushing should resolve |
bb9c2b4
to
669f7b1
Compare
@amyeroberts I have now force-pushed only my changes 👍 |
(curious) how to push without force after rebasing ... 👀 ? |
I've done it before but can't remember exactly the steps I took to achieve it! I think it rejects the push, you can pull and then push again. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this fix!
We should add tests to make sure the attention implementation and torch dtype properly get set from the configs.
@@ -1473,16 +1473,23 @@ def __init__(self, config: Idefics2Config): | |||
super().__init__(config) | |||
self.padding_idx = self.config.text_config.pad_token_id | |||
self.vocab_size = self.config.text_config.vocab_size | |||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we directly pass the config's attention implementation we don't need this
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" |
text_model_kwargs = {} | ||
if self._use_flash_attention_2: | ||
text_model_kwargs["use_flash_attention_2"] = True | ||
torch_dtype = None | ||
if config.text_config.torch_dtype is not None: | ||
torch_dtype = config.text_config.torch_dtype | ||
elif config.torch_dtype is not None: | ||
torch_dtype = config.torch_dtype | ||
text_model_kwargs["torch_dtype"] = torch_dtype | ||
self.text_model = AutoModel.from_config(config.text_config, **text_model_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
Let's use the same pattern as in llava-next. This is more robust to future attention implementations
-
It's a bit funny to have two possible options for
torch_dtype
here, and taking the one from the text_config as precedence. If a user specifiedmodel = IdeficsForConditionalGeneration(checkpoint, torch_dtype=torch.float16)
I'd expecttorch.float16
to be used.
text_model_kwargs = {} | |
if self._use_flash_attention_2: | |
text_model_kwargs["use_flash_attention_2"] = True | |
torch_dtype = None | |
if config.text_config.torch_dtype is not None: | |
torch_dtype = config.text_config.torch_dtype | |
elif config.torch_dtype is not None: | |
torch_dtype = config.torch_dtype | |
text_model_kwargs["torch_dtype"] = torch_dtype | |
self.text_model = AutoModel.from_config(config.text_config, **text_model_kwargs) | |
torch_dtype = config.text_config.torch_dtype | |
if config.torch_dtype is not None: | |
torch_dtype = config.torch_dtype | |
self.text_model = AutoModel.from_config( | |
config.text_config, | |
attn_implementation=config.._attn_implementation, | |
torch_dtype=torch_dtype | |
) |
@amyeroberts thank you for the constructive feedback. I am currently experiencing some weird behavior when I integrate those changes, perhaps I am not 100% familiar with the internals of the For the following implementation of the init method in class Idefics2Model(Idefics2PreTrainedModel):
def __init__(self, config: Idefics2Config):
super().__init__(config)
self.padding_idx = self.config.text_config.pad_token_id
self.vocab_size = self.config.text_config.vocab_size
self.vision_model = Idefics2VisionTransformer(config.vision_config)
self.connector = Idefics2Connector(config)
torch_dtype = config.text_config.torch_dtype
if config.torch_dtype is not None:
torch_dtype = config.torch_dtype
attn_implementation = config.text_config._attn_implementation
if config._attn_implementation is not None:
attn_implementation = config._attn_implementation
print("=================")
print("torch_dtype being passed to text_model in Idefics2Model.__init__():", torch_dtype)
print("=================")
self.text_model = AutoModel.from_config(
config.text_config,
attn_implementation=attn_implementation,
torch_dtype=torch_dtype,
)
self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = self.config.image_token_id
self.post_init() and the following code sample: import torch
from transformers import Idefics2ForConditionalGeneration
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
print("Perceiver model flash attention: ", model.model.connector.perceiver_resampler._use_flash_attention_2)
print("Vision model flash attention: ", model.model.vision_model._use_flash_attention_2)
print("Text model flash attention: ", model.model.text_model._attn_implementation == "flash_attention_2")
print('-----------------')
print("Model dtype: ", model.dtype) I get the output:
So, the flash attention is correctly propagate to all submodules when the user specifies Do you have any idea why this is happening? Unrelated to the above issue, I have another suggestion: Since the vision model is initialized from |
Hi @zafstojano, thanks for sharing this script! OK, so the behaviour of torch_dtype is quite complex and not the area of the code I'm most familiar with. In terms of what's happening in the script, I think:
cc @younesbelkada To confirm if this is right and if there's anything else to be aware of. To understand more, if in composite models like this and llava have their language model saved in e.g. float16, and their vision tower in float32; what will happen when we use torch_dtype="auto"?
Yes! Good idea |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks !
To answer @amyeroberts 's question, I think it is fine to not pass any torch_dtype
since cls._set_default_torch_dtype()
is called here:
transformers/src/transformers/modeling_utils.py
Line 3518 in a98c417
dtype_orig = cls._set_default_torch_dtype(torch_dtype) |
dtype
(i.e. either the one passed by torch_dtype
or from the config if one passes "auto")).
Once the model is initialized the original dtype is set again here.
@zafstojano,
Note torch_dtype
always dictates the dtype of the whole model, even if idefics2 is in fact a combination of models, it should be seen as a standalone model. If one wants to try out complex combinations such as loading the vision part in fp32 and the text model in fp16, they should first load the entire model in fp16 and upcast the vision part in fp32 (or the other way around).
See also the solution in llava we've been doing this with that architecture and everything looks fine so far. I think the fix I propose here should be sufficient, can you double check that?
torch_dtype = config.text_config.torch_dtype | ||
if config.torch_dtype is not None: | ||
torch_dtype = config.torch_dtype | ||
attn_implementation = config.text_config._attn_implementation | ||
if config._attn_implementation is not None: | ||
attn_implementation = config._attn_implementation | ||
self.text_model = AutoModel.from_config( | ||
config.text_config, | ||
attn_implementation=attn_implementation, | ||
torch_dtype=torch_dtype, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch_dtype = config.text_config.torch_dtype | |
if config.torch_dtype is not None: | |
torch_dtype = config.torch_dtype | |
attn_implementation = config.text_config._attn_implementation | |
if config._attn_implementation is not None: | |
attn_implementation = config._attn_implementation | |
self.text_model = AutoModel.from_config( | |
config.text_config, | |
attn_implementation=attn_implementation, | |
torch_dtype=torch_dtype, | |
) | |
self.text_model = AutoModel.from_config( | |
config.text_config, attn_implementation=config._attn_implementation | |
) |
Hi @amyeroberts @younesbelkada With the implementation you proposed, for the following sample code: import torch
from transformers import Idefics2ForConditionalGeneration
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
print("Perceiver model flash attention: ", model.model.connector.perceiver_resampler._use_flash_attention_2)
print("Vision model flash attention: ", model.model.vision_model._use_flash_attention_2)
print("Text model flash attention: ", model.model.text_model._attn_implementation == "flash_attention_2")
print('-----------------')
print("Model dtype: ", model.dtype) I get the following output:
The reason why I wanted to explicitly pass Is this acceptable? |
Moreover, when using the vision tower with Flash Attention, I get this exception:
The above error can get fixed with casting the input image hidden states to the same dtype as the input tokens going into the Mistral model: # Get sequence from the vision encoder
image_hidden_states = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
).last_hidden_state.to(dtype=self.dtype, device=input_ids.device)
# Modality projection & resampling
image_hidden_states = self.connector(
image_hidden_states, attention_mask=patch_attention_mask.view(pixel_values.size(0), -1)
).to(dtype=self.dtype, device=input_ids.device) Although, I still get the warning about upcasting:
Weirdly, this happens even if I cast all params to the same dtype (e.g. |
@zafstojano I see. So this is an issue, and a tricky one at that. @younesbelkada it doesn't seem to be the case that passing |
hmm interesting ok, I will have a deeper look then ! |
@younesbelkada @zafstojano Just to follow up on the dtype investigation, I suspect there might be a difference between the torch_dtype being passed in the model inits during instantiation, and the torch dtype used when the pretrained weights are loaded in. I just ran a quick test, and the weights do seem to be loaded in as expected: import torch
from transformers import Idefics2ForConditionalGeneration
print("Loading in as torch.bfloat16")
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
)
model.dtype
print(model.model.text_model.dtype)
print(model.model.vision_model.embeddings.position_embedding.weight.dtype)
print(model.model.connector.perceiver_resampler.latents.dtype)
print("\nLoading in as torch.float32")
model = Idefics2ForConditionalGeneration.from_pretrained(
"HuggingFaceM4/idefics2-8b",
attn_implementation="flash_attention_2",
)
model.dtype
print(model.model.text_model.dtype)
print(model.model.vision_model.embeddings.position_embedding.weight.dtype)
print(model.model.connector.perceiver_resampler.latents.dtype) Produces output:
|
thanks for investigating @amyeroberts and apologies for not investigating ! Seems all is good then ? 🙏 |
@younesbelkada Yep! I think so |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
This PR attempts to resolve the issue of the text model not being loaded with Flash Attention 2
Relevant issue: #30394
Currently, whatever combination of parameters I pass to the instantiation of the Idefics2 models, the text model is not being loaded with Flash Attention 2. Here are several examples:
_attn_implementation
tofrom_pretrained
Output:
attn_implementation
tofrom_pretrained
Output:
config
object with property_attn_implementation
:Output:
config
andattn_implementation
tofrom_pretrained
:Output:
This PR contains a simple patch which would allow the text model to be loaded with Flash Attention. Here is the output with the changes included:
Output
It is not an ideal fix, since it requires both passing a
config
object and aattn_implementation
parameter. Moreover, it relies on theuse_flash_attention_2
parameter which might be deprecated soon.Criticism, feedback and requests for changes are welcomed.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Ping: @amyeroberts