diff --git a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py index 7b61d2bdefd9..1c832e84932f 100644 --- a/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modeling_falcon_mamba.py @@ -35,6 +35,7 @@ from ...utils import ModelOutput, auto_docstring, logging from ...utils.import_utils import ( is_causal_conv1d_available, + is_kernels_available, is_mamba_ssm_available, is_mambapy_available, ) @@ -54,11 +55,6 @@ else: selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None - logger = logging.get_logger(__name__) @@ -164,6 +160,28 @@ def reset(self): self.ssm_states[layer_idx].zero_() +def _lazy_load_causal_conv1d(): + global _causal_conv1d_cache + if _causal_conv1d_cache is not None: + return _causal_conv1d_cache + + if is_kernels_available(): + from kernels import get_kernel + + _causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d") + _causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn) + elif is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + + _causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn) + else: + _causal_conv1d_cache = (None, None) + return _causal_conv1d_cache + + +_causal_conv1d_cache = None + + def rms_forward(hidden_states, variance_epsilon=1e-6): """ Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will @@ -243,6 +261,7 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int): self.rms_eps = config.mixer_rms_eps def warn_slow_implementation(self): + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) @@ -251,8 +270,8 @@ def warn_slow_implementation(self): if is_mambapy_available(): logger.warning_once( "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d" + " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and" + " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d" ) else: raise ImportError( @@ -261,8 +280,8 @@ def warn_slow_implementation(self): else: logger.warning_once( "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py." + " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and" + " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py." ) def cuda_kernels_forward( @@ -297,6 +316,7 @@ def cuda_kernels_forward( ) else: + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() hidden_states, gate = projected_states.chunk(2, dim=1) if attention_mask is not None: @@ -491,6 +511,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) diff --git a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py index 090a147d31e2..7534d1b6c68a 100644 --- a/src/transformers/models/falcon_mamba/modular_falcon_mamba.py +++ b/src/transformers/models/falcon_mamba/modular_falcon_mamba.py @@ -21,7 +21,10 @@ from torch import nn from ...utils import auto_docstring, logging -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available +from ...utils.import_utils import ( + is_mamba_ssm_available, + is_mambapy_available, +) from ..mamba.configuration_mamba import MambaConfig from ..mamba.modeling_mamba import ( MambaBlock, @@ -33,6 +36,7 @@ MambaOutput, MambaPreTrainedModel, MambaRMSNorm, + _lazy_load_causal_conv1d, ) @@ -51,10 +55,7 @@ else: selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None +_causal_conv1d_cache = None class FalconMambaConfig(MambaConfig): @@ -225,6 +226,7 @@ def rms_forward(hidden_states, variance_epsilon=1e-6): class FalconMambaMixer(MambaMixer): def warn_slow_implementation(self): + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) @@ -233,8 +235,8 @@ def warn_slow_implementation(self): if is_mambapy_available(): logger.warning_once( "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d" + " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and" + " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d" ) else: raise ImportError( @@ -243,8 +245,8 @@ def warn_slow_implementation(self): else: logger.warning_once( "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py." + " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and" + " https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py." ) def __init__(self, config: FalconMambaConfig, layer_idx: int): @@ -290,6 +292,7 @@ def cuda_kernels_forward( ) else: + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() hidden_states, gate = projected_states.chunk(2, dim=1) if attention_mask is not None: @@ -483,6 +486,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 3f39c2d8490b..9cdc63a9943a 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -33,7 +33,12 @@ auto_docstring, logging, ) -from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available +from ...utils.import_utils import ( + is_causal_conv1d_available, + is_kernels_available, + is_mamba_ssm_available, + is_mambapy_available, +) from .configuration_mamba import MambaConfig @@ -50,10 +55,26 @@ else: selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None -if is_causal_conv1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -else: - causal_conv1d_update, causal_conv1d_fn = None, None +_causal_conv1d_cache = None + + +def _lazy_load_causal_conv1d(): + global _causal_conv1d_cache + if _causal_conv1d_cache is not None: + return _causal_conv1d_cache + + if is_kernels_available(): + from kernels import get_kernel + + _causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d") + _causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn) + elif is_causal_conv1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update + + _causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn) + else: + _causal_conv1d_cache = (None, None) + return _causal_conv1d_cache class MambaCache: @@ -209,6 +230,7 @@ def __init__(self, config: MambaConfig, layer_idx: int): self.warn_slow_implementation() def warn_slow_implementation(self): + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) @@ -217,8 +239,8 @@ def warn_slow_implementation(self): if is_mambapy_available(): logger.warning_once( "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d" + " is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and" + " install the kernels library using `pip install kernels` or https://github.com/Dao-AILab/causal-conv1d for causal-conv1d" ) else: raise ImportError( @@ -227,8 +249,8 @@ def warn_slow_implementation(self): else: logger.warning_once( "The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" - " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and" - " https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py." + " is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and" + " install the kernels library using `pip install kernels` or https://github.com/Dao-AILab/causal-conv1d for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py." ) def cuda_kernels_forward( @@ -259,6 +281,7 @@ def cuda_kernels_forward( ) else: + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() hidden_states, gate = projected_states.chunk(2, dim=1) if attention_mask is not None: @@ -422,6 +445,7 @@ def forward( cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): + causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d() is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) )