Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 27 additions & 56 deletions flash_dmattn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,72 +2,33 @@

from typing import Optional

__version__ = "1.0.0"
__version__ = "1.0.3"


# Import CUDA functions when available
try:
from flash_dmattn.flash_dmattn_interface import flash_dmattn_func
CUDA_AVAILABLE = True
except ImportError:
CUDA_AVAILABLE = False
flash_dmattn_func = None

# Import Triton functions when available
try:
from .flash_dmattn_triton import triton_dmattn_func
from flash_dmattn.flash_dmattn_triton import triton_dmattn_func
TRITON_AVAILABLE = True
except ImportError:
TRITON_AVAILABLE = False
triton_dmattn_func = None

# Import Flex functions when available
try:
from .flash_dmattn_flex import flex_dmattn_func
from flash_dmattn.flash_dmattn_flex import flex_dmattn_func
FLEX_AVAILABLE = True
except ImportError:
FLEX_AVAILABLE = False
flex_dmattn_func = None

# Check if CUDA extension is available
try:
import flash_dmattn_cuda # type: ignore[import]
CUDA_AVAILABLE = True
except ImportError:
CUDA_AVAILABLE = False

# Import CUDA functions when available
if CUDA_AVAILABLE:
try:
from .flash_dmattn_interface import (
flash_dmattn_func,
flash_dmattn_kvpacked_func,
flash_dmattn_qkvpacked_func,
flash_dmattn_varlen_func,
flash_dmattn_varlen_kvpacked_func,
flash_dmattn_varlen_qkvpacked_func,
)
except ImportError:
# Fallback if interface module is not available
flash_dmattn_func = None
flash_dmattn_kvpacked_func = None
flash_dmattn_qkvpacked_func = None
flash_dmattn_varlen_func = None
flash_dmattn_varlen_kvpacked_func = None
flash_dmattn_varlen_qkvpacked_func = None
else:
flash_dmattn_func = None
flash_dmattn_kvpacked_func = None
flash_dmattn_qkvpacked_func = None
flash_dmattn_varlen_func = None
flash_dmattn_varlen_kvpacked_func = None
flash_dmattn_varlen_qkvpacked_func = None

__all__ = [
"triton_dmattn_func",
"flex_dmattn_func",
"flash_dmattn_func",
"flash_dmattn_kvpacked_func",
"flash_dmattn_qkvpacked_func",
"flash_dmattn_varlen_func",
"flash_dmattn_varlen_kvpacked_func",
"flash_dmattn_varlen_qkvpacked_func",
"flash_dmattn_func_auto",
"get_available_backends",
"TRITON_AVAILABLE",
"FLEX_AVAILABLE",
"CUDA_AVAILABLE",
]


def get_available_backends():
"""Return a list of available backends."""
Expand All @@ -87,7 +48,7 @@ def flash_dmattn_func_auto(backend: Optional[str] = None, **kwargs):

Args:
backend (str, optional): Backend to use ('cuda', 'triton', 'flex').
If None, will use the first available backend in order: cuda, triton, flex.
If None, will use the first available backend in order: cuda, triton, flex.
**kwargs: Arguments to pass to the attention function.

Returns:
Expand All @@ -107,8 +68,6 @@ def flash_dmattn_func_auto(backend: Optional[str] = None, **kwargs):
if backend == "cuda":
if not CUDA_AVAILABLE:
raise RuntimeError("CUDA backend is not available. Please build the CUDA extension.")
if flash_dmattn_func is None:
raise RuntimeError("CUDA flash_dmattn_func is not available. Please check the installation.")
return flash_dmattn_func

elif backend == "triton":
Expand All @@ -123,3 +82,15 @@ def flash_dmattn_func_auto(backend: Optional[str] = None, **kwargs):

else:
raise ValueError(f"Unknown backend: {backend}. Available backends: {get_available_backends()}")


__all__ = [
"CUDA_AVAILABLE",
"TRITON_AVAILABLE",
"FLEX_AVAILABLE",
"flash_dmattn_func",
"triton_dmattn_func",
"flex_dmattn_func",
"get_available_backends",
"flash_dmattn_func_auto",
]