From 501f893a5a1edfc2ad69f626111732dca8ce4d4d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 29 Aug 2025 12:14:45 +0200 Subject: [PATCH 01/11] clean --- src/transformers/modeling_utils.py | 107 ++++++++++++++++------------- 1 file changed, 60 insertions(+), 47 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4a96e9e537e2..3e0b0b6f3d74 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -44,7 +44,7 @@ from torch.distributions import constraints from torch.utils.checkpoint import checkpoint -from transformers.utils import is_torchao_available +from .utils import is_torchao_available if is_torchao_available(): @@ -75,14 +75,8 @@ from .loss.loss_utils import LOSS_MAPPING from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from .modeling_flash_attention_utils import lazy_import_flash_attention -from .pytorch_utils import ( # noqa: F401 - Conv1D, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, +from .pytorch_utils import ( id_tensor_storage, - prune_conv1d_layer, - prune_layer, - prune_linear_layer, ) from .quantizers import HfQuantizer from .quantizers.auto import get_hf_quantizer @@ -2801,44 +2795,10 @@ def _check_and_adjust_attn_implementation( and is_kernels_available() ): applicable_attn_implementation = "kernels-community/flash-attn" - if applicable_attn_implementation is not None and re.match( - r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", applicable_attn_implementation - ): - if not is_kernels_available(): - raise ValueError("kernels is not installed. Please install it with `pip install kernels`.") - attention_wrapper = None - # FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work - actual_attn_name = applicable_attn_implementation - if "|" in applicable_attn_implementation: - attention_wrapper, actual_attn_name = applicable_attn_implementation.split("|") - # `transformers` has wrapper for sdpa, paged, flash, flex etc. - attention_wrapper = ALL_ATTENTION_FUNCTIONS.get(attention_wrapper) - # Extract repo_id and kernel_name from the string - if ":" in actual_attn_name: - repo_id, kernel_name = actual_attn_name.split(":") - kernel_name = kernel_name.strip() - else: - repo_id = actual_attn_name - kernel_name = None - repo_id = repo_id.strip() - # extract the rev after the @ if it exists - repo_id, _, rev = repo_id.partition("@") - repo_id = repo_id.strip() - rev = rev.strip() if rev else None + if is_kernel(applicable_attn_implementation): try: - kernel = get_kernel(repo_id, revision=rev) - if hasattr(kernel, "flash_attn_varlen_func"): - if attention_wrapper is None: - attention_wrapper = flash_attention_forward - kernel_function = partial(attention_wrapper, implementation=kernel) - lazy_import_flash_attention(kernel) - elif kernel_name is not None: - kernel_function = getattr(kernel, kernel_name) - ALL_ATTENTION_FUNCTIONS.register(applicable_attn_implementation, kernel_function) - ALL_MASK_ATTENTION_FUNCTIONS.register( - applicable_attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] - ) - # log that we used kernel fallback + load_and_register_kernel(applicable_attn_implementation) + # log that we used kernel fallback if successful if attn_implementation == "flash_attention_2": logger.warning_once( "You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` " @@ -2848,8 +2808,8 @@ def _check_and_adjust_attn_implementation( if attn_implementation == "flash_attention_2": self._flash_attn_2_can_dispatch() # will fail as fa2 is not available but raise the proper exception logger.warning_once( - f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using " - "default attention implementation instead (sdpa if available, eager otherwise)." + f"Could not find a kernel matching `{applicable_attn_implementation}` compatible with your device in the " + f"hub:\n{e}.\nUsing default attention implementation instead (sdpa if available, eager otherwise)." ) try: self._sdpa_can_dispatch(is_init_check) @@ -6297,6 +6257,59 @@ def get_disk_only_shard_files(device_map, weight_map): return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] +def is_kernel(attn_implementation: Optional[str]) -> bool: + """Check whether `attn_implementation` matches a kernel pattern from the hub.""" + return ( + attn_implementation is not None + and re.search(r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", attn_implementation) is not None + ) + + +def load_and_register_kernel(attn_implementation: str) -> None: + """Load and register the kernel associated to `attn_implementation`.""" + if not is_kernel(attn_implementation): + return + if not is_kernels_available(): + raise ImportError("`kernels` is not installed. Please install it with `pip install kernels`.") + + attention_wrapper = None + # FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work + actual_attn_name = attn_implementation + if "|" in attn_implementation: + attention_wrapper, actual_attn_name = attn_implementation.split("|") + # `transformers` has wrapper for sdpa, paged, flash, flex etc. + attention_wrapper = ALL_ATTENTION_FUNCTIONS.get(attention_wrapper) + # Extract repo_id and kernel_name from the string + if ":" in actual_attn_name: + repo_id, kernel_name = actual_attn_name.split(":") + kernel_name = kernel_name.strip() + else: + repo_id = actual_attn_name + kernel_name = None + repo_id = repo_id.strip() + # extract the rev after the @ if it exists + repo_id, _, rev = repo_id.partition("@") + repo_id = repo_id.strip() + rev = rev.strip() if rev else None + + # Load the kernel from hub + try: + kernel = get_kernel(repo_id, revision=rev) + except Exception as e: + raise ValueError(f"An error occured while trying to load from '{repo_id}': {e}.") + # correctly wrap the kernel + if hasattr(kernel, "flash_attn_varlen_func"): + if attention_wrapper is None: + attention_wrapper = flash_attention_forward + kernel_function = partial(attention_wrapper, implementation=kernel) + lazy_import_flash_attention(kernel) + elif kernel_name is not None: + kernel_function = getattr(kernel, kernel_name) + # Register the kernel as a valid attention + ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function) + ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) + + class AttentionInterface(GeneralInterface): """ Dict-like object keeping track of allowed attention functions. You can easily add a new attention function From fd287ae7d10945e2c7d911cec31449164dd010f6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 29 Aug 2025 12:25:14 +0200 Subject: [PATCH 02/11] clean imporrts --- src/transformers/modeling_utils.py | 52 ++++++++++++------------------ 1 file changed, 21 insertions(+), 31 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 3e0b0b6f3d74..94b69f36b229 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -44,12 +44,6 @@ from torch.distributions import constraints from torch.utils.checkpoint import checkpoint -from .utils import is_torchao_available - - -if is_torchao_available(): - from torchao.quantization import Int4WeightOnlyConfig - from .configuration_utils import PretrainedConfig from .distributed import DistributedConfig from .dynamic_module_utils import custom_object_save @@ -75,9 +69,7 @@ from .loss.loss_utils import LOSS_MAPPING from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from .modeling_flash_attention_utils import lazy_import_flash_attention -from .pytorch_utils import ( - id_tensor_storage, -) +from .pytorch_utils import id_tensor_storage from .quantizers import HfQuantizer from .quantizers.auto import get_hf_quantizer from .quantizers.quantizers_utils import get_module_from_name @@ -118,6 +110,7 @@ is_torch_npu_available, is_torch_xla_available, is_torch_xpu_available, + is_torchao_available, logging, ) from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder @@ -132,9 +125,8 @@ from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod -XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() -XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() - +if is_torchao_available(): + from torchao.quantization import Int4WeightOnlyConfig if is_accelerate_available(): from accelerate import dispatch_model, infer_auto_device_map @@ -158,32 +150,17 @@ from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file +if is_peft_available(): + from .utils import find_adapter_config_file if is_kernels_available(): from kernels import get_kernel - -logger = logging.get_logger(__name__) - - -_init_weights = True -_is_quantized = False -_is_ds_init_called = False _torch_distributed_available = torch.distributed.is_available() - _is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5") if _is_dtensor_available: from torch.distributed.tensor import DTensor - -def is_local_dist_rank_0(): - return ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - and int(os.environ.get("LOCAL_RANK", "-1")) == 0 - ) - - if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp from smdistributed.modelparallel import __version__ as SMP_VERSION @@ -192,11 +169,24 @@ def is_local_dist_rank_0(): else: IS_SAGEMAKER_MP_POST_1_10 = False -if is_peft_available(): - from .utils import find_adapter_config_file +logger = logging.get_logger(__name__) +XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() +XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel") +_init_weights = True +_is_quantized = False +_is_ds_init_called = False + + +def is_local_dist_rank_0(): + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and int(os.environ.get("LOCAL_RANK", "-1")) == 0 + ) + TORCH_INIT_FUNCTIONS = { "uniform_": nn.init.uniform_, From 8317e5d9765cc807beb12ea4a356d824ccf68970 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 29 Aug 2025 12:34:02 +0200 Subject: [PATCH 03/11] fix imports --- src/transformers/models/blip/modeling_blip_text.py | 8 ++------ .../models/deprecated/mctct/modeling_mctct.py | 8 ++------ src/transformers/models/esm/modeling_esm.py | 8 ++------ src/transformers/models/evolla/modeling_evolla.py | 10 ++-------- src/transformers/models/markuplm/modeling_markuplm.py | 9 ++------- 5 files changed, 10 insertions(+), 33 deletions(-) diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 3eff5ac785e2..2eaf655d01be 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -31,12 +31,8 @@ BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging from ...utils.deprecation import deprecate_kwarg from .configuration_blip import BlipTextConfig diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index a6686c1eb2a1..6b4d73045159 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -28,13 +28,9 @@ from ....modeling_attn_mask_utils import _prepare_4d_attention_mask from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutput, CausalLMOutput -from ....modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) from ....utils import logging +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from .configuration_mctct import MCTCTConfig diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index aaea1614bb75..5db366aa6197 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -31,13 +31,9 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - ALL_ATTENTION_FUNCTIONS, - PreTrainedModel, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_esm import EsmConfig diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index e74fadc049b1..f0677d5f600d 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -41,15 +41,9 @@ ModelOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import ( - ALL_ATTENTION_FUNCTIONS, - ModuleUtilsMixin, - PreTrainedModel, - find_pruneable_heads_and_indices, - get_parameter_dtype, - prune_linear_layer, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, ModuleUtilsMixin, PreTrainedModel, get_parameter_dtype from ...processing_utils import Unpack +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder, check_model_inputs diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 0dd845ecff3c..aeb817be7060 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -32,13 +32,8 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - ALL_ATTENTION_FUNCTIONS, - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, can_return_tuple, logging from .configuration_markuplm import MarkupLMConfig From f7a259bd36e541464fc2c2b04359c2560290ec8b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 29 Aug 2025 12:36:44 +0200 Subject: [PATCH 04/11] oups --- src/transformers/models/deprecated/mctct/modeling_mctct.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index 6b4d73045159..06b1efc0d9e8 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -29,8 +29,8 @@ from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutput, CausalLMOutput from ....utils import logging -from ...modeling_utils import PreTrainedModel -from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ....modeling_utils import PreTrainedModel +from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from .configuration_mctct import MCTCTConfig From ceab19d29a37f26c259835a259fef1a00b8b5d21 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 29 Aug 2025 12:43:00 +0200 Subject: [PATCH 05/11] more imports --- src/transformers/models/deprecated/mctct/modeling_mctct.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index 06b1efc0d9e8..3c9d259e8215 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -28,9 +28,9 @@ from ....modeling_attn_mask_utils import _prepare_4d_attention_mask from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutput, CausalLMOutput -from ....utils import logging from ....modeling_utils import PreTrainedModel from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer +from ....utils import logging from .configuration_mctct import MCTCTConfig From bd97e7131c349254fab53a7385fa6a81aae98cec Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 29 Aug 2025 12:45:17 +0200 Subject: [PATCH 06/11] more imports --- src/transformers/models/cvt/modeling_cvt.py | 3 ++- src/transformers/models/flava/modeling_flava.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py index e838ffb3cd41..85e2bde325e2 100644 --- a/src/transformers/models/cvt/modeling_cvt.py +++ b/src/transformers/models/cvt/modeling_cvt.py @@ -24,7 +24,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging from .configuration_cvt import CvtConfig diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 63eadf41c380..cafd6e589adf 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -27,7 +27,8 @@ from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int from .configuration_flava import ( FlavaConfig, From 6b9d3fae680be848de88410a158493f06f960675 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 29 Aug 2025 13:11:08 +0200 Subject: [PATCH 07/11] more --- src/transformers/models/bridgetower/modeling_bridgetower.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 385f6b224825..ba29c6fe324d 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -34,8 +34,8 @@ ModelOutput, SequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int from ...utils.deprecation import deprecate_kwarg from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig From ff077bec70bf2b432852bdd2a5bf0c0c3ca52a66 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 29 Aug 2025 13:49:27 +0200 Subject: [PATCH 08/11] move it to integrations --- src/transformers/integrations/hub_kernels.py | 74 +++++++++++++++++--- src/transformers/modeling_utils.py | 58 +-------------- 2 files changed, 65 insertions(+), 67 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index ad5e08d8da4d..68fd35a7bf17 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -11,21 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +import re +from functools import partial +from typing import Optional, Union +from ..modeling_flash_attention_utils import lazy_import_flash_attention +from ..utils import is_kernels_available +from .flash_attention import flash_attention_forward -try: + +if is_kernels_available(): from kernels import ( Device, LayerRepository, Mode, + get_kernel, register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, ) - _hub_kernels_available = True - _KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = { "MultiScaleDeformableAttention": { "cuda": LayerRepository( @@ -82,8 +87,7 @@ register_kernel_mapping(_KERNEL_MAPPING) - -except ImportError: +else: # Stub to make decorators int transformers work when `kernels` # is not installed. def use_kernel_forward_from_hub(*args, **kwargs): @@ -104,16 +108,66 @@ def replace_kernel_forward_from_hub(*args, **kwargs): def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") - _hub_kernels_available = False + +def is_kernel(attn_implementation: Optional[str]) -> bool: + """Check whether `attn_implementation` matches a kernel pattern from the hub.""" + return ( + attn_implementation is not None + and re.search(r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", attn_implementation) is not None + ) -def is_hub_kernels_available(): - return _hub_kernels_available +def load_and_register_kernel(attn_implementation: str) -> None: + """Load and register the kernel associated to `attn_implementation`.""" + if not is_kernel(attn_implementation): + return + if not is_kernels_available(): + raise ImportError("`kernels` is not installed. Please install it with `pip install kernels`.") + + # Need to be imported here as otherwise we have a circular import in `modeling_utils` + from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS + from ..modeling_utils import ALL_ATTENTION_FUNCTIONS + + attention_wrapper = None + # FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work + actual_attn_name = attn_implementation + if "|" in attn_implementation: + attention_wrapper, actual_attn_name = attn_implementation.split("|") + # `transformers` has wrapper for sdpa, paged, flash, flex etc. + attention_wrapper = ALL_ATTENTION_FUNCTIONS.get(attention_wrapper) + # Extract repo_id and kernel_name from the string + if ":" in actual_attn_name: + repo_id, kernel_name = actual_attn_name.split(":") + kernel_name = kernel_name.strip() + else: + repo_id = actual_attn_name + kernel_name = None + repo_id = repo_id.strip() + # extract the rev after the @ if it exists + repo_id, _, rev = repo_id.partition("@") + repo_id = repo_id.strip() + rev = rev.strip() if rev else None + + # Load the kernel from hub + try: + kernel = get_kernel(repo_id, revision=rev) + except Exception as e: + raise ValueError(f"An error occured while trying to load from '{repo_id}': {e}.") + # correctly wrap the kernel + if hasattr(kernel, "flash_attn_varlen_func"): + if attention_wrapper is None: + attention_wrapper = flash_attention_forward + kernel_function = partial(attention_wrapper, implementation=kernel) + lazy_import_flash_attention(kernel) + elif kernel_name is not None: + kernel_function = getattr(kernel, kernel_name) + # Register the kernel as a valid attention + ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function) + ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) __all__ = [ "LayerRepository", - "is_hub_kernels_available", "use_kernel_forward_from_hub", "register_kernel_mapping", "replace_kernel_forward_from_hub", diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 94b69f36b229..973ee405cb3a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -55,6 +55,7 @@ from .integrations.flash_attention import flash_attention_forward from .integrations.flash_paged import paged_attention_forward from .integrations.flex_attention import flex_attention_forward +from .integrations.hub_kernels import is_kernel, load_and_register_kernel from .integrations.sdpa_attention import sdpa_attention_forward from .integrations.sdpa_paged import sdpa_attention_paged_forward from .integrations.tensor_parallel import ( @@ -67,7 +68,6 @@ verify_tp_plan, ) from .loss.loss_utils import LOSS_MAPPING -from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from .modeling_flash_attention_utils import lazy_import_flash_attention from .pytorch_utils import id_tensor_storage from .quantizers import HfQuantizer @@ -153,9 +153,6 @@ if is_peft_available(): from .utils import find_adapter_config_file -if is_kernels_available(): - from kernels import get_kernel - _torch_distributed_available = torch.distributed.is_available() _is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5") if _is_dtensor_available: @@ -6247,59 +6244,6 @@ def get_disk_only_shard_files(device_map, weight_map): return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}] -def is_kernel(attn_implementation: Optional[str]) -> bool: - """Check whether `attn_implementation` matches a kernel pattern from the hub.""" - return ( - attn_implementation is not None - and re.search(r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", attn_implementation) is not None - ) - - -def load_and_register_kernel(attn_implementation: str) -> None: - """Load and register the kernel associated to `attn_implementation`.""" - if not is_kernel(attn_implementation): - return - if not is_kernels_available(): - raise ImportError("`kernels` is not installed. Please install it with `pip install kernels`.") - - attention_wrapper = None - # FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work - actual_attn_name = attn_implementation - if "|" in attn_implementation: - attention_wrapper, actual_attn_name = attn_implementation.split("|") - # `transformers` has wrapper for sdpa, paged, flash, flex etc. - attention_wrapper = ALL_ATTENTION_FUNCTIONS.get(attention_wrapper) - # Extract repo_id and kernel_name from the string - if ":" in actual_attn_name: - repo_id, kernel_name = actual_attn_name.split(":") - kernel_name = kernel_name.strip() - else: - repo_id = actual_attn_name - kernel_name = None - repo_id = repo_id.strip() - # extract the rev after the @ if it exists - repo_id, _, rev = repo_id.partition("@") - repo_id = repo_id.strip() - rev = rev.strip() if rev else None - - # Load the kernel from hub - try: - kernel = get_kernel(repo_id, revision=rev) - except Exception as e: - raise ValueError(f"An error occured while trying to load from '{repo_id}': {e}.") - # correctly wrap the kernel - if hasattr(kernel, "flash_attn_varlen_func"): - if attention_wrapper is None: - attention_wrapper = flash_attention_forward - kernel_function = partial(attention_wrapper, implementation=kernel) - lazy_import_flash_attention(kernel) - elif kernel_name is not None: - kernel_function = getattr(kernel, kernel_name) - # Register the kernel as a valid attention - ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function) - ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) - - class AttentionInterface(GeneralInterface): """ Dict-like object keeping track of allowed attention functions. You can easily add a new attention function From ba4b4aca5d3f6a59c8a74dc9740f844db242bff9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 29 Aug 2025 13:59:24 +0200 Subject: [PATCH 09/11] fix --- src/transformers/integrations/hub_kernels.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 68fd35a7bf17..3608bedf9df4 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -16,11 +16,10 @@ from typing import Optional, Union from ..modeling_flash_attention_utils import lazy_import_flash_attention -from ..utils import is_kernels_available from .flash_attention import flash_attention_forward -if is_kernels_available(): +try: from kernels import ( Device, LayerRepository, @@ -30,6 +29,7 @@ replace_kernel_forward_from_hub, use_kernel_forward_from_hub, ) + _kernels_available = True _KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = { "MultiScaleDeformableAttention": { @@ -87,7 +87,8 @@ register_kernel_mapping(_KERNEL_MAPPING) -else: +except ImportError: + _kernels_available = False # Stub to make decorators int transformers work when `kernels` # is not installed. def use_kernel_forward_from_hub(*args, **kwargs): @@ -121,7 +122,7 @@ def load_and_register_kernel(attn_implementation: str) -> None: """Load and register the kernel associated to `attn_implementation`.""" if not is_kernel(attn_implementation): return - if not is_kernels_available(): + if not _kernels_available: raise ImportError("`kernels` is not installed. Please install it with `pip install kernels`.") # Need to be imported here as otherwise we have a circular import in `modeling_utils` From 0b3c307e3017dd3acd29e5bf2d67cfede3a17b13 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 29 Aug 2025 13:59:50 +0200 Subject: [PATCH 10/11] style --- src/transformers/integrations/hub_kernels.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 3608bedf9df4..78432215d9e1 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -29,6 +29,7 @@ replace_kernel_forward_from_hub, use_kernel_forward_from_hub, ) + _kernels_available = True _KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = { @@ -89,6 +90,7 @@ except ImportError: _kernels_available = False + # Stub to make decorators int transformers work when `kernels` # is not installed. def use_kernel_forward_from_hub(*args, **kwargs): From 59e3a5288e82e3133bfb2bff8140407dafdf36cc Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 29 Aug 2025 14:05:12 +0200 Subject: [PATCH 11/11] fix doc --- src/transformers/modeling_flash_attention_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 47aaedc99fba..37554773a85f 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -126,10 +126,10 @@ def _lazy_define_process_function(flash_function): def lazy_import_flash_attention(implementation: Optional[str]): """ - Lazy loading flash attention and returning the respective functions + flags back + Lazily import flash attention and return the respective functions + flags. - NOTE: For fullgraph, this needs to be called before compile while no fullgraph can - can work without preloading. See `_check_and_adjust_attn_implementation` in `modeling_utils`. + NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can + work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`. """ global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn if any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):