Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions flash_dmattn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from typing import Optional

__version__ = "0.1.0"

try:
from .flash_dmattn_triton import triton_dmattn_func
TRITON_AVAILABLE = True
Expand All @@ -23,11 +25,44 @@
except ImportError:
CUDA_AVAILABLE = False

__version__ = "0.1.0"
# Import CUDA functions when available
if CUDA_AVAILABLE:
try:
from .flash_dmattn_interface import (
flash_dmattn_func,
flash_dmattn_kvpacked_func,
flash_dmattn_qkvpacked_func,
flash_dmattn_varlen_func,
flash_dmattn_varlen_kvpacked_func,
flash_dmattn_varlen_qkvpacked_func,
)
except ImportError:
# Fallback if interface module is not available
flash_dmattn_func = None
flash_dmattn_kvpacked_func = None
flash_dmattn_qkvpacked_func = None
flash_dmattn_varlen_func = None
flash_dmattn_varlen_kvpacked_func = None
flash_dmattn_varlen_qkvpacked_func = None
else:
flash_dmattn_func = None
flash_dmattn_kvpacked_func = None
flash_dmattn_qkvpacked_func = None
flash_dmattn_varlen_func = None
flash_dmattn_varlen_kvpacked_func = None
flash_dmattn_varlen_qkvpacked_func = None

__all__ = [
"triton_dmattn_func",
"flex_dmattn_func",
"flash_dmattn_func",
"flash_dmattn_kvpacked_func",
"flash_dmattn_qkvpacked_func",
"flash_dmattn_varlen_func",
"flash_dmattn_varlen_kvpacked_func",
"flash_dmattn_varlen_qkvpacked_func",
"flash_dmattn_func_auto",
"get_available_backends",
"TRITON_AVAILABLE",
"FLEX_AVAILABLE",
"CUDA_AVAILABLE",
Expand All @@ -46,7 +81,7 @@ def get_available_backends():
return backends


def flash_dmattn_func(backend: Optional[str] = None, **kwargs):
def flash_dmattn_func_auto(backend: Optional[str] = None, **kwargs):
"""
Flash Dynamic Mask Attention function with automatic backend selection.

Expand All @@ -72,8 +107,9 @@ def flash_dmattn_func(backend: Optional[str] = None, **kwargs):
if backend == "cuda":
if not CUDA_AVAILABLE:
raise RuntimeError("CUDA backend is not available. Please build the CUDA extension.")
# Import and return CUDA function
raise NotImplementedError("CUDA backend not yet implemented in this version")
if flash_dmattn_func is None:
raise RuntimeError("CUDA flash_dmattn_func is not available. Please check the installation.")
return flash_dmattn_func

elif backend == "triton":
if not TRITON_AVAILABLE:
Expand Down
Loading