Add fused Triton kernel for permuted-weight FlashHead decode#3
Conversation
Port block-sparse matmul+argmax kernel from mistral_small_4. Permutes lm_head rows once so cluster k occupies contiguous rows [k*cap, (k+1)*cap), then the Triton kernel does matmul + per-probe argmax + atomic-max reduction in one pass over probed clusters. End-to-end Cosmos-Reason2-2B-W4A16 (vllm, enforce_eager=False): A40: 318.91 -> 336.73 TPS (+5.6%) Orin: 70.33 -> 76.46 TPS (+8.7%) Isolated lm_head call: A40: 341 us -> 191 us (1.78x) Orin: 1821 us -> 951 us (1.91x)
|
Thank you @yanptang and @daiguano3 for sharing code and ideas on your Triton kernel work. Really appreciate your contribution! |
| self.original_lm_head = lm_head | ||
| V, D = lm_head.weight.shape | ||
| K, cap = vocab_maps_tensor.shape | ||
| assert K * cap == V, f"unbalanced: K={K}*cap={cap} != V={V}" |
There was a problem hiding this comment.
unbalanced is officially still supported. Should we drop support for unbalanced clusters? If we want to keep the option, the assert should be removed and the triton kernel should not be used in that case.
There was a problem hiding this comment.
any model whose vocab is not perfectly divisible will fail.
There was a problem hiding this comment.
I think we should keep the old path fallback.
There was a problem hiding this comment.
As we discussed, decided to drop unbalanced support entirely. get_flash_head_parameters now raises ValueError with the offending cluster sizes if they aren't balanced (it was silently padding short clusters with their first vocab id, which is what made the constructor assert fail in the first place). The assert stays as a defense-in-depth check inside init.
| if hidden_states.shape[0] > 10: | ||
| logits = self.original_lm_head(hidden_states) | ||
| return self.get_next_token_standard(logits, do_sample, temperature) | ||
| logits = torch.nn.functional.linear( |
There was a problem hiding this comment.
Why did we change this here? Isn't it safer to just call the og head?
There was a problem hiding this comment.
To not have to store an extra copy (head is now reshuffled)
| and not use_identical_tiebreak | ||
| and self.special_token_ids_tensor.numel() == 0 | ||
| ): | ||
| from .fused import block_sparse_argmax_atomic |
There was a problem hiding this comment.
any reason this is not on the top level?
There was a problem hiding this comment.
No good reason. moved to module top.
| @@ -0,0 +1,187 @@ | |||
| # Copyright (C) 2026 Embedl AB | |||
There was a problem hiding this comment.
Claude had this comment: - Atomic-max tie-break differs from reference. With vocab_id in the low 32 bits, equal logits resolve to the largest vocab_id, while
torch.argmax picks the smallest index. The doc comment in block_sparse_argmax_atomic claims they match because topk is "arbitrary
on ties", but greedy decode's argmax over the gathered logits is deterministic. Consider packing ~vocab_id (or cap-1-rel) so
atomic_max picks the smaller id, matching the non-fused path bit-for-bit.
There was a problem hiding this comment.
Good catch. The kernel now packs, vocab_id in the low 32 bits, so atomic_max resolves tied logits to the smallest vocab_id; the wrapper flips back on extraction. Docstring updated.
| # Find each special-token's row in w_perm by searching vocab_maps. | ||
| # Called rarely (fallback path), so O(V) is acceptable. | ||
| vm_flat = self.vocab_maps.view(-1) | ||
| sp_rows = (vm_flat.unsqueeze(0) == sp.unsqueeze(1)).int().argmax(dim=1) |
There was a problem hiding this comment.
NIT - Silent failure if a special_token_id isn't in vocab_maps. _gather_cluster_logits (line 235) does (vm_flat ==
sp).int().argmax(dim=1) — argmax returns 0 when nothing matches, meaning a missing special token is silently mapped to w_perm row 0.
Worth either masking with an any() check or asserting at construction time that all special_token_ids appear in vocab_maps.
There was a problem hiding this comment.
Added a constructor-time check in FlashHead.init that scans vocab_maps_tensor for each special_token_ids entry and raises ValueError listing any missing ids. The runtime argmax-based lookup in _gather_cluster_logits is then guaranteed to find every special token, so the silent-zero case is gone.
- get_flash_head_parameters: raise ValueError on unbalanced clusters (was silently padding short clusters with their first vocab id, producing K*cap > V and tripping the constructor assert). - FlashHead.__init__: validate every special_token_ids entry appears in vocab_maps_tensor; raise ValueError listing any missing ids. - _block_sparse_atomic_kernel: pack ~vocab_id in the low bits so atomic_max picks the smallest vocab_id on tied logits, matching torch.argmax's first-index tie-break. Wrapper flips back on extract. - Hoist `from .fused import block_sparse_argmax_atomic` to module top. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
koccupies contiguous rows[k*cap, (k+1)*cap); the Triton kernel then does matmul + per-probe argmax + atomic-max reduction in a single pass over probed clusters.Performance
End-to-end Cosmos-Reason2-2B-W4A16 (vLLM,
enforce_eager=False):Isolated lm_head call: