Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions src/transformers/models/falcon_mamba/modeling_falcon_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ...utils import ModelOutput, auto_docstring, logging
from ...utils.import_utils import (
is_causal_conv1d_available,
is_kernels_available,
is_mamba_ssm_available,
is_mambapy_available,
)
Expand All @@ -54,11 +55,6 @@
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None

if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None


logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -164,6 +160,28 @@ def reset(self):
self.ssm_states[layer_idx].zero_()


def _lazy_load_causal_conv1d():
global _causal_conv1d_cache
if _causal_conv1d_cache is not None:
return _causal_conv1d_cache

if is_kernels_available():
from kernels import get_kernel

_causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
_causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
elif is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update

_causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
else:
_causal_conv1d_cache = (None, None)
return _causal_conv1d_cache


_causal_conv1d_cache = None


def rms_forward(hidden_states, variance_epsilon=1e-6):
"""
Calculates simple RMSNorm with no learnable weights. `MambaRMSNorm` will
Expand Down Expand Up @@ -243,6 +261,7 @@ def __init__(self, config: FalconMambaConfig, layer_idx: int):
self.rms_eps = config.mixer_rms_eps

def warn_slow_implementation(self):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
Expand All @@ -251,8 +270,8 @@ def warn_slow_implementation(self):
if is_mambapy_available():
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
" is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
" https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d"
)
else:
raise ImportError(
Expand All @@ -261,8 +280,8 @@ def warn_slow_implementation(self):
else:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
" https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
)

def cuda_kernels_forward(
Expand Down Expand Up @@ -297,6 +316,7 @@ def cuda_kernels_forward(
)

else:
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
Expand Down Expand Up @@ -491,6 +511,7 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
Expand Down
22 changes: 13 additions & 9 deletions src/transformers/models/falcon_mamba/modular_falcon_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
from torch import nn

from ...utils import auto_docstring, logging
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
from ...utils.import_utils import (
is_mamba_ssm_available,
is_mambapy_available,
)
from ..mamba.configuration_mamba import MambaConfig
from ..mamba.modeling_mamba import (
MambaBlock,
Expand All @@ -33,6 +36,7 @@
MambaOutput,
MambaPreTrainedModel,
MambaRMSNorm,
_lazy_load_causal_conv1d,
)


Expand All @@ -51,10 +55,7 @@
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None

if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None
_causal_conv1d_cache = None


class FalconMambaConfig(MambaConfig):
Expand Down Expand Up @@ -225,6 +226,7 @@ def rms_forward(hidden_states, variance_epsilon=1e-6):

class FalconMambaMixer(MambaMixer):
def warn_slow_implementation(self):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
Expand All @@ -233,8 +235,8 @@ def warn_slow_implementation(self):
if is_mambapy_available():
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
" is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
" https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d"
)
else:
raise ImportError(
Expand All @@ -243,8 +245,8 @@ def warn_slow_implementation(self):
else:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
" https://github.com/Dao-AILab/causal-conv1d or `pip install kernels` for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
)

def __init__(self, config: FalconMambaConfig, layer_idx: int):
Expand Down Expand Up @@ -290,6 +292,7 @@ def cuda_kernels_forward(
)

else:
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
Expand Down Expand Up @@ -483,6 +486,7 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
Expand Down
42 changes: 33 additions & 9 deletions src/transformers/models/mamba/modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
auto_docstring,
logging,
)
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available, is_mambapy_available
from ...utils.import_utils import (
is_causal_conv1d_available,
is_kernels_available,
is_mamba_ssm_available,
is_mambapy_available,
)
from .configuration_mamba import MambaConfig


Expand All @@ -50,10 +55,26 @@
else:
selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None

if is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
else:
causal_conv1d_update, causal_conv1d_fn = None, None
_causal_conv1d_cache = None


def _lazy_load_causal_conv1d():
global _causal_conv1d_cache
if _causal_conv1d_cache is not None:
return _causal_conv1d_cache

if is_kernels_available():
from kernels import get_kernel

_causal_conv1d_kernel = get_kernel("kernels-community/causal-conv1d")
_causal_conv1d_cache = (_causal_conv1d_kernel.causal_conv1d_update, _causal_conv1d_kernel.causal_conv1d_fn)
elif is_causal_conv1d_available():
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update

_causal_conv1d_cache = (causal_conv1d_update, causal_conv1d_fn)
else:
_causal_conv1d_cache = (None, None)
return _causal_conv1d_cache


class MambaCache:
Expand Down Expand Up @@ -209,6 +230,7 @@ def __init__(self, config: MambaConfig, layer_idx: int):
self.warn_slow_implementation()

def warn_slow_implementation(self):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
Expand All @@ -217,8 +239,8 @@ def warn_slow_implementation(self):
if is_mambapy_available():
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d"
" is None. Falling back to the mamba.py backend. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
" install the kernels library using `pip install kernels` or https://github.com/Dao-AILab/causal-conv1d for causal-conv1d"
)
else:
raise ImportError(
Expand All @@ -227,8 +249,8 @@ def warn_slow_implementation(self):
else:
logger.warning_once(
"The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`"
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and"
" https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
" is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation for mamba-ssm and"
" install the kernels library using `pip install kernels` or https://github.com/Dao-AILab/causal-conv1d for causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py."
)

def cuda_kernels_forward(
Expand Down Expand Up @@ -259,6 +281,7 @@ def cuda_kernels_forward(
)

else:
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
hidden_states, gate = projected_states.chunk(2, dim=1)

if attention_mask is not None:
Expand Down Expand Up @@ -422,6 +445,7 @@ def forward(
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
):
causal_conv1d_update, causal_conv1d_fn = _lazy_load_causal_conv1d()
is_fast_path_available = all(
(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
)
Expand Down