Skip to content

Add fused Triton kernel for permuted-weight FlashHead decode#3

Merged
JonnaMat merged 3 commits into
masterfrom
feat/fused-triton-kernels
May 4, 2026
Merged

Add fused Triton kernel for permuted-weight FlashHead decode#3
JonnaMat merged 3 commits into
masterfrom
feat/fused-triton-kernels

Conversation

@WilhelmTr
Copy link
Copy Markdown
Contributor

@WilhelmTr WilhelmTr commented Apr 29, 2026

Summary

  • Permute lm_head rows once so cluster k occupies contiguous rows [k*cap, (k+1)*cap); the Triton kernel then does matmul + per-probe argmax + atomic-max reduction in a single pass over probed clusters.

Performance

End-to-end Cosmos-Reason2-2B-W4A16 (vLLM, enforce_eager=False):

  • A40: 318.91 → 336.73 TPS (+5.6%)
  • Orin: 70.33 → 76.46 TPS (+8.7%)

Isolated lm_head call:

  • A40: 341 µs → 191 µs (1.78×)
  • Orin: 1821 µs → 951 µs (1.91×)

Port block-sparse matmul+argmax kernel from mistral_small_4. Permutes
lm_head rows once so cluster k occupies contiguous rows [k*cap, (k+1)*cap),
then the Triton kernel does matmul + per-probe argmax + atomic-max
reduction in one pass over probed clusters.

End-to-end Cosmos-Reason2-2B-W4A16 (vllm, enforce_eager=False):
  A40:  318.91 -> 336.73 TPS (+5.6%)
  Orin: 70.33 -> 76.46 TPS  (+8.7%)

Isolated lm_head call:
  A40:  341 us -> 191 us  (1.78x)
  Orin: 1821 us -> 951 us (1.91x)
@WilhelmTr
Copy link
Copy Markdown
Contributor Author

Thank you @yanptang and @daiguano3 for sharing code and ideas on your Triton kernel work. Really appreciate your contribution!

@WilhelmTr WilhelmTr requested a review from JonnaMat April 29, 2026 13:23
self.original_lm_head = lm_head
V, D = lm_head.weight.shape
K, cap = vocab_maps_tensor.shape
assert K * cap == V, f"unbalanced: K={K}*cap={cap} != V={V}"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unbalanced is officially still supported. Should we drop support for unbalanced clusters? If we want to keep the option, the assert should be removed and the triton kernel should not be used in that case.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any model whose vocab is not perfectly divisible will fail.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should keep the old path fallback.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed, decided to drop unbalanced support entirely. get_flash_head_parameters now raises ValueError with the offending cluster sizes if they aren't balanced (it was silently padding short clusters with their first vocab id, which is what made the constructor assert fail in the first place). The assert stays as a defense-in-depth check inside init.

if hidden_states.shape[0] > 10:
logits = self.original_lm_head(hidden_states)
return self.get_next_token_standard(logits, do_sample, temperature)
logits = torch.nn.functional.linear(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we change this here? Isn't it safer to just call the og head?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To not have to store an extra copy (head is now reshuffled)

Comment thread src/flash_head/fused.py
Comment thread src/flash_head/flash_head.py Outdated
and not use_identical_tiebreak
and self.special_token_ids_tensor.numel() == 0
):
from .fused import block_sparse_argmax_atomic
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason this is not on the top level?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No good reason. moved to module top.

Comment thread src/flash_head/fused.py
@@ -0,0 +1,187 @@
# Copyright (C) 2026 Embedl AB
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude had this comment: - Atomic-max tie-break differs from reference. With vocab_id in the low 32 bits, equal logits resolve to the largest vocab_id, while
torch.argmax picks the smallest index. The doc comment in block_sparse_argmax_atomic claims they match because topk is "arbitrary
on ties", but greedy decode's argmax over the gathered logits is deterministic. Consider packing ~vocab_id (or cap-1-rel) so
atomic_max picks the smaller id, matching the non-fused path bit-for-bit.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. The kernel now packs, vocab_id in the low 32 bits, so atomic_max resolves tied logits to the smallest vocab_id; the wrapper flips back on extraction. Docstring updated.

# Find each special-token's row in w_perm by searching vocab_maps.
# Called rarely (fallback path), so O(V) is acceptable.
vm_flat = self.vocab_maps.view(-1)
sp_rows = (vm_flat.unsqueeze(0) == sp.unsqueeze(1)).int().argmax(dim=1)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT - Silent failure if a special_token_id isn't in vocab_maps. _gather_cluster_logits (line 235) does (vm_flat ==
sp).int().argmax(dim=1) — argmax returns 0 when nothing matches, meaning a missing special token is silently mapped to w_perm row 0.
Worth either masking with an any() check or asserting at construction time that all special_token_ids appear in vocab_maps.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a constructor-time check in FlashHead.init that scans vocab_maps_tensor for each special_token_ids entry and raises ValueError listing any missing ids. The runtime argmax-based lookup in _gather_cluster_logits is then guaranteed to find every special token, so the silent-zero case is gone.

@WilhelmTr WilhelmTr requested a review from JonnaMat April 30, 2026 09:13
- get_flash_head_parameters: raise ValueError on unbalanced clusters
  (was silently padding short clusters with their first vocab id,
  producing K*cap > V and tripping the constructor assert).
- FlashHead.__init__: validate every special_token_ids entry appears
  in vocab_maps_tensor; raise ValueError listing any missing ids.
- _block_sparse_atomic_kernel: pack ~vocab_id in the low bits so
  atomic_max picks the smallest vocab_id on tied logits, matching
  torch.argmax's first-index tie-break. Wrapper flips back on extract.
- Hoist `from .fused import block_sparse_argmax_atomic` to module top.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@JonnaMat JonnaMat merged commit 40d6d19 into master May 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants