Skip to content

Add SDPA/FlexAttention backends for CuteChronos2 attention#39

Merged
lee101 merged 1 commit into
mainfrom
worktree-agent-a44cb2b8
Mar 22, 2026
Merged

Add SDPA/FlexAttention backends for CuteChronos2 attention#39
lee101 merged 1 commit into
mainfrom
worktree-agent-a44cb2b8

Conversation

@lee101

@lee101 lee101 commented Mar 19, 2026

Copy link
Copy Markdown
Owner

Summary

  • Add cutechronos/modules/flex_attention.py with SDPA, FlexAttention, and eager backends for unscaled attention (scale=1.0)
  • Wire SDPA as the preferred CUDA attention backend in FusedTimeSelfAttention and model.py, replacing the custom Triton kernel path
  • SDPA auto-selects FlashAttention2/cuDNN kernels, benchmarked at ~2x faster than eager on RTX 5090
  • FlexAttention provided for mask-free case; falls back to SDPA for masked attention (FlexAttention's score_mod is incompatible with general additive mask tensor indexing under vmap tracing)
  • Backend registry with get_best_attention_backend(), benchmark_backends(), and list_backends()

Test plan

  • 56 new tests covering all backends (SDPA, Flex, eager) on CPU and CUDA
  • All 195 cutechronos tests pass (no regressions)
  • Cross-backend consistency verified: all backends match reference within max abs error < 1e-4
  • Benchmark tests print timing comparisons (SDPA ~0.068ms vs eager ~0.129ms for B=4,H=12,S=130,D=64)
  • Integration test verifies SDPA is drop-in compatible with existing fallback

Generated with Claude Code

Replace custom Triton attention with PyTorch SDPA (scale=1.0) as the
preferred CUDA backend. SDPA auto-selects FlashAttention2/cuDNN kernels
and is ~2x faster than the eager fallback on RTX 5090.

New module cutechronos/modules/flex_attention.py provides:
- sdpa_unscaled_attention: SDPA with scale=1.0 (recommended)
- flex_unscaled_attention: FlexAttention for mask-free case, SDPA fallback for masked
- eager_unscaled_attention: delegates to existing _fallbacks implementation
- Backend registry with benchmark_backends() and get_best_attention_backend()

Integration: FusedTimeSelfAttention and model.py now use SDPA on CUDA,
with Triton and eager as fallbacks for non-CUDA paths.

56 new tests, 195 total tests passing.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@codex-infinite

Copy link
Copy Markdown

Codex Infinity
Hi! I'm Codex Infinity, your coding agent for this repo.

Start a task on this PR's branch by commenting:

Tasks and logs: https://codex-infinity.com

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 4c20cb304e

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Returns:
Dict mapping backend name to mean time in milliseconds.
"""
torch.manual_seed(42)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve caller RNG state in benchmark_backends

benchmark_backends() unconditionally calls torch.manual_seed(42), and get_best_attention_backend() reaches it on every invocation. In any training/sweep setup that benchmarks attention before initializing weights or samplers, this silently resets the global RNG and changes all later randomized steps (initialization, shuffling, augmentation) to follow seed 42 instead of the caller's seed. Please use a local torch.Generator or save/restore RNG state so backend selection does not perturb experiment results.

Useful? React with 👍 / 👎.

Comment on lines +154 to +155
if on_cuda:
attn_output = _sdpa_attn(query_states, key_states, value_states, attention_mask)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Keep the CUDA attention path deterministic by default

This switches the current CUDA path from the eager fallback to auto-dispatched scaled_dot_product_attention. On CUDA builds where SDPA selects a cuDNN kernel, PyTorch may use a nondeterministic algorithm, so repeated forecasts with identical weights and inputs can now drift run-to-run. That is a regression for experiments and output-cache checks that depended on the previous fallback path being stable; if the faster kernel is desired, it should be opt-in or gated by a determinism setting.

Useful? React with 👍 / 👎.

@lee101 lee101 merged commit fcf38c2 into main Mar 22, 2026
3 of 4 checks passed
lee101 added a commit that referenced this pull request Mar 24, 2026
)

Replace custom Triton attention with PyTorch SDPA (scale=1.0) as the
preferred CUDA backend. SDPA auto-selects FlashAttention2/cuDNN kernels
and is ~2x faster than the eager fallback on RTX 5090.

New module cutechronos/modules/flex_attention.py provides:
- sdpa_unscaled_attention: SDPA with scale=1.0 (recommended)
- flex_unscaled_attention: FlexAttention for mask-free case, SDPA fallback for masked
- eager_unscaled_attention: delegates to existing _fallbacks implementation
- Backend registry with benchmark_backends() and get_best_attention_backend()

Integration: FusedTimeSelfAttention and model.py now use SDPA on CUDA,
with Triton and eager as fallbacks for non-CUDA paths.

56 new tests, 195 total tests passing.

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
lee101 added a commit that referenced this pull request Apr 13, 2026
)

Replace custom Triton attention with PyTorch SDPA (scale=1.0) as the
preferred CUDA backend. SDPA auto-selects FlashAttention2/cuDNN kernels
and is ~2x faster than the eager fallback on RTX 5090.

New module cutechronos/modules/flex_attention.py provides:
- sdpa_unscaled_attention: SDPA with scale=1.0 (recommended)
- flex_unscaled_attention: FlexAttention for mask-free case, SDPA fallback for masked
- eager_unscaled_attention: delegates to existing _fallbacks implementation
- Backend registry with benchmark_backends() and get_best_attention_backend()

Integration: FusedTimeSelfAttention and model.py now use SDPA on CUDA,
with Triton and eager as fallbacks for non-CUDA paths.

56 new tests, 195 total tests passing.

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant