Skip to content

Commit

Permalink
[FA-2] Fix fa-2 issue when passing config to from_pretrained (#…
Browse files Browse the repository at this point in the history
…28043)

* fix fa-2 issue

* fix test

* Update src/transformers/modeling_utils.py

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>

* clenaer fix

* up

* add more robust tests

* Update src/transformers/modeling_utils.py

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>

* fixup

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* pop

* add test

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
3 people committed Dec 18, 2023
1 parent f33b061 commit 6e4429f
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 0 deletions.
12 changes: 12 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2955,6 +2955,18 @@ def from_pretrained(
**kwargs,
)
else:
# In case one passes a config to `from_pretrained` + "attn_implementation"
# override the `_attn_implementation` attribute to `attn_implementation` of the kwargs
# Please see: https://github.com/huggingface/transformers/issues/28038

# Overwrite `config._attn_implementation` by the one from the kwargs --> in auto-factory
# we pop attn_implementation from the kwargs but this handles the case where users
# passes manually the config to `from_pretrained`.
config = copy.deepcopy(config)

kwarg_attn_imp = kwargs.pop("attn_implementation", None)
if kwarg_attn_imp is not None and config._attn_implementation != kwarg_attn_imp:
config._attn_implementation = kwarg_attn_imp
model_kwargs = kwargs

quantizer = None
Expand Down
25 changes: 25 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,16 @@ def test_error_no_flash_available(self):

self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))

def test_error_no_flash_available_with_config(self):
with self.assertRaises(ValueError) as cm:
config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")

_ = AutoModel.from_pretrained(
"hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_2"
)

self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))

def test_error_wrong_attn_implementation(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
Expand All @@ -1840,6 +1850,21 @@ def test_not_available_flash(self):

self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))

def test_not_available_flash_with_config(self):
if is_flash_attn_2_available():
self.skipTest("Please uninstall flash-attn package to run test_not_available_flash")

config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")

with self.assertRaises(ImportError) as cm:
_ = AutoModel.from_pretrained(
"hf-internal-testing/tiny-random-GPTBigCodeModel",
config=config,
attn_implementation="flash_attention_2",
)

self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))

def test_not_available_sdpa(self):
if is_torch_sdpa_available():
self.skipTest("This test requires torch<=2.0")
Expand Down

0 comments on commit 6e4429f

Please sign in to comment.