From 4c1623dd4b67521a2c30b75a621f61a67b35c541 Mon Sep 17 00:00:00 2001 From: LoserCheems <3314685395@qq.com> Date: Sat, 6 Sep 2025 20:26:57 +0800 Subject: [PATCH] Simplifies module imports and updates version to 1.0.3 Streamlines the initialization logic by removing redundant CUDA extension checking and complex nested import structures. Consolidates all backend imports into consistent try-catch blocks with cleaner error handling. Removes multiple CUDA function variants that were previously exposed, focusing on the core flash_dmattn_func interface. Updates __all__ exports to reflect the simplified API surface. Improves code maintainability by using absolute imports and reducing conditional import complexity. --- flash_dmattn/__init__.py | 83 +++++++++++++--------------------------- 1 file changed, 27 insertions(+), 56 deletions(-) diff --git a/flash_dmattn/__init__.py b/flash_dmattn/__init__.py index 825d590..5ecdbe8 100644 --- a/flash_dmattn/__init__.py +++ b/flash_dmattn/__init__.py @@ -2,72 +2,33 @@ from typing import Optional -__version__ = "1.0.0" +__version__ = "1.0.3" + +# Import CUDA functions when available +try: + from flash_dmattn.flash_dmattn_interface import flash_dmattn_func + CUDA_AVAILABLE = True +except ImportError: + CUDA_AVAILABLE = False + flash_dmattn_func = None + +# Import Triton functions when available try: - from .flash_dmattn_triton import triton_dmattn_func + from flash_dmattn.flash_dmattn_triton import triton_dmattn_func TRITON_AVAILABLE = True except ImportError: TRITON_AVAILABLE = False triton_dmattn_func = None +# Import Flex functions when available try: - from .flash_dmattn_flex import flex_dmattn_func + from flash_dmattn.flash_dmattn_flex import flex_dmattn_func FLEX_AVAILABLE = True except ImportError: FLEX_AVAILABLE = False flex_dmattn_func = None -# Check if CUDA extension is available -try: - import flash_dmattn_cuda # type: ignore[import] - CUDA_AVAILABLE = True -except ImportError: - CUDA_AVAILABLE = False - -# Import CUDA functions when available -if CUDA_AVAILABLE: - try: - from .flash_dmattn_interface import ( - flash_dmattn_func, - flash_dmattn_kvpacked_func, - flash_dmattn_qkvpacked_func, - flash_dmattn_varlen_func, - flash_dmattn_varlen_kvpacked_func, - flash_dmattn_varlen_qkvpacked_func, - ) - except ImportError: - # Fallback if interface module is not available - flash_dmattn_func = None - flash_dmattn_kvpacked_func = None - flash_dmattn_qkvpacked_func = None - flash_dmattn_varlen_func = None - flash_dmattn_varlen_kvpacked_func = None - flash_dmattn_varlen_qkvpacked_func = None -else: - flash_dmattn_func = None - flash_dmattn_kvpacked_func = None - flash_dmattn_qkvpacked_func = None - flash_dmattn_varlen_func = None - flash_dmattn_varlen_kvpacked_func = None - flash_dmattn_varlen_qkvpacked_func = None - -__all__ = [ - "triton_dmattn_func", - "flex_dmattn_func", - "flash_dmattn_func", - "flash_dmattn_kvpacked_func", - "flash_dmattn_qkvpacked_func", - "flash_dmattn_varlen_func", - "flash_dmattn_varlen_kvpacked_func", - "flash_dmattn_varlen_qkvpacked_func", - "flash_dmattn_func_auto", - "get_available_backends", - "TRITON_AVAILABLE", - "FLEX_AVAILABLE", - "CUDA_AVAILABLE", -] - def get_available_backends(): """Return a list of available backends.""" @@ -87,7 +48,7 @@ def flash_dmattn_func_auto(backend: Optional[str] = None, **kwargs): Args: backend (str, optional): Backend to use ('cuda', 'triton', 'flex'). - If None, will use the first available backend in order: cuda, triton, flex. + If None, will use the first available backend in order: cuda, triton, flex. **kwargs: Arguments to pass to the attention function. Returns: @@ -107,8 +68,6 @@ def flash_dmattn_func_auto(backend: Optional[str] = None, **kwargs): if backend == "cuda": if not CUDA_AVAILABLE: raise RuntimeError("CUDA backend is not available. Please build the CUDA extension.") - if flash_dmattn_func is None: - raise RuntimeError("CUDA flash_dmattn_func is not available. Please check the installation.") return flash_dmattn_func elif backend == "triton": @@ -123,3 +82,15 @@ def flash_dmattn_func_auto(backend: Optional[str] = None, **kwargs): else: raise ValueError(f"Unknown backend: {backend}. Available backends: {get_available_backends()}") + + +__all__ = [ + "CUDA_AVAILABLE", + "TRITON_AVAILABLE", + "FLEX_AVAILABLE", + "flash_dmattn_func", + "triton_dmattn_func", + "flex_dmattn_func", + "get_available_backends", + "flash_dmattn_func_auto", +]