Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Description

Streamlines package initialization by consolidating backend imports into consistent try/except blocks and removing redundant CUDA extension checks and nested import logic. Focuses the public API on the core CUDA interface while keeping Triton and Flex backends accessible. Updates version to 1.0.3.

Type of Change

  • Bug fix (non-breaking change which fixes an issue)
  • Code refactoring

Related Issues

Changes Made

Code Changes

  • Modified Python API
  • Updated CUDA kernels
  • Changed build system
  • Updated dependencies

Details:

  • Switched to absolute imports and simplified error handling in flash_dmattn/__init__.py.
  • Unified backend availability flags: CUDA_AVAILABLE, TRITON_AVAILABLE, FLEX_AVAILABLE.
  • Exposed a clean API surface via __all__:
    • flash_dmattn_func (from flash_dmattn/flash_dmattn_interface.py)
    • triton_dmattn_func (from flash_dmattn/flash_dmattn_triton.py)
    • flex_dmattn_func (from flash_dmattn/flash_dmattn_flex.py)
    • get_available_backends, flash_dmattn_func_auto
  • Removed redundant/legacy CUDA function variants from exports.
  • Bumped __version__ to 1.0.3.

Documentation

  • Updated README
  • Updated API documentation
  • Added examples
  • Updated benchmarks

Testing

  • Existing imports work without side effects:
    • from flash_dmattn import flash_dmattn_func_auto, get_available_backends
  • Backend detection reflects installed backends and returns correct callables for "cuda" | "triton" | "flex".
  • Smoke tests on small tensors (forward path) for available backends.
  • Backward pass smoke tests (if gradients are enabled).

Test Configuration

  • OS: Windows 11 / Ubuntu 22.04
  • Python: 3.10.x
  • PyTorch: 2.4.x
  • CUDA: 12.1
  • GPU: NVIDIA RTX 4090

Performance Impact

  • No runtime performance impact. Changes only affect import-time initialization logic.

Breaking Changes

  • Removed legacy/duplicated CUDA function exports from the package root.

Migration:

  • Import core functions from flash_dmattn directly:
    • from flash_dmattn import flash_dmattn_func, flash_dmattn_func_auto
    • For specific backends: triton_dmattn_func, flex_dmattn_func

Checklist

  • My code follows the project's style guidelines
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • Any dependent changes have been merged and published

CUDA-specific (if applicable)

  • CUDA kernels compile without warnings
  • Tested on SM 8.0+ architectures
  • Memory usage has been profiled
  • No memory leaks detected

Additional Notes

  • The simplified init path reduces import-time failures and clarifies error messages when optional backends are unavailable.

Streamlines the initialization logic by removing redundant CUDA extension checking and complex nested import structures. Consolidates all backend imports into consistent try-catch blocks with cleaner error handling.

Removes multiple CUDA function variants that were previously exposed, focusing on the core flash_dmattn_func interface. Updates __all__ exports to reflect the simplified API surface.

Improves code maintainability by using absolute imports and reducing conditional import complexity.
Copilot AI review requested due to automatic review settings September 6, 2025 12:32
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Simplifies the package initialization by consolidating backend imports into consistent try/except blocks and removing redundant CUDA extension checks. Updates the public API to focus on core functionality while maintaining access to all backends.

  • Streamlined import logic with direct absolute imports and unified error handling
  • Simplified CUDA availability detection by removing nested import checks
  • Updated public API exports to focus on core functions while removing legacy variants
Comments suppressed due to low confidence (1)

flash_dmattn/init.py:1

  • The removed redundant check for flash_dmattn_func is None is good, but this leaves a potential gap. If CUDA_AVAILABLE is True but flash_dmattn_func is None (due to import failure), the function would return None instead of a callable. Consider adding a simple check if flash_dmattn_func is None: after the CUDA_AVAILABLE check to ensure consistency.
# Copyright (c) 2025, Jingze Shi.

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@LoserCheems
Copy link
Collaborator Author

LoserCheems commented Sep 6, 2025

Hi @zylwithxy, let's switch to the fix-import-bug branch and see if the import error has been resolved

@LoserCheems LoserCheems merged commit d4d7247 into main Sep 7, 2025
@LoserCheems LoserCheems deleted the fix-import-bug branch November 13, 2025 04:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Installation Error

3 participants