Skip to content

Conversation

bminixhofer
Copy link
Contributor

@bminixhofer bminixhofer commented Mar 9, 2024

What does this PR do?

Previously, the Flax Llama implementation did not use num_key_value_heads. This means it only worked if num_key_value_heads is equal to num_attention_heads.

This leads to e.g. TinyLlama failing to load:

In [1]: from transformers import FlaxAutoModelForCausalLM

In [2]: FlaxAutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", from_pt=True)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 1
----> 1 FlaxAutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", from_pt=True)

File ~/.local/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py:561, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    559 elif type(config) in cls._model_mapping.keys():
    560     model_class = _get_model_class(config, cls._model_mapping)
--> 561     return model_class.from_pretrained(
    562         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    563     )
    564 raise ValueError(
    565     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    566     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
    567 )

File ~/.local/lib/python3.8/site-packages/transformers/modeling_flax_utils.py:903, in FlaxPreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, dtype, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, *model_args, **kwargs)
    900 model = cls(config, *model_args, _do_init=_do_init, **model_kwargs)
    902 if from_pt or safetensors_from_pt:
--> 903     state = load_pytorch_checkpoint_in_flax_state_dict(model, resolved_archive_file, is_sharded)
    904 else:
    905     if is_sharded:

File ~/.local/lib/python3.8/site-packages/transformers/modeling_flax_pytorch_utils.py:81, in load_pytorch_checkpoint_in_flax_state_dict(flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys)
     78         pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
     79         logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
---> 81     flax_state_dict = convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
     82 else:
     83     # model is sharded and pytorch_checkpoint_path already contains the list of .pt shard files
     84     flax_state_dict = convert_pytorch_sharded_state_dict_to_flax(pytorch_checkpoint_path, flax_model)

File ~/.local/lib/python3.8/site-packages/transformers/modeling_flax_pytorch_utils.py:214, in convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model)
    212 if flax_key in random_flax_state_dict:
    213     if flax_tensor.shape != random_flax_state_dict[flax_key].shape:
--> 214         raise ValueError(
    215             f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
    216             f"{random_flax_state_dict[flax_key].shape}, but is {flax_tensor.shape}."
    217         )
    219 # add batch stats if the model contains batchnorm layers
    220 if "batch_stats" in flax_model.params:

ValueError: PyTorch checkpoint seems to be incorrect. Weight model.layers.0.self_attn.k_proj.weight was expected to be of shape (2048, 2048), but is (2048, 256).

This PR fixes this by adding support for distinct num_key_value_heads. I used the implementation from the Flax Mistral model and adjusted variable names. With the PR, this runs through:

import jax
import torch
from transformers import FlaxAutoModelForCausalLM, AutoModelForCausalLM

model_flax = FlaxAutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", from_pt=True)
model_pt = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

assert (model_pt(torch.arange(5)[None]).logits.argmax(-1).numpy() == model_flax(torch.arange(5)[None]).logits.argmax(-1)).all()

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@sanchit-gandhi @vvvm23

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@bminixhofer
Copy link
Contributor Author

@sanchit-gandhi @vvvm23 @ArthurZucker gentle bump 🤗

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I did not get the initial ping!
LGTM, this was actually ported to gemma first, but good call!
Me might be able to use copied from Llama for gemma now no? 🤗

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@bminixhofer
Copy link
Contributor Author

That's on me, I didn't ping you initially.

Me might be able to use copied from Llama for gemma now no?

Sorry I don't know what you mean here 😅

@ArthurZucker
Copy link
Collaborator

I mean that now that both attention use GQA we might be able to squeeze Copied from, but can be a follow up! Let's merge 🔥

@ArthurZucker ArthurZucker merged commit 8e08aca into huggingface:main Mar 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants