Skip to content

Gate FA2 attention backend on model-class support#369

Merged
dmiv-helical merged 1 commit into
mainfrom
geneformer-fa2-fix
Apr 20, 2026
Merged

Gate FA2 attention backend on model-class support#369
dmiv-helical merged 1 commit into
mainfrom
geneformer-fa2-fix

Conversation

@dmiv-helical
Copy link
Copy Markdown
Contributor

Rationale

select_attn_backend previously returned "flash_attention_2" whenever flash_attn was installed and the device was CUDA, without checking whether the target model class actually declares FA2 support via HF's dispatcher. For BertForMaskedLM (Geneformer) that silently routed the model down a code path transformers can't actually dispatch, so the "Loading ... in bfloat16 for flash_attention_2 compatibility" warning wasn't just cosmetic noise — it flagged a branch that couldn't work. The helical integration-tests job doesn't install flash_attn, so this gap was invisible in CI.

Plan

  • Add a supports_fa2 parameter to select_attn_backend. Only models whose class declares _supports_flash_attn / _supports_flash_attn_2 can take the FA2 branch; others (Geneformer) fall back to sdpa.
  • Pass supports_fa2=True from HelixmRNA. Leave Geneformer on the default (False) and annotate the call site so callers who want FA2 for BertForMaskedLM know they have to wire flash_attn directly.
  • Drop the now-unreachable bfloat16-for-FA2 warnings from Geneformer; the sdpa fallback path never triggers them.
  • Add a flash-attn-integration CI job that installs flash_attn and smoke-tests both paths: Geneformer (regression guard — must still load on sdpa even with flash_attn present) and HelixmRNA (must actually run on the FA2 branch).

Comment thread .github/workflows/release.yml Outdated
Rationale
---------
select_attn_backend previously returned "flash_attention_2" whenever
flash_attn was installed and the device was CUDA, without checking
whether the target model class actually declares FA2 support via HF's
dispatcher. For BertForMaskedLM (Geneformer) that silently routed the
model down a code path transformers can't actually dispatch, so the
"Loading ... in bfloat16 for flash_attention_2 compatibility" warning
wasn't just cosmetic noise — it flagged a branch that couldn't work.
The helical integration-tests job doesn't install flash_attn, so this
gap was invisible in CI.

Plan
----
* Add a supports_fa2 parameter to select_attn_backend. Only models
  whose class declares _supports_flash_attn / _supports_flash_attn_2
  can take the FA2 branch; others (Geneformer) fall back to sdpa.
* Pass supports_fa2=True from HelixmRNA. Leave Geneformer on the
  default (False) and annotate the call site so callers who want FA2
  for BertForMaskedLM know they have to wire flash_attn directly.
* Drop the now-unreachable bfloat16-for-FA2 warnings from Geneformer;
  the sdpa fallback path never triggers them.
* Add a flash-attn-integration CI job that installs flash_attn and
  smoke-tests both paths: Geneformer (regression guard — must still
  load on sdpa even with flash_attn present) and HelixmRNA (must
  actually run on the FA2 branch).
@dmiv-helical dmiv-helical merged commit d1bc938 into main Apr 20, 2026
4 checks passed
@dmiv-helical dmiv-helical deleted the geneformer-fa2-fix branch April 20, 2026 20:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants