Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions src/transformers/integrations/hub_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
import re
from functools import partial
from typing import Optional, Union

from ..modeling_flash_attention_utils import lazy_import_flash_attention
from .flash_attention import flash_attention_forward


try:
from kernels import (
Device,
LayerRepository,
Mode,
get_kernel,
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
)

_hub_kernels_available = True
_kernels_available = True

_KERNEL_MAPPING: dict[str, dict[Union[Device, str], LayerRepository]] = {
"MultiScaleDeformableAttention": {
Expand Down Expand Up @@ -82,8 +88,9 @@

register_kernel_mapping(_KERNEL_MAPPING)


except ImportError:
_kernels_available = False

# Stub to make decorators int transformers work when `kernels`
# is not installed.
def use_kernel_forward_from_hub(*args, **kwargs):
Expand All @@ -104,16 +111,66 @@ def replace_kernel_forward_from_hub(*args, **kwargs):
def register_kernel_mapping(*args, **kwargs):
raise RuntimeError("register_kernel_mapping requires `kernels` to be installed. Run `pip install kernels`.")

_hub_kernels_available = False

def is_kernel(attn_implementation: Optional[str]) -> bool:
"""Check whether `attn_implementation` matches a kernel pattern from the hub."""
return (
attn_implementation is not None
and re.search(r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", attn_implementation) is not None
)


def is_hub_kernels_available():
return _hub_kernels_available
def load_and_register_kernel(attn_implementation: str) -> None:
"""Load and register the kernel associated to `attn_implementation`."""
if not is_kernel(attn_implementation):
return
if not _kernels_available:
raise ImportError("`kernels` is not installed. Please install it with `pip install kernels`.")

# Need to be imported here as otherwise we have a circular import in `modeling_utils`
from ..masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS

attention_wrapper = None
# FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work
actual_attn_name = attn_implementation
if "|" in attn_implementation:
attention_wrapper, actual_attn_name = attn_implementation.split("|")
# `transformers` has wrapper for sdpa, paged, flash, flex etc.
attention_wrapper = ALL_ATTENTION_FUNCTIONS.get(attention_wrapper)
# Extract repo_id and kernel_name from the string
if ":" in actual_attn_name:
repo_id, kernel_name = actual_attn_name.split(":")
kernel_name = kernel_name.strip()
else:
repo_id = actual_attn_name
kernel_name = None
repo_id = repo_id.strip()
# extract the rev after the @ if it exists
repo_id, _, rev = repo_id.partition("@")
repo_id = repo_id.strip()
rev = rev.strip() if rev else None

# Load the kernel from hub
try:
kernel = get_kernel(repo_id, revision=rev)
except Exception as e:
raise ValueError(f"An error occured while trying to load from '{repo_id}': {e}.")
# correctly wrap the kernel
if hasattr(kernel, "flash_attn_varlen_func"):
if attention_wrapper is None:
attention_wrapper = flash_attention_forward
kernel_function = partial(attention_wrapper, implementation=kernel)
lazy_import_flash_attention(kernel)
elif kernel_name is not None:
kernel_function = getattr(kernel, kernel_name)
# Register the kernel as a valid attention
ALL_ATTENTION_FUNCTIONS.register(attn_implementation, kernel_function)
ALL_MASK_ATTENTION_FUNCTIONS.register(attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"])


__all__ = [
"LayerRepository",
"is_hub_kernels_available",
"use_kernel_forward_from_hub",
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ def _lazy_define_process_function(flash_function):

def lazy_import_flash_attention(implementation: Optional[str]):
"""
Lazy loading flash attention and returning the respective functions + flags back
Lazily import flash attention and return the respective functions + flags.

NOTE: For fullgraph, this needs to be called before compile while no fullgraph can
can work without preloading. See `_check_and_adjust_attn_implementation` in `modeling_utils`.
NOTE: For fullgraph, this needs to be called before compile, while no fullgraph can
work without preloading. See `load_and_register_kernel` in `integrations.hub_kernels`.
"""
global _flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn
if any(k is None for k in [_flash_fn, _flash_varlen_fn, _pad_fn, _unpad_fn]):
Expand Down
107 changes: 27 additions & 80 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@
from torch.distributions import constraints
from torch.utils.checkpoint import checkpoint

from transformers.utils import is_torchao_available


if is_torchao_available():
from torchao.quantization import Int4WeightOnlyConfig

from .configuration_utils import PretrainedConfig
from .distributed import DistributedConfig
from .dynamic_module_utils import custom_object_save
Expand All @@ -61,6 +55,7 @@
from .integrations.flash_attention import flash_attention_forward
from .integrations.flash_paged import paged_attention_forward
from .integrations.flex_attention import flex_attention_forward
from .integrations.hub_kernels import is_kernel, load_and_register_kernel
from .integrations.sdpa_attention import sdpa_attention_forward
from .integrations.sdpa_paged import sdpa_attention_paged_forward
from .integrations.tensor_parallel import (
Expand All @@ -73,17 +68,8 @@
verify_tp_plan,
)
from .loss.loss_utils import LOSS_MAPPING
from .masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
from .modeling_flash_attention_utils import lazy_import_flash_attention
from .pytorch_utils import ( # noqa: F401
Conv1D,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
id_tensor_storage,
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
)
from .pytorch_utils import id_tensor_storage
from .quantizers import HfQuantizer
from .quantizers.auto import get_hf_quantizer
from .quantizers.quantizers_utils import get_module_from_name
Expand Down Expand Up @@ -124,6 +110,7 @@
is_torch_npu_available,
is_torch_xla_available,
is_torch_xpu_available,
is_torchao_available,
logging,
)
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
Expand All @@ -138,9 +125,8 @@
from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod


XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()

if is_torchao_available():
from torchao.quantization import Int4WeightOnlyConfig

if is_accelerate_available():
from accelerate import dispatch_model, infer_auto_device_map
Expand All @@ -164,32 +150,14 @@
from safetensors.torch import load_file as safe_load_file
from safetensors.torch import save_file as safe_save_file

if is_peft_available():
from .utils import find_adapter_config_file

if is_kernels_available():
from kernels import get_kernel


logger = logging.get_logger(__name__)


_init_weights = True
_is_quantized = False
_is_ds_init_called = False
_torch_distributed_available = torch.distributed.is_available()

_is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
if _is_dtensor_available:
from torch.distributed.tensor import DTensor


def is_local_dist_rank_0():
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and int(os.environ.get("LOCAL_RANK", "-1")) == 0
)


if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION
Expand All @@ -198,11 +166,24 @@ def is_local_dist_rank_0():
else:
IS_SAGEMAKER_MP_POST_1_10 = False

if is_peft_available():
from .utils import find_adapter_config_file

logger = logging.get_logger(__name__)

XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
_init_weights = True
_is_quantized = False
_is_ds_init_called = False


def is_local_dist_rank_0():
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and int(os.environ.get("LOCAL_RANK", "-1")) == 0
)


TORCH_INIT_FUNCTIONS = {
"uniform_": nn.init.uniform_,
Expand Down Expand Up @@ -2801,44 +2782,10 @@ def _check_and_adjust_attn_implementation(
and is_kernels_available()
):
applicable_attn_implementation = "kernels-community/flash-attn"
if applicable_attn_implementation is not None and re.match(
r"^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$", applicable_attn_implementation
):
if not is_kernels_available():
raise ValueError("kernels is not installed. Please install it with `pip install kernels`.")
attention_wrapper = None
# FIXME: @ArthurZucker this is dirty, did not want to do a lof of extra work
actual_attn_name = applicable_attn_implementation
if "|" in applicable_attn_implementation:
attention_wrapper, actual_attn_name = applicable_attn_implementation.split("|")
# `transformers` has wrapper for sdpa, paged, flash, flex etc.
attention_wrapper = ALL_ATTENTION_FUNCTIONS.get(attention_wrapper)
# Extract repo_id and kernel_name from the string
if ":" in actual_attn_name:
repo_id, kernel_name = actual_attn_name.split(":")
kernel_name = kernel_name.strip()
else:
repo_id = actual_attn_name
kernel_name = None
repo_id = repo_id.strip()
# extract the rev after the @ if it exists
repo_id, _, rev = repo_id.partition("@")
repo_id = repo_id.strip()
rev = rev.strip() if rev else None
if is_kernel(applicable_attn_implementation):
try:
kernel = get_kernel(repo_id, revision=rev)
if hasattr(kernel, "flash_attn_varlen_func"):
if attention_wrapper is None:
attention_wrapper = flash_attention_forward
kernel_function = partial(attention_wrapper, implementation=kernel)
lazy_import_flash_attention(kernel)
elif kernel_name is not None:
kernel_function = getattr(kernel, kernel_name)
ALL_ATTENTION_FUNCTIONS.register(applicable_attn_implementation, kernel_function)
ALL_MASK_ATTENTION_FUNCTIONS.register(
applicable_attn_implementation, ALL_MASK_ATTENTION_FUNCTIONS["flash_attention_2"]
)
# log that we used kernel fallback
load_and_register_kernel(applicable_attn_implementation)
# log that we used kernel fallback if successful
if attn_implementation == "flash_attention_2":
logger.warning_once(
"You do not have `flash_attn` installed, using `kernels-community/flash-attn` from the `kernels` "
Expand All @@ -2848,8 +2795,8 @@ def _check_and_adjust_attn_implementation(
if attn_implementation == "flash_attention_2":
self._flash_attn_2_can_dispatch() # will fail as fa2 is not available but raise the proper exception
logger.warning_once(
f"Could not find a kernel repository '{repo_id}' compatible with your device in the hub: {e}. Using "
"default attention implementation instead (sdpa if available, eager otherwise)."
f"Could not find a kernel matching `{applicable_attn_implementation}` compatible with your device in the "
f"hub:\n{e}.\nUsing default attention implementation instead (sdpa if available, eager otherwise)."
)
try:
self._sdpa_can_dispatch(is_init_check)
Expand Down
8 changes: 2 additions & 6 deletions src/transformers/models/blip/modeling_blip_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,8 @@
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
)
from ...modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import logging
from ...utils.deprecation import deprecate_kwarg
from .configuration_blip import BlipTextConfig
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bridgetower/modeling_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
ModelOutput,
SequenceClassifierOutput,
)
from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import auto_docstring, logging, torch_int
from ...utils.deprecation import deprecate_kwarg
from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/cvt/modeling_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...modeling_outputs import ImageClassifierOutputWithNoAttention, ModelOutput
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import auto_docstring, logging
from .configuration_cvt import CvtConfig

Expand Down
8 changes: 2 additions & 6 deletions src/transformers/models/deprecated/mctct/modeling_mctct.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,8 @@
from ....modeling_attn_mask_utils import _prepare_4d_attention_mask
from ....modeling_layers import GradientCheckpointingLayer
from ....modeling_outputs import BaseModelOutput, CausalLMOutput
from ....modeling_utils import (
PreTrainedModel,
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ....modeling_utils import PreTrainedModel
from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ....utils import logging
from .configuration_mctct import MCTCTConfig

Expand Down
8 changes: 2 additions & 6 deletions src/transformers/models/esm/modeling_esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,9 @@
SequenceClassifierOutput,
TokenClassifierOutput,
)
from ...modeling_utils import (
ALL_ATTENTION_FUNCTIONS,
PreTrainedModel,
find_pruneable_heads_and_indices,
prune_linear_layer,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import OutputRecorder, check_model_inputs
from .configuration_esm import EsmConfig
Expand Down
10 changes: 2 additions & 8 deletions src/transformers/models/evolla/modeling_evolla.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,9 @@
ModelOutput,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_utils import (
ALL_ATTENTION_FUNCTIONS,
ModuleUtilsMixin,
PreTrainedModel,
find_pruneable_heads_and_indices,
get_parameter_dtype,
prune_linear_layer,
)
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, ModuleUtilsMixin, PreTrainedModel, get_parameter_dtype
from ...processing_utils import Unpack
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.deprecation import deprecate_kwarg
from ...utils.generic import OutputRecorder, check_model_inputs
Expand Down
Loading