1818import math
1919from dataclasses import dataclass
2020from enum import Enum
21- from typing import TYPE_CHECKING , Any , Callable , Dict , List , Literal , Optional , Tuple , Union
21+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Tuple , Union
2222
2323import torch
2424
@@ -160,16 +160,13 @@ def wrap(func):
160160# - CP with sage attention, flex, xformers, other missing backends
161161# - Add support for normal and CP training with backends that don't support it yet
162162
163- _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal ["fp32" , "fp32+fp32" ]
164- _SAGE_ATTENTION_QK_QUANT_GRAN = Literal ["per_thread" , "per_warp" ]
165- _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal ["cuda" , "triton" ]
166-
167163
168164class AttentionBackendName (str , Enum ):
169165 # EAGER = "eager"
170166
171167 # `flash-attn`
172168 FLASH = "flash"
169+ FLASH_HUB = "flash_hub"
173170 FLASH_VARLEN = "flash_varlen"
174171 _FLASH_3 = "_flash_3"
175172 _FLASH_VARLEN_3 = "_flash_varlen_3"
@@ -191,6 +188,7 @@ class AttentionBackendName(str, Enum):
191188
192189 # `sageattention`
193190 SAGE = "sage"
191+ SAGE_HUB = "sage_hub"
194192 SAGE_VARLEN = "sage_varlen"
195193 _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
196194 _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
@@ -264,7 +262,13 @@ class _HubKernelConfig:
264262 # TODO: temporary revision for now. Remove when merged upstream into `main`.
265263 AttentionBackendName ._FLASH_3_HUB : _HubKernelConfig (
266264 repo_id = "kernels-community/flash-attn3" , function_attr = "flash_attn_func" , revision = "fake-ops-return-probs"
267- )
265+ ),
266+ AttentionBackendName .FLASH_HUB : _HubKernelConfig (
267+ repo_id = "kernels-community/flash-attn2" , function_attr = "flash_attn_func" , revision = None
268+ ),
269+ AttentionBackendName .SAGE_HUB : _HubKernelConfig (
270+ repo_id = "kernels-community/sage_attention" , function_attr = "sageattn" , revision = None
271+ ),
268272}
269273
270274
@@ -420,8 +424,8 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
420424 f"Flash Attention 3 backend '{ backend .value } ' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
421425 )
422426
423- # TODO: add support Hub variant of FA3 varlen later
424- elif backend in [AttentionBackendName ._FLASH_3_HUB ]:
427+ # TODO: add support Hub variant of varlen later
428+ elif backend in [AttentionBackendName ._FLASH_3_HUB , AttentionBackendName . FLASH_HUB , AttentionBackendName . SAGE_HUB ]:
425429 if not is_kernels_available ():
426430 raise RuntimeError (
427431 f"Backend '{ backend .value } ' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
@@ -1350,6 +1354,38 @@ def _flash_attention(
13501354 return (out , lse ) if return_lse else out
13511355
13521356
1357+ @_AttentionBackendRegistry .register (
1358+ AttentionBackendName .FLASH_HUB ,
1359+ constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
1360+ supports_context_parallel = False ,
1361+ )
1362+ def _flash_attention_hub (
1363+ query : torch .Tensor ,
1364+ key : torch .Tensor ,
1365+ value : torch .Tensor ,
1366+ dropout_p : float = 0.0 ,
1367+ is_causal : bool = False ,
1368+ scale : Optional [float ] = None ,
1369+ return_lse : bool = False ,
1370+ _parallel_config : Optional ["ParallelConfig" ] = None ,
1371+ ) -> torch .Tensor :
1372+ lse = None
1373+ func = _HUB_KERNELS_REGISTRY [AttentionBackendName .FLASH_HUB ].kernel_fn
1374+ out = func (
1375+ q = query ,
1376+ k = key ,
1377+ v = value ,
1378+ dropout_p = dropout_p ,
1379+ softmax_scale = scale ,
1380+ causal = is_causal ,
1381+ return_attn_probs = return_lse ,
1382+ )
1383+ if return_lse :
1384+ out , lse , * _ = out
1385+
1386+ return (out , lse ) if return_lse else out
1387+
1388+
13531389@_AttentionBackendRegistry .register (
13541390 AttentionBackendName .FLASH_VARLEN ,
13551391 constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
@@ -1431,6 +1467,7 @@ def _flash_attention_3(
14311467@_AttentionBackendRegistry .register (
14321468 AttentionBackendName ._FLASH_3_HUB ,
14331469 constraints = [_check_device , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
1470+ supports_context_parallel = False ,
14341471)
14351472def _flash_attention_3_hub (
14361473 query : torch .Tensor ,
@@ -1444,6 +1481,9 @@ def _flash_attention_3_hub(
14441481 return_attn_probs : bool = False ,
14451482 _parallel_config : Optional ["ParallelConfig" ] = None ,
14461483) -> torch .Tensor :
1484+ if _parallel_config :
1485+ raise NotImplementedError (f"{ AttentionBackendName ._FLASH_3_HUB .value } is not implemented for parallelism yet." )
1486+
14471487 func = _HUB_KERNELS_REGISTRY [AttentionBackendName ._FLASH_3_HUB ].kernel_fn
14481488 out = func (
14491489 q = query ,
@@ -1938,6 +1978,38 @@ def _sage_attention(
19381978 return (out , lse ) if return_lse else out
19391979
19401980
1981+ @_AttentionBackendRegistry .register (
1982+ AttentionBackendName .SAGE_HUB ,
1983+ constraints = [_check_device_cuda , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
1984+ supports_context_parallel = False ,
1985+ )
1986+ def _sage_attention_hub (
1987+ query : torch .Tensor ,
1988+ key : torch .Tensor ,
1989+ value : torch .Tensor ,
1990+ is_causal : bool = False ,
1991+ scale : Optional [float ] = None ,
1992+ return_lse : bool = False ,
1993+ _parallel_config : Optional ["ParallelConfig" ] = None ,
1994+ ) -> torch .Tensor :
1995+ lse = None
1996+ func = _HUB_KERNELS_REGISTRY [AttentionBackendName .SAGE_HUB ].kernel_fn
1997+ if _parallel_config is None :
1998+ out = func (
1999+ q = query ,
2000+ k = key ,
2001+ v = value ,
2002+ tensor_layout = "NHD" ,
2003+ is_causal = is_causal ,
2004+ sm_scale = scale ,
2005+ return_lse = return_lse ,
2006+ )
2007+ if return_lse :
2008+ out , lse , * _ = out
2009+
2010+ return (out , lse ) if return_lse else out
2011+
2012+
19412013@_AttentionBackendRegistry .register (
19422014 AttentionBackendName .SAGE_VARLEN ,
19432015 constraints = [_check_device_cuda , _check_qkv_dtype_bf16_or_fp16 , _check_shape ],
0 commit comments