diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index ad5e08d8da4d..78432215d9e1 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -11,7 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +import re +from functools import partial +from typing import Optional, Union + +from ..modeling_flash_attention_utils import lazy_import_flash_attention +from .flash_attention import flash_attention_forward try: @@ -19,12 +24,13 @@ Device, LayerRepository, Mode, + get_kernel, register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, ) - _hub_kernels_available = True + _kernels_available = True _KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = { "MultiScaleDeformableAttention": { @@ -82,8 +88,9 @@ register_kernel_mapping(_KERNEL_MAPPING) - except ImportError: + _kernels_available = False + # Stub to make decorators int transformers work when `kernels` # is not installed. def use_kernel_forward_from_hub(*args, **kwargs): @@ -104,16 +111,66 @@ def replace_kernel_forward_from_hub(*args, **kwargs): def register_kernel_mapping(*args, **kwargs): raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.") - _hub_kernels_available = False + +def is_kernel(attn_implementation: Optional[str]) -> bool: + """Check whether `attn_implementation` matches a kernel pattern from the hub.""" + return ( + attn_implementation is not None + and re.search(r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", attn_implementation) is not None + ) -def is_hub_kernels_available(): - return _hub_kernels_available +def load_and_register_kernel(attn_implementation: str) -> None: + """Load and register the kernel associated to `attn_implementation`.""" + if not is_kernel(attn_implementation): + return + if not _kernels_available: + raise ImportError("`kernels` is not installed. Please install it with `pip install kernels`.") + + # Need to be imported here as otherwise we have a circular import in `modeling_utils` + from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS + from ..modeling_utils import ALL_ATTENTION_FUNCTIONS + + attention_wrapper = None + # FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work + actual_attn_name = attn_implementation + if "|" in attn_implementation: + attention_wrapper, actual_attn_name = attn_implementation.split("|") + # `transformers` has wrapper for sdpa, paged, flash, flex etc. + attention_wrapper = ALL_ATTENTION_FUNCTIONS.get(attention_wrapper) + # Extract repo_id and kernel_name from the string + if ":" in actual_attn_name: + repo_id, kernel_name = actual_attn_name.split(":") + kernel_name = kernel_name.strip() + else: + repo_id = actual_attn_name + kernel_name = None + repo_id = repo_id.strip() + # extract the rev after the @ if it exists + repo_id, _, rev = repo_id.partition("@") + repo_id = repo_id.strip() + rev = rev.strip() if rev else None + + # Load the kernel from hub + try: + kernel = get_kernel(repo_id, revision=rev) + except Exception as e: + raise ValueError(f"An error occured while trying to load from '{repo_id}': {e}.") + # correctly wrap the kernel + if hasattr(kernel, "flash_attn_varlen_func"): + if attention_wrapper is None: + attention_wrapper = flash_attention_forward + kernel_function = partial(attention_wrapper, implementation=kernel) + lazy_import_flash_attention(kernel) + elif kernel_name is not None: + kernel_function = getattr(kernel, kernel_name) + # Register the kernel as a valid attention + ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function) + ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]) __all__ = [ "LayerRepository", - "is_hub_kernels_available", "use_kernel_forward_from_hub", "register_kernel_mapping", "replace_kernel_forward_from_hub", diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 47aaedc99fba..37554773a85f 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -126,10 +126,10 @@ def _lazy_define_process_function(flash_function): def lazy_import_flash_attention(implementation: Optional[str]): """ - Lazy loading flash attention and returning the respective functions + flags back + Lazily import flash attention and return the respective functions + flags. - NOTE: For fullgraph, this needs to be called before compile while no fullgraph can - can work without preloading. See `_check_and_adjust_attn_implementation` in `modeling_utils`. + NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can + work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`. """ global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn if any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]): diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4a96e9e537e2..973ee405cb3a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -44,12 +44,6 @@ from torch.distributions import constraints from torch.utils.checkpoint import checkpoint -from transformers.utils import is_torchao_available - - -if is_torchao_available(): - from torchao.quantization import Int4WeightOnlyConfig - from .configuration_utils import PretrainedConfig from .distributed import DistributedConfig from .dynamic_module_utils import custom_object_save @@ -61,6 +55,7 @@ from .integrations.flash_attention import flash_attention_forward from .integrations.flash_paged import paged_attention_forward from .integrations.flex_attention import flex_attention_forward +from .integrations.hub_kernels import is_kernel, load_and_register_kernel from .integrations.sdpa_attention import sdpa_attention_forward from .integrations.sdpa_paged import sdpa_attention_paged_forward from .integrations.tensor_parallel import ( @@ -73,17 +68,8 @@ verify_tp_plan, ) from .loss.loss_utils import LOSS_MAPPING -from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS from .modeling_flash_attention_utils import lazy_import_flash_attention -from .pytorch_utils import ( # noqa: F401 - Conv1D, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - id_tensor_storage, - prune_conv1d_layer, - prune_layer, - prune_linear_layer, -) +from .pytorch_utils import id_tensor_storage from .quantizers import HfQuantizer from .quantizers.auto import get_hf_quantizer from .quantizers.quantizers_utils import get_module_from_name @@ -124,6 +110,7 @@ is_torch_npu_available, is_torch_xla_available, is_torch_xpu_available, + is_torchao_available, logging, ) from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder @@ -138,9 +125,8 @@ from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod -XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() -XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() - +if is_torchao_available(): + from torchao.quantization import Int4WeightOnlyConfig if is_accelerate_available(): from accelerate import dispatch_model, infer_auto_device_map @@ -164,32 +150,14 @@ from safetensors.torch import load_file as safe_load_file from safetensors.torch import save_file as safe_save_file +if is_peft_available(): + from .utils import find_adapter_config_file -if is_kernels_available(): - from kernels import get_kernel - - -logger = logging.get_logger(__name__) - - -_init_weights = True -_is_quantized = False -_is_ds_init_called = False _torch_distributed_available = torch.distributed.is_available() - _is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5") if _is_dtensor_available: from torch.distributed.tensor import DTensor - -def is_local_dist_rank_0(): - return ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - and int(os.environ.get("LOCAL_RANK", "-1")) == 0 - ) - - if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp from smdistributed.modelparallel import __version__ as SMP_VERSION @@ -198,11 +166,24 @@ def is_local_dist_rank_0(): else: IS_SAGEMAKER_MP_POST_1_10 = False -if is_peft_available(): - from .utils import find_adapter_config_file +logger = logging.get_logger(__name__) +XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() +XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel") +_init_weights = True +_is_quantized = False +_is_ds_init_called = False + + +def is_local_dist_rank_0(): + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + and int(os.environ.get("LOCAL_RANK", "-1")) == 0 + ) + TORCH_INIT_FUNCTIONS = { "uniform_": nn.init.uniform_, @@ -2801,44 +2782,10 @@ def _check_and_adjust_attn_implementation( and is_kernels_available() ): applicable_attn_implementation = "kernels-community/flash-attn" - if applicable_attn_implementation is not None and re.match( - r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", applicable_attn_implementation - ): - if not is_kernels_available(): - raise ValueError("kernels is not installed. Please install it with `pip install kernels`.") - attention_wrapper = None - # FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work - actual_attn_name = applicable_attn_implementation - if "|" in applicable_attn_implementation: - attention_wrapper, actual_attn_name = applicable_attn_implementation.split("|") - # `transformers` has wrapper for sdpa, paged, flash, flex etc. - attention_wrapper = ALL_ATTENTION_FUNCTIONS.get(attention_wrapper) - # Extract repo_id and kernel_name from the string - if ":" in actual_attn_name: - repo_id, kernel_name = actual_attn_name.split(":") - kernel_name = kernel_name.strip() - else: - repo_id = actual_attn_name - kernel_name = None - repo_id = repo_id.strip() - # extract the rev after the @ if it exists - repo_id, _, rev = repo_id.partition("@") - repo_id = repo_id.strip() - rev = rev.strip() if rev else None + if is_kernel(applicable_attn_implementation): try: - kernel = get_kernel(repo_id, revision=rev) - if hasattr(kernel, "flash_attn_varlen_func"): - if attention_wrapper is None: - attention_wrapper = flash_attention_forward - kernel_function = partial(attention_wrapper, implementation=kernel) - lazy_import_flash_attention(kernel) - elif kernel_name is not None: - kernel_function = getattr(kernel, kernel_name) - ALL_ATTENTION_FUNCTIONS.register(applicable_attn_implementation, kernel_function) - ALL_MASK_ATTENTION_FUNCTIONS.register( - applicable_attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"] - ) - # log that we used kernel fallback + load_and_register_kernel(applicable_attn_implementation) + # log that we used kernel fallback if successful if attn_implementation == "flash_attention_2": logger.warning_once( "You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` " @@ -2848,8 +2795,8 @@ def _check_and_adjust_attn_implementation( if attn_implementation == "flash_attention_2": self._flash_attn_2_can_dispatch() # will fail as fa2 is not available but raise the proper exception logger.warning_once( - f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using " - "default attention implementation instead (sdpa if available, eager otherwise)." + f"Could not find a kernel matching `{applicable_attn_implementation}` compatible with your device in the " + f"hub:\n{e}.\nUsing default attention implementation instead (sdpa if available, eager otherwise)." ) try: self._sdpa_can_dispatch(is_init_check) diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 3eff5ac785e2..2eaf655d01be 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -31,12 +31,8 @@ BaseModelOutputWithPoolingAndCrossAttentions, CausalLMOutputWithCrossAttentions, ) -from ...modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging from ...utils.deprecation import deprecate_kwarg from .configuration_blip import BlipTextConfig diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 385f6b224825..ba29c6fe324d 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -34,8 +34,8 @@ ModelOutput, SequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward -from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int from ...utils.deprecation import deprecate_kwarg from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig diff --git a/src/transformers/models/cvt/modeling_cvt.py b/src/transformers/models/cvt/modeling_cvt.py index e838ffb3cd41..85e2bde325e2 100644 --- a/src/transformers/models/cvt/modeling_cvt.py +++ b/src/transformers/models/cvt/modeling_cvt.py @@ -24,7 +24,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging from .configuration_cvt import CvtConfig diff --git a/src/transformers/models/deprecated/mctct/modeling_mctct.py b/src/transformers/models/deprecated/mctct/modeling_mctct.py index a6686c1eb2a1..3c9d259e8215 100755 --- a/src/transformers/models/deprecated/mctct/modeling_mctct.py +++ b/src/transformers/models/deprecated/mctct/modeling_mctct.py @@ -28,12 +28,8 @@ from ....modeling_attn_mask_utils import _prepare_4d_attention_mask from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutput, CausalLMOutput -from ....modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ....modeling_utils import PreTrainedModel +from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ....utils import logging from .configuration_mctct import MCTCTConfig diff --git a/src/transformers/models/esm/modeling_esm.py b/src/transformers/models/esm/modeling_esm.py index aaea1614bb75..5db366aa6197 100755 --- a/src/transformers/models/esm/modeling_esm.py +++ b/src/transformers/models/esm/modeling_esm.py @@ -31,13 +31,9 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - ALL_ATTENTION_FUNCTIONS, - PreTrainedModel, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_esm import EsmConfig diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index e74fadc049b1..f0677d5f600d 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -41,15 +41,9 @@ ModelOutput, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from ...modeling_utils import ( - ALL_ATTENTION_FUNCTIONS, - ModuleUtilsMixin, - PreTrainedModel, - find_pruneable_heads_and_indices, - get_parameter_dtype, - prune_linear_layer, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, ModuleUtilsMixin, PreTrainedModel, get_parameter_dtype from ...processing_utils import Unpack +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder, check_model_inputs diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 63eadf41c380..cafd6e589adf 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -27,7 +27,8 @@ from ...activations import ACT2FN from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer +from ...modeling_utils import PreTrainedModel +from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging, torch_int from .configuration_flava import ( FlavaConfig, diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 0dd845ecff3c..aeb817be7060 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -32,13 +32,8 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_utils import ( - ALL_ATTENTION_FUNCTIONS, - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, can_return_tuple, logging from .configuration_markuplm import MarkupLMConfig