[webgpu] Add NTC layout support for CausalConvWithState#28504
Open
xiaofeihan1 wants to merge 1 commit into
Open
[webgpu] Add NTC layout support for CausalConvWithState#28504xiaofeihan1 wants to merge 1 commit into
xiaofeihan1 wants to merge 1 commit into
Conversation
Adds an optional `data_format` attribute (default "NCT") to com.microsoft.CausalConvWithState. When set to "NTC", the input/output tensor layout is channels-last [B, T, C] instead of channels-first [B, C, T]. Motivation: - WebGPU coalesced reads favor channel as the innermost (contiguous) dim. - For Qwen3.5 / Mamba-style models, the conv is wrapped between two Transposes (NTC->NCT before, NCT->NTC after) in the HuggingFace reference. Supporting NTC natively lets the model builder skip both Transposes (48 nodes removed for Qwen3.5-4B). - Measured +7.4% gen TPS / +5.8% e2e on Qwen3.5-4B int4 (RTX 5080, prefill-1000, max_tokens=100). Scope: - WebGPU EP: supports both NCT and NTC. Layout is part of CacheHint so the two paths get separate compiled shaders. - CPU / CUDA EP: NCT only. Explicitly reject NTC at kernel construction to fail loudly rather than silently mis-compute. - Default value is "NCT" - existing models without the attribute behave unchanged.
| std::string data_format = info.GetAttrOrDefault<std::string>("data_format", "NCT"); | ||
| ORT_ENFORCE(data_format == "NCT", | ||
| "CPU CausalConvWithState only supports data_format='NCT' currently. " | ||
| "Got: ", data_format); |
Contributor
There was a problem hiding this comment.
Suggested change
| "Got: ", data_format); | |
| "Got: ", | |
| data_format); |
| std::string data_format = info.GetAttrOrDefault<std::string>("data_format", "NCT"); | ||
| ORT_ENFORCE(data_format == "NCT", | ||
| "CUDA CausalConvWithState only supports data_format='NCT' currently. " | ||
| "Got: ", data_format); |
Contributor
There was a problem hiding this comment.
Suggested change
| "Got: ", data_format); | |
| "Got: ", | |
| data_format); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Adds an optional
data_formatattribute (default "NCT") to com.microsoft.CausalConvWithState. When set to "NTC", the input/output tensor layout is channels-last [B, T, C] instead of channels-first [B, C, T].Motivation:
Scope: