Add SDPA/FlexAttention backends for CuteChronos2 attention#39
Conversation
Replace custom Triton attention with PyTorch SDPA (scale=1.0) as the preferred CUDA backend. SDPA auto-selects FlashAttention2/cuDNN kernels and is ~2x faster than the eager fallback on RTX 5090. New module cutechronos/modules/flex_attention.py provides: - sdpa_unscaled_attention: SDPA with scale=1.0 (recommended) - flex_unscaled_attention: FlexAttention for mask-free case, SDPA fallback for masked - eager_unscaled_attention: delegates to existing _fallbacks implementation - Backend registry with benchmark_backends() and get_best_attention_backend() Integration: FusedTimeSelfAttention and model.py now use SDPA on CUDA, with Triton and eager as fallbacks for non-CUDA paths. 56 new tests, 195 total tests passing. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
Codex Infinity Start a task on this PR's branch by commenting:
Tasks and logs: https://codex-infinity.com |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 4c20cb304e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| Returns: | ||
| Dict mapping backend name to mean time in milliseconds. | ||
| """ | ||
| torch.manual_seed(42) |
There was a problem hiding this comment.
Preserve caller RNG state in benchmark_backends
benchmark_backends() unconditionally calls torch.manual_seed(42), and get_best_attention_backend() reaches it on every invocation. In any training/sweep setup that benchmarks attention before initializing weights or samplers, this silently resets the global RNG and changes all later randomized steps (initialization, shuffling, augmentation) to follow seed 42 instead of the caller's seed. Please use a local torch.Generator or save/restore RNG state so backend selection does not perturb experiment results.
Useful? React with 👍 / 👎.
| if on_cuda: | ||
| attn_output = _sdpa_attn(query_states, key_states, value_states, attention_mask) |
There was a problem hiding this comment.
Keep the CUDA attention path deterministic by default
This switches the current CUDA path from the eager fallback to auto-dispatched scaled_dot_product_attention. On CUDA builds where SDPA selects a cuDNN kernel, PyTorch may use a nondeterministic algorithm, so repeated forecasts with identical weights and inputs can now drift run-to-run. That is a regression for experiments and output-cache checks that depended on the previous fallback path being stable; if the faster kernel is desired, it should be opt-in or gated by a determinism setting.
Useful? React with 👍 / 👎.
) Replace custom Triton attention with PyTorch SDPA (scale=1.0) as the preferred CUDA backend. SDPA auto-selects FlashAttention2/cuDNN kernels and is ~2x faster than the eager fallback on RTX 5090. New module cutechronos/modules/flex_attention.py provides: - sdpa_unscaled_attention: SDPA with scale=1.0 (recommended) - flex_unscaled_attention: FlexAttention for mask-free case, SDPA fallback for masked - eager_unscaled_attention: delegates to existing _fallbacks implementation - Backend registry with benchmark_backends() and get_best_attention_backend() Integration: FusedTimeSelfAttention and model.py now use SDPA on CUDA, with Triton and eager as fallbacks for non-CUDA paths. 56 new tests, 195 total tests passing. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
) Replace custom Triton attention with PyTorch SDPA (scale=1.0) as the preferred CUDA backend. SDPA auto-selects FlashAttention2/cuDNN kernels and is ~2x faster than the eager fallback on RTX 5090. New module cutechronos/modules/flex_attention.py provides: - sdpa_unscaled_attention: SDPA with scale=1.0 (recommended) - flex_unscaled_attention: FlexAttention for mask-free case, SDPA fallback for masked - eager_unscaled_attention: delegates to existing _fallbacks implementation - Backend registry with benchmark_backends() and get_best_attention_backend() Integration: FusedTimeSelfAttention and model.py now use SDPA on CUDA, with Triton and eager as fallbacks for non-CUDA paths. 56 new tests, 195 total tests passing. Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Summary
cutechronos/modules/flex_attention.pywith SDPA, FlexAttention, and eager backends for unscaled attention (scale=1.0)FusedTimeSelfAttentionandmodel.py, replacing the custom Triton kernel pathget_best_attention_backend(),benchmark_backends(), andlist_backends()Test plan
Generated with Claude Code