diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a63afa5b85..ee4a4b5c4b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -124,6 +124,9 @@ By @SupaMaggie70Incorporated in [#8206](https://github.com/gfx-rs/wgpu/pull/8206 - Corrected documentation of the minimum alignment of the *end* of a mapped range of a buffer (it is 4, not 8). By @kpreid in [#8450](https://github.com/gfx-rs/wgpu/pull/8450). - `util::StagingBelt` now takes a `Device` when it is created instead of when it is used. By @kpreid in [#8462](https://github.com/gfx-rs/wgpu/pull/8462). +#### Metal +- Add support for mesh shaders. By @SupaMaggie70Incorporated in [#8139](https://github.com/gfx-rs/wgpu/pull/8139) + ### Bug Fixes #### naga @@ -141,6 +144,7 @@ By @SupaMaggie70Incorporated in [#8206](https://github.com/gfx-rs/wgpu/pull/8206 - The texture subresources used by the color attachments of a render pass are no longer allowed to overlap when accessed via different texture views. By @andyleiserson in [#8402](https://github.com/gfx-rs/wgpu/pull/8402). - The `STORAGE_READ_ONLY` texture usage is now permitted to coexist with other read-only usages. By @andyleiserson in [#8490](https://github.com/gfx-rs/wgpu/pull/8490). - Validate that buffers are unmapped in `write_buffer` calls. By @ErichDonGubler in [#8454](https://github.com/gfx-rs/wgpu/pull/8454). +- Add WGSL parsing for mesh shaders. By @inner-daemons in [#8370](https://github.com/gfx-rs/wgpu/pull/8370). #### DX12 @@ -154,6 +158,10 @@ By @SupaMaggie70Incorporated in [#8206](https://github.com/gfx-rs/wgpu/pull/8206 - Fixed a bug where the texture aspect was not passed through when calling `copy_texture_to_buffer` in WebGPU, causing the copy to fail for depth/stencil textures. By @Tim-Evans-Seequent in [#8445](https://github.com/gfx-rs/wgpu/pull/8445). +### Metal + +- Complete support for mesh shaders without passthrough shaders. By @inner-daemons in [#8493](https://github.com/gfx-rs/wgpu/pull/8493). + #### hal - `DropCallback`s are now called after dropping all other fields of their parent structs. By @jerzywilczek in [#8353](https://github.com/gfx-rs/wgpu/pull/8353) diff --git a/docs/api-specs/mesh_shading.md b/docs/api-specs/mesh_shading.md index 4b28ec635e7..41720765a55 100644 --- a/docs/api-specs/mesh_shading.md +++ b/docs/api-specs/mesh_shading.md @@ -103,10 +103,8 @@ An example of using mesh shaders to render a single triangle can be seen [here]( * DirectX 12 support is planned. * Metal support is desired but not currently planned. - ## Naga implementation - ### Supported frontends * 🛠️ WGSL * ❌ SPIR-V @@ -114,7 +112,7 @@ An example of using mesh shaders to render a single triangle can be seen [here]( ### Supported backends * 🛠️ SPIR-V -* ❌ HLSL +* 🛠️ HLSL * ❌ MSL * 🚫 GLSL * 🚫 WGSL @@ -130,7 +128,7 @@ The majority of changes relating to mesh shaders will be in WGSL and `naga`. Using any of these features in a `wgsl` program will require adding the `enable mesh_shading` directive to the top of a program. -Two new shader stages will be added to `WGSL`. Fragment shaders are also modified slightly. Both task shaders and mesh shaders are allowed to use any compute-specific functionality, such as subgroup operations. +Two new shader stages will be added to `WGSL`. Fragment shaders are also modified slightly. Both task shaders and mesh shaders are allowed to use any compute-available functionality, including subgroup operations. ### Task shader @@ -145,6 +143,8 @@ A task shader entry point must return a `vec3` value. The return value of e Each task shader workgroup dispatches an independent mesh shader grid: in mesh shader invocations, `@builtin` values like `workgroup_id` and `global_invocation_id` describe the position of the workgroup and invocation within that grid; and `@builtin(num_workgroups)` matches the task shader workgroup's return value. Mesh shaders dispatched for other task shader workgroups are not included in the count. If it is necessary for a mesh shader to know which task shader workgroup dispatched it, the task shader can include its own workgroup id in the task payload. +Task shaders can use compute and subgroup builtin inputs, in addition to `view_index` and `draw_id`. + ### Mesh shader A function with the `@mesh` attribute is a **mesh shader entry point**. Mesh shaders must not return anything. @@ -159,17 +159,19 @@ A mesh shader entry point must have the following attributes: - `@workgroup_size`: this has the same meaning as when it appears on a compute shader entry point. -- `@vertex_output(V, NV)`: This indicates that the mesh shader workgroup will generate at most `NV` vertex values, each of type `V`. +- `@mesh(VAR)`: Here, `VAR` represents a workgroup variable storing the output information. -- `@primitive_output(P, NP)`: This indicates that the mesh shader workgroup will generate at most `NP` primitives, each of type `P`. +All mesh shader outputs are per-workgroup, and taken from the workgroup variable specified above. The type must have exactly 4 fields: +- A field decorated with `@builtin(vertex_count)`, with type `u32`: this field represents the number of vertices that will be drawn +- A field decorated with `@builtin(primitive_count)`, with type `u32`: this field represents the number of primitives that will be drawn +- A field decorated with `@builtin(vertices)`, typed as an array of `V`, where `V` is the vertex output type as specified below +- A field decorated with `@builtin(primitives)`, typed as an array of `P`, where `P` is the primitive output type as specified below -Each mesh shader entry point invocation must call the `setMeshOutputs(numVertices: u32, numPrimitives: u32)` builtin function at least once. The values passed by each workgroup's first invocation (that is, the one whose `local_invocation_index` is `0`) determine how many vertices (values of type `V`) and primitives (values of type `P`) the workgroup must produce. The user can still write past these indices, but they won't be used in the output. +For a vertex count `NV`, the first `NV` elements of the vertex array above are outputted. Therefore, `NV` must be less than or equal to the size of the vertex array. The same is true for primitives with `NP`. -The `numVertices` and `numPrimitives` arguments must be no greater than `NV` and `NP` from the `@vertex_output` and `@primitive_output` attributes. +The vertex output type `V` must meet the same requirements as a struct type returned by a `@vertex` entry point: all members must have either `@builtin` or `@location` attributes, there must be a `@builtin(position)`, and so on. -To produce vertex data, the workgroup as a whole must make `numVertices` calls to the `setVertex(i: u32, vertex: V)` builtin function. This establishes `vertex` as the value of the `i`'th vertex, where `i` is less than the maximum number of output vertices in the `@vertex_output` attribute. `V` is the type given in the `@vertex_output` attribute. `V` must meet the same requirements as a struct type returned by a `@vertex` entry point: all members must have either `@builtin` or `@location` attributes, there must be a `@builtin(position)`, and so on. - -To produce primitives, the workgroup as a whole must make `numPrimitives` calls to the `setPrimitive(i: u32, primitive: P)` builtin function. This establishes `primitive` as the value of the `i`'th primitive, where `i` is less than the maximum number of output primitives in the `@primitive_output` attribute. `P` is the type given in the `@primitive_output` attribute. `P` must be a struct type, every member of which either has a `@location` or `@builtin` attribute. The following `@builtin` attributes are allowed: +The primitive output type `P` must be a struct type, every member of which either has a `@location` or `@builtin` attribute. All members decorated with `@location` must also be decorated with `@per_primitive`, as must the corresponding fragment input. The `@per_primitive` decoration may only be applied to members decorated with `@location`. The following `@builtin` attributes are allowed: - `triangle_indices`, `line_indices`, or `point_index`: The annotated member must be of type `vec3`, `vec2`, or `u32`. @@ -179,15 +181,13 @@ To produce primitives, the workgroup as a whole must make `numPrimitives` calls - `cull_primitive`: The annotated member must be of type `bool`. If it is true, then the primitive is skipped during rendering. -Every member of `P` with a `@location` attribute must either have a `@per_primitive` attribute, or be part of a struct type that appears in the primitive data as a struct member with the `@per_primitive` attribute. - The `@location` attributes of `P` and `V` must not overlap, since they are merged to produce the user-defined inputs to the fragment shader. -It is possible to write to the same vertex or primitive index repeatedly. Since the implicit arrays written by `setVertex` and `setPrimitive` are shared by the workgroup, data races on writes to the same index for a given type are undefined behavior. +Mesh shaders can use compute and mesh shader builtin inputs, in addition to `view_index`, and if no task shader is present, `draw_id`. ### Fragment shader -Fragment shaders can access vertex output data as if it is from a vertex shader. They can also access primitive output data, provided the input is decorated with `@per_primitive`. The `@per_primitive` attribute can be applied to a value directly, such as `@per_primitive @location(1) value: vec4`, to a struct such as `@per_primitive primitive_input: PrimitiveInput` where `PrimitiveInput` is a struct containing fields decorated with `@location` and `@builtin`, or to members of a struct that are themselves decorated with `@location` or `@builtin`. +Fragment shaders can access vertex output data as if it is from a vertex shader. They can also access primitive output data, provided the input is decorated with `@per_primitive`. The `@per_primitive` decoration may only be applied to inputs or struct members decorated with `@location`. The primitive state is part of the fragment input and must match the output of the mesh shader in the pipeline. Using `@per_primitive` also requires enabling the mesh shader extension. Additionally, the locations of vertex and primitive input cannot overlap. @@ -199,72 +199,75 @@ The following is a full example of WGSL shaders that could be used to create a m enable mesh_shading; const positions = array( - vec4(0.,1.,0.,1.), - vec4(-1.,-1.,0.,1.), - vec4(1.,-1.,0.,1.) + vec4(0., 1., 0., 1.), + vec4(-1., -1., 0., 1.), + vec4(1., -1., 0., 1.) ); const colors = array( - vec4(0.,1.,0.,1.), - vec4(0.,0.,1.,1.), - vec4(1.,0.,0.,1.) + vec4(0., 1., 0., 1.), + vec4(0., 0., 1., 1.), + vec4(1., 0., 0., 1.) ); struct TaskPayload { - colorMask: vec4, - visible: bool, + colorMask: vec4, + visible: bool, } var taskPayload: TaskPayload; var workgroupData: f32; struct VertexOutput { - @builtin(position) position: vec4, - @location(0) color: vec4, + @builtin(position) position: vec4, + @location(0) color: vec4, } struct PrimitiveOutput { - @builtin(triangle_indices) index: vec3, - @builtin(cull_primitive) cull: bool, - @per_primitive @location(1) colorMask: vec4, + @builtin(triangle_indices) index: vec3, + @builtin(cull_primitive) cull: bool, + @per_primitive @location(1) colorMask: vec4, } struct PrimitiveInput { - @per_primitive @location(1) colorMask: vec4, + @per_primitive @location(1) colorMask: vec4, } @task @payload(taskPayload) @workgroup_size(1) fn ts_main() -> @builtin(mesh_task_size) vec3 { - workgroupData = 1.0; - taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); - taskPayload.visible = true; - return vec3(3, 1, 1); + workgroupData = 1.0; + taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + return vec3(3, 1, 1); +} + +struct MeshOutput { + @builtin(vertices) vertices: array, + @builtin(primitives) primitives: array, + @builtin(vertex_count) vertex_count: u32, + @builtin(primitive_count) primitive_count: u32, } -@mesh + +var mesh_output: MeshOutput; +@mesh(mesh_output) @payload(taskPayload) -@vertex_output(VertexOutput, 3) @primitive_output(PrimitiveOutput, 1) @workgroup_size(1) fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3) { - setMeshOutputs(3, 1); - workgroupData = 2.0; - var v: VertexOutput; - - v.position = positions[0]; - v.color = colors[0] * taskPayload.colorMask; - setVertex(0, v); - - v.position = positions[1]; - v.color = colors[1] * taskPayload.colorMask; - setVertex(1, v); - - v.position = positions[2]; - v.color = colors[2] * taskPayload.colorMask; - setVertex(2, v); - - var p: PrimitiveOutput; - p.index = vec3(0, 1, 2); - p.cull = !taskPayload.visible; - p.colorMask = vec4(1.0, 0.0, 1.0, 1.0); - setPrimitive(0, p); + mesh_output.vertex_count = 3; + mesh_output.primitive_count = 1; + workgroupData = 2.0; + + mesh_output.vertices[0].position = positions[0]; + mesh_output.vertices[0].color = colors[0] * taskPayload.colorMask; + + mesh_output.vertices[1].position = positions[1]; + mesh_output.vertices[1].color = colors[1] * taskPayload.colorMask; + + mesh_output.vertices[2].position = positions[2]; + mesh_output.vertices[2].color = colors[2] * taskPayload.colorMask; + + mesh_output.primitives[0].index = vec3(0, 1, 2); + mesh_output.primitives[0].cull = !taskPayload.visible; + mesh_output.primitives[0].colorMask = vec4(1.0, 0.0, 1.0, 1.0); } @fragment fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { - return vertex.color * primitive.colorMask; + return vertex.color * primitive.colorMask; } ``` diff --git a/examples/features/src/lib.rs b/examples/features/src/lib.rs index 2299dded725..76837e518d7 100644 --- a/examples/features/src/lib.rs +++ b/examples/features/src/lib.rs @@ -49,6 +49,7 @@ fn all_tests() -> Vec { cube::TEST, cube::TEST_LINES, hello_synchronization::tests::SYNC, + mesh_shader::TEST, mipmap::TEST, mipmap::TEST_QUERY, msaa_line::TEST, diff --git a/examples/features/src/mesh_shader/mod.rs b/examples/features/src/mesh_shader/mod.rs index 4b9f735c24e..9a202d19272 100644 --- a/examples/features/src/mesh_shader/mod.rs +++ b/examples/features/src/mesh_shader/mod.rs @@ -1,6 +1,12 @@ use std::process::Stdio; // Same as in mesh shader tests +fn compile_wgsl(device: &wgpu::Device) -> wgpu::ShaderModule { + device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()), + }) +} fn compile_glsl(device: &wgpu::Device, shader_stage: &'static str) -> wgpu::ShaderModule { let cmd = std::process::Command::new("glslc") .args([ @@ -71,21 +77,31 @@ impl crate::framework::Example for Example { device: &wgpu::Device, _queue: &wgpu::Queue, ) -> Self { - let (ts, ms, fs) = if adapter.get_info().backend == wgpu::Backend::Vulkan { - ( - compile_glsl(device, "task"), - compile_glsl(device, "mesh"), - compile_glsl(device, "frag"), - ) - } else if adapter.get_info().backend == wgpu::Backend::Dx12 { - ( - compile_hlsl(device, "Task", "as"), - compile_hlsl(device, "Mesh", "ms"), - compile_hlsl(device, "Frag", "ps"), - ) - } else { - panic!("Example can only run on vulkan or dx12"); - }; + let (ts, ms, fs, ts_name, ms_name, fs_name) = + if adapter.get_info().backend == wgpu::Backend::Metal { + let s = compile_wgsl(device); + (s.clone(), s.clone(), s, "ts_main", "ms_main", "fs_main") + } else if adapter.get_info().backend == wgpu::Backend::Vulkan { + ( + compile_glsl(device, "task"), + compile_glsl(device, "mesh"), + compile_glsl(device, "frag"), + "main", + "main", + "main", + ) + } else if adapter.get_info().backend == wgpu::Backend::Dx12 { + ( + compile_hlsl(device, "Task", "as"), + compile_hlsl(device, "Mesh", "ms"), + compile_hlsl(device, "Frag", "ps"), + "main", + "main", + "main", + ) + } else { + panic!("Example can only run on vulkan or dx12"); + }; let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { label: None, bind_group_layouts: &[], @@ -96,17 +112,17 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout), task: Some(wgpu::TaskState { module: &ts, - entry_point: Some("main"), + entry_point: Some(ts_name), compilation_options: Default::default(), }), mesh: wgpu::MeshState { module: &ms, - entry_point: Some("main"), + entry_point: Some(ms_name), compilation_options: Default::default(), }, fragment: Some(wgpu::FragmentState { module: &fs, - entry_point: Some("main"), + entry_point: Some(fs_name), compilation_options: Default::default(), targets: &[Some(config.view_formats[0].into())], }), @@ -179,3 +195,23 @@ impl crate::framework::Example for Example { pub fn main() { crate::framework::run::("mesh_shader"); } + +#[cfg(test)] +#[wgpu_test::gpu_test] +pub static TEST: crate::framework::ExampleTestParams = + crate::framework::ExampleTestParams { + name: "mesh_shader", + // Generated on 1080ti on Vk/Windows + image_path: "/examples/features/src/mesh_shader/screenshot.png", + width: 1024, + height: 768, + optional_features: wgpu::Features::default(), + base_test_parameters: wgpu_test::TestParameters::default() + .features( + wgpu::Features::EXPERIMENTAL_MESH_SHADER + | wgpu::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS, + ) + .limits(wgpu::Limits::defaults().using_recommended_minimum_mesh_shader_values()), + comparisons: &[wgpu_test::ComparisonType::Mean(0.005)], + _phantom: std::marker::PhantomData::, + }; diff --git a/examples/features/src/mesh_shader/screenshot.png b/examples/features/src/mesh_shader/screenshot.png new file mode 100644 index 00000000000..df76e141504 Binary files /dev/null and b/examples/features/src/mesh_shader/screenshot.png differ diff --git a/examples/features/src/mesh_shader/shader.wgsl b/examples/features/src/mesh_shader/shader.wgsl new file mode 100644 index 00000000000..cdc7366b415 --- /dev/null +++ b/examples/features/src/mesh_shader/shader.wgsl @@ -0,0 +1,74 @@ +enable mesh_shading; + +const positions = array( + vec4(0., 1., 0., 1.), + vec4(-1., -1., 0., 1.), + vec4(1., -1., 0., 1.) +); +const colors = array( + vec4(0., 1., 0., 1.), + vec4(0., 0., 1., 1.), + vec4(1., 0., 0., 1.) +); +struct TaskPayload { + colorMask: vec4, + visible: bool, +} +var taskPayload: TaskPayload; +var workgroupData: f32; +struct VertexOutput { + @builtin(position) position: vec4, + @location(0) color: vec4, +} +struct PrimitiveOutput { + @builtin(triangle_indices) index: vec3, + @builtin(cull_primitive) cull: bool, + @per_primitive @location(1) colorMask: vec4, +} +struct PrimitiveInput { + @per_primitive @location(1) colorMask: vec4, +} + +@task +@payload(taskPayload) +@workgroup_size(1) +fn ts_main() -> @builtin(mesh_task_size) vec3 { + workgroupData = 1.0; + taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + return vec3(3, 1, 1); +} + +struct MeshOutput { + @builtin(vertices) vertices: array, + @builtin(primitives) primitives: array, + @builtin(vertex_count) vertex_count: u32, + @builtin(primitive_count) primitive_count: u32, +} + +var mesh_output: MeshOutput; +@mesh(mesh_output) +@payload(taskPayload) +@workgroup_size(1) +fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3) { + mesh_output.vertex_count = 3; + mesh_output.primitive_count = 1; + workgroupData = 2.0; + + mesh_output.vertices[0].position = positions[0]; + mesh_output.vertices[0].color = colors[0] * taskPayload.colorMask; + + mesh_output.vertices[1].position = positions[1]; + mesh_output.vertices[1].color = colors[1] * taskPayload.colorMask; + + mesh_output.vertices[2].position = positions[2]; + mesh_output.vertices[2].color = colors[2] * taskPayload.colorMask; + + mesh_output.primitives[0].index = vec3(0, 1, 2); + mesh_output.primitives[0].cull = !taskPayload.visible; + mesh_output.primitives[0].colorMask = vec4(1.0, 0.0, 1.0, 1.0); +} +@fragment +fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { + return vertex.color * primitive.colorMask; +} diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 1f1396eccff..826dad1c219 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -307,25 +307,6 @@ impl StatementGraph { crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } - S::MeshFunction(crate::MeshFunction::SetMeshOutputs { - vertex_count, - primitive_count, - }) => { - self.dependencies.push((id, vertex_count, "vertex_count")); - self.dependencies - .push((id, primitive_count, "primitive_count")); - "SetMeshOutputs" - } - S::MeshFunction(crate::MeshFunction::SetVertex { index, value }) => { - self.dependencies.push((id, index, "index")); - self.dependencies.push((id, value, "value")); - "SetVertex" - } - S::MeshFunction(crate::MeshFunction::SetPrimitive { index, value }) => { - self.dependencies.push((id, index, "index")); - self.dependencies.push((id, value, "value")); - "SetPrimitive" - } S::SubgroupBallot { result, predicate } => { if let Some(predicate) = predicate { self.dependencies.push((id, predicate, "predicate")); diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 521e2dcade7..062734b049e 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2675,11 +2675,6 @@ impl<'a, W: Write> Writer<'a, W> { self.write_image_atomic(ctx, image, coordinate, array_index, fun, value)? } Statement::RayQuery { .. } => unreachable!(), - Statement::MeshFunction( - crate::MeshFunction::SetMeshOutputs { .. } - | crate::MeshFunction::SetVertex { .. } - | crate::MeshFunction::SetPrimitive { .. }, - ) => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); @@ -5270,7 +5265,11 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s | Bi::PointIndex | Bi::LineIndices | Bi::TriangleIndices - | Bi::MeshTaskSize => { + | Bi::MeshTaskSize + | Bi::VertexCount + | Bi::PrimitiveCount + | Bi::Vertices + | Bi::Primitives => { unimplemented!() } } diff --git a/naga/src/back/hlsl/conv.rs b/naga/src/back/hlsl/conv.rs index b4d7af86ed6..6cd3679e817 100644 --- a/naga/src/back/hlsl/conv.rs +++ b/naga/src/back/hlsl/conv.rs @@ -187,7 +187,11 @@ impl crate::BuiltIn { } Self::CullPrimitive => "SV_CullPrimitive", Self::PointIndex | Self::LineIndices | Self::TriangleIndices => unimplemented!(), - Self::MeshTaskSize => unreachable!(), + Self::MeshTaskSize + | Self::VertexCount + | Self::PrimitiveCount + | Self::Vertices + | Self::Primitives => unreachable!(), }) } } diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index f13601bff91..e55472460e4 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2608,19 +2608,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, ".Abort();")?; } }, - Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { - vertex_count, - primitive_count, - }) => { - write!(self.out, "{level}SetMeshOutputCounts(")?; - self.write_expr(module, vertex_count, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, primitive_count, func_ctx)?; - write!(self.out, ");")?; - } - Statement::MeshFunction( - crate::MeshFunction::SetVertex { .. } | crate::MeshFunction::SetPrimitive { .. }, - ) => unimplemented!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = Baked(result).to_string(); diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index f2df74d8c2d..984084b0700 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -711,10 +711,15 @@ impl ResolvedBinding { Bi::CullDistance | Bi::DrawID => { return Err(Error::UnsupportedBuiltIn(built_in)) } - Bi::CullPrimitive => "primitive_culled", - // TODO: figure out how to make this written as a function call - Bi::PointIndex | Bi::LineIndices | Bi::TriangleIndices => unimplemented!(), - Bi::MeshTaskSize => unreachable!(), + Bi::CullPrimitive + | Bi::PointIndex + | Bi::LineIndices + | Bi::TriangleIndices + | Bi::MeshTaskSize + | Bi::VertexCount + | Bi::PrimitiveCount + | Bi::Vertices + | Bi::Primitives => "TODO_MESH_BUILTIN", }; write!(out, "{name}")?; } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 484142630d2..adb94a55d33 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -410,7 +410,7 @@ impl TypedGlobalVariable<'_> { first_time: false, }; - let (space, access, reference) = match var.space.to_msl_name() { + let (space, access, reference, trailing_attribute) = match var.space.to_msl_name() { Some(space) if self.reference => { let access = if var.space.needs_access_qualifier() && !self.usage.intersects(valid::GlobalUse::WRITE) @@ -419,14 +419,19 @@ impl TypedGlobalVariable<'_> { } else { "" }; - (space, access, "&") + let trailing_attribute = if var.space == crate::AddressSpace::TaskPayload { + " [[payload]]" + } else { + "" + }; + (space, access, "&", trailing_attribute) } - _ => ("", "", ""), + _ => ("", "", "", ""), }; Ok(write!( out, - "{}{}{}{}{}{} {}", + "{}{}{}{}{}{} {}{}", space, if space.is_empty() { "" } else { " " }, ty_name, @@ -434,6 +439,7 @@ impl TypedGlobalVariable<'_> { access, reference, name, + trailing_attribute )?) } } @@ -608,7 +614,7 @@ impl crate::AddressSpace { // may end up with "const" even if the binding is read-write, // and that should be OK. Self::Storage { .. } => true, - Self::TaskPayload => unimplemented!(), + Self::TaskPayload => true, // These should always be read-write. Self::Private | Self::WorkGroup => false, // These translate to `constant` address space, no need for qualifiers. @@ -4063,14 +4069,6 @@ impl Writer { } } } - // TODO: write emitters for these - crate::Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { .. }) => { - unimplemented!() - } - crate::Statement::MeshFunction( - crate::MeshFunction::SetVertex { .. } - | crate::MeshFunction::SetPrimitive { .. }, - ) => unimplemented!(), crate::Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = self.namer.call(""); @@ -6611,26 +6609,42 @@ template self.write_wrapped_functions(module, &ctx)?; - let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage { + let (em_str, in_mode, out_mode, can_vertex_pull, extra_attribute) = match ep.stage { crate::ShaderStage::Vertex => ( - "vertex", + Some("vertex"), LocationMode::VertexInput, LocationMode::VertexOutput, true, + None, ), crate::ShaderStage::Fragment => ( - "fragment", + Some("fragment"), LocationMode::FragmentInput, LocationMode::FragmentOutput, false, + None, ), crate::ShaderStage::Compute => ( - "kernel", + Some("kernel"), + LocationMode::Uniform, + LocationMode::Uniform, + false, + None, + ), + crate::ShaderStage::Task => ( + None, LocationMode::Uniform, LocationMode::Uniform, false, + Some("task"), + ), + crate::ShaderStage::Mesh => ( + None, + LocationMode::Uniform, + LocationMode::Uniform, + false, + Some("mesh"), ), - crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(), }; // Should this entry point be modified to do vertex pulling? @@ -6697,9 +6711,7 @@ template break; } } - crate::AddressSpace::TaskPayload => { - unimplemented!() - } + crate::AddressSpace::TaskPayload => {} crate::AddressSpace::Function | crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => {} @@ -6916,8 +6928,16 @@ template } } + // Mesh/task (object) shaders use `[[mesh]] void ...` syntax instead of `kernel void ...`. + if let Some(extra_attribute) = extra_attribute { + writeln!(self.out, "[[{extra_attribute}]]")?; + } + // Write the entry point function's name, and begin its argument list. - writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?; + if let Some(em_str) = em_str { + write!(self.out, "{em_str} ")?; + } + writeln!(self.out, "{result_type_name} {fun_name}(")?; let mut is_first_argument = true; let mut separator = || { @@ -7122,7 +7142,7 @@ template // the resolves have already been checked for `!fake_missing_bindings` case let resolved = match var.space { crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(), - crate::AddressSpace::WorkGroup => None, + crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => None, _ => options .resolve_resource_binding(ep, var.binding.as_ref().unwrap()) .ok(), diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 109cc591e74..de643b82fab 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -860,26 +860,6 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S crate::RayQueryFunction::Terminate => {} } } - Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { - ref mut vertex_count, - ref mut primitive_count, - }) => { - adjust(vertex_count); - adjust(primitive_count); - } - Statement::MeshFunction( - crate::MeshFunction::SetVertex { - ref mut index, - ref mut value, - } - | crate::MeshFunction::SetPrimitive { - ref mut index, - ref mut value, - }, - ) => { - adjust(index); - adjust(value); - } Statement::Break | Statement::Continue | Statement::Kill diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index d0556acdc53..dd9a3811687 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -3655,7 +3655,6 @@ impl BlockContext<'_> { } => { self.write_subgroup_gather(mode, argument, result, &mut block)?; } - Statement::MeshFunction(_) => unreachable!(), } } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 1beb86577c8..ee1ea847739 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -2156,7 +2156,11 @@ impl Writer { | Bi::CullPrimitive | Bi::PointIndex | Bi::LineIndices - | Bi::TriangleIndices => unreachable!(), + | Bi::TriangleIndices + | Bi::VertexCount + | Bi::PrimitiveCount + | Bi::Vertices + | Bi::Primitives => unreachable!(), }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index d1ebf62e6ee..daf32a7116f 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -856,7 +856,6 @@ impl Writer { } } Statement::RayQuery { .. } => unreachable!(), - Statement::MeshFunction(..) => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); diff --git a/naga/src/common/wgsl/to_wgsl.rs b/naga/src/common/wgsl/to_wgsl.rs index 25847a5df7b..5e6178c049c 100644 --- a/naga/src/common/wgsl/to_wgsl.rs +++ b/naga/src/common/wgsl/to_wgsl.rs @@ -194,7 +194,11 @@ impl TryToWgsl for crate::BuiltIn { | Bi::TriangleIndices | Bi::LineIndices | Bi::MeshTaskSize - | Bi::PointIndex => return None, + | Bi::PointIndex + | Bi::VertexCount + | Bi::PrimitiveCount + | Bi::Vertices + | Bi::Primitives => return None, }) } } diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index a7d3d463f11..2761c7cfaf8 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -226,6 +226,9 @@ pub fn compact(module: &mut crate::Module, keep_unused: KeepUnused) { module_tracer.global_variables_used.insert(task_payload); } if let Some(ref mesh_info) = entry.mesh_info { + module_tracer + .global_variables_used + .insert(mesh_info.output_variable); module_tracer .types_used .insert(mesh_info.vertex_output_type); @@ -385,6 +388,7 @@ pub fn compact(module: &mut crate::Module, keep_unused: KeepUnused) { module_map.globals.adjust(task_payload); } if let Some(ref mut mesh_info) = entry.mesh_info { + module_map.globals.adjust(&mut mesh_info.output_variable); module_map.types.adjust(&mut mesh_info.vertex_output_type); module_map .types diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index b370501baca..39d6065f5f0 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -117,20 +117,6 @@ impl FunctionTracer<'_> { self.expressions_used.insert(query); self.trace_ray_query_function(fun); } - St::MeshFunction(crate::MeshFunction::SetMeshOutputs { - vertex_count, - primitive_count, - }) => { - self.expressions_used.insert(vertex_count); - self.expressions_used.insert(primitive_count); - } - St::MeshFunction( - crate::MeshFunction::SetPrimitive { index, value } - | crate::MeshFunction::SetVertex { index, value }, - ) => { - self.expressions_used.insert(index); - self.expressions_used.insert(value); - } St::SubgroupBallot { result, predicate } => { if let Some(predicate) = predicate { self.expressions_used.insert(predicate); @@ -349,26 +335,6 @@ impl FunctionMap { adjust(query); self.adjust_ray_query_function(fun); } - St::MeshFunction(crate::MeshFunction::SetMeshOutputs { - ref mut vertex_count, - ref mut primitive_count, - }) => { - adjust(vertex_count); - adjust(primitive_count); - } - St::MeshFunction( - crate::MeshFunction::SetVertex { - ref mut index, - ref mut value, - } - | crate::MeshFunction::SetPrimitive { - ref mut index, - ref mut value, - }, - ) => { - adjust(index); - adjust(value); - } St::SubgroupBallot { ref mut result, ref mut predicate, diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 2a3a971a8bf..ac9eaf8306f 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -4661,7 +4661,6 @@ impl> Frontend { | S::Atomic { .. } | S::ImageAtomic { .. } | S::RayQuery { .. } - | S::MeshFunction(..) | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } | S::SubgroupGather { .. } => {} diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 17dab5cb0ea..0cd7e11c737 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -406,6 +406,9 @@ pub(crate) enum Error<'a> { accept_span: Span, accept_type: String, }, + ExpectedGlobalVariable { + name_span: Span, + }, StructMemberTooLarge { member_name_span: Span, }, @@ -1370,6 +1373,11 @@ impl<'a> Error<'a> { ], notes: vec![], }, + Error::ExpectedGlobalVariable { name_span } => ParseError { + message: "expected global variable".to_string(), + labels: vec![(name_span, "variable used here".into())], + notes: vec![], + }, Error::StructMemberTooLarge { member_name_span } => ParseError { message: "struct member is too large".into(), labels: vec![(member_name_span, "this member exceeds the maximum size".into())], diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 2066d7cf2c8..33a1de6d579 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1479,47 +1479,93 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .collect(); if let Some(ref entry) = f.entry_point { - let workgroup_size_info = if let Some(workgroup_size) = entry.workgroup_size { - // TODO: replace with try_map once stabilized - let mut workgroup_size_out = [1; 3]; - let mut workgroup_size_overrides_out = [None; 3]; - for (i, size) in workgroup_size.into_iter().enumerate() { - if let Some(size_expr) = size { - match self.const_u32(size_expr, &mut ctx.as_const()) { - Ok(value) => { - workgroup_size_out[i] = value.0; - } - Err(err) => { - if let Error::ConstantEvaluatorError(ref ty, _) = *err { - match **ty { - proc::ConstantEvaluatorError::OverrideExpr => { - workgroup_size_overrides_out[i] = - Some(self.workgroup_size_override( - size_expr, - &mut ctx.as_override(), - )?); - } - _ => { - return Err(err); + let (workgroup_size, workgroup_size_overrides) = + if let Some(workgroup_size) = entry.workgroup_size { + // TODO: replace with try_map once stabilized + let mut workgroup_size_out = [1; 3]; + let mut workgroup_size_overrides_out = [None; 3]; + for (i, size) in workgroup_size.into_iter().enumerate() { + if let Some(size_expr) = size { + match self.const_u32(size_expr, &mut ctx.as_const()) { + Ok(value) => { + workgroup_size_out[i] = value.0; + } + Err(err) => { + if let Error::ConstantEvaluatorError(ref ty, _) = *err { + match **ty { + proc::ConstantEvaluatorError::OverrideExpr => { + workgroup_size_overrides_out[i] = + Some(self.workgroup_size_override( + size_expr, + &mut ctx.as_override(), + )?); + } + _ => { + return Err(err); + } } + } else { + return Err(err); } - } else { - return Err(err); } } } } - } - if workgroup_size_overrides_out.iter().all(|x| x.is_none()) { - (workgroup_size_out, None) + if workgroup_size_overrides_out.iter().all(|x| x.is_none()) { + (workgroup_size_out, None) + } else { + (workgroup_size_out, Some(workgroup_size_overrides_out)) + } } else { - (workgroup_size_out, Some(workgroup_size_overrides_out)) + ([0; 3], None) + }; + + let mesh_info = if let Some((var_name, var_span)) = entry.mesh_output_variable { + let var = match ctx.globals.get(var_name) { + Some(&LoweredGlobalDecl::Var(handle)) => handle, + Some(_) => { + return Err(Box::new(Error::ExpectedGlobalVariable { + name_span: var_span, + })) + } + None => return Err(Box::new(Error::UnknownIdent(var_span, var_name))), + }; + + let mut info = ctx.module.analyze_mesh_shader_info(var); + if let Some(h) = info.1[0] { + info.0.max_vertices_override = Some( + ctx.module + .global_expressions + .append(crate::Expression::Override(h), Span::UNDEFINED), + ); + } + if let Some(h) = info.1[1] { + info.0.max_primitives_override = Some( + ctx.module + .global_expressions + .append(crate::Expression::Override(h), Span::UNDEFINED), + ); } + + Some(info.0) + } else { + None + }; + + let task_payload = if let Some((var_name, var_span)) = entry.task_payload { + Some(match ctx.globals.get(var_name) { + Some(&LoweredGlobalDecl::Var(handle)) => handle, + Some(_) => { + return Err(Box::new(Error::ExpectedGlobalVariable { + name_span: var_span, + })) + } + None => return Err(Box::new(Error::UnknownIdent(var_span, var_name))), + }) } else { - ([0; 3], None) + None }; - let (workgroup_size, workgroup_size_overrides) = workgroup_size_info; ctx.module.entry_points.push(ir::EntryPoint { name: f.name.name.to_string(), stage: entry.stage, @@ -1527,8 +1573,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { workgroup_size, workgroup_size_overrides, function, - mesh_info: None, - task_payload: None, + mesh_info, + task_payload, }); Ok(LoweredGlobalDecl::EntryPoint( ctx.module.entry_points.len() - 1, @@ -4059,6 +4105,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { interpolation, sampling, blend_src, + per_primitive, }) => { let blend_src = if let Some(blend_src) = blend_src { Some(self.const_u32(blend_src, &mut ctx.as_const())?.0) @@ -4071,7 +4118,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { interpolation, sampling, blend_src, - per_primitive: false, + per_primitive, }; binding.apply_default_interpolation(&ctx.module.types[ty].inner); Some(binding) diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index 345e9c4c486..04964e7ba5f 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -128,6 +128,8 @@ pub struct EntryPoint<'a> { pub stage: crate::ShaderStage, pub early_depth_test: Option, pub workgroup_size: Option<[Option>>; 3]>, + pub mesh_output_variable: Option<(&'a str, Span)>, + pub task_payload: Option<(&'a str, Span)>, } #[cfg(doc)] @@ -152,6 +154,7 @@ pub enum Binding<'a> { interpolation: Option, sampling: Option, blend_src: Option>>, + per_primitive: bool, }, } diff --git a/naga/src/front/wgsl/parse/conv.rs b/naga/src/front/wgsl/parse/conv.rs index de07ba2e391..0303b7ed6bb 100644 --- a/naga/src/front/wgsl/parse/conv.rs +++ b/naga/src/front/wgsl/parse/conv.rs @@ -6,7 +6,11 @@ use crate::Span; use alloc::boxed::Box; -pub fn map_address_space(word: &str, span: Span) -> Result<'_, crate::AddressSpace> { +pub fn map_address_space<'a>( + word: &str, + span: Span, + enable_extensions: &EnableExtensions, +) -> Result<'a, crate::AddressSpace> { match word { "private" => Ok(crate::AddressSpace::Private), "workgroup" => Ok(crate::AddressSpace::WorkGroup), @@ -16,6 +20,16 @@ pub fn map_address_space(word: &str, span: Span) -> Result<'_, crate::AddressSpa }), "push_constant" => Ok(crate::AddressSpace::PushConstant), "function" => Ok(crate::AddressSpace::Function), + "task_payload" => { + if enable_extensions.contains(ImplementedEnableExtension::MeshShader) { + Ok(crate::AddressSpace::TaskPayload) + } else { + Err(Box::new(Error::EnableExtensionNotEnabled { + span, + kind: ImplementedEnableExtension::MeshShader.into(), + })) + } + } _ => Err(Box::new(Error::UnknownAddressSpace(span))), } } @@ -50,6 +64,17 @@ pub fn map_built_in( "subgroup_id" => crate::BuiltIn::SubgroupId, "subgroup_size" => crate::BuiltIn::SubgroupSize, "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, + // mesh + "cull_primitive" => crate::BuiltIn::CullPrimitive, + "point_index" => crate::BuiltIn::PointIndex, + "line_indices" => crate::BuiltIn::LineIndices, + "triangle_indices" => crate::BuiltIn::TriangleIndices, + "mesh_task_size" => crate::BuiltIn::MeshTaskSize, + // mesh global variable + "vertex_count" => crate::BuiltIn::VertexCount, + "vertices" => crate::BuiltIn::Vertices, + "primitive_count" => crate::BuiltIn::PrimitiveCount, + "primitives" => crate::BuiltIn::Primitives, _ => return Err(Box::new(Error::UnknownBuiltin(span))), }; match built_in { @@ -61,6 +86,21 @@ pub fn map_built_in( })); } } + crate::BuiltIn::CullPrimitive + | crate::BuiltIn::PointIndex + | crate::BuiltIn::LineIndices + | crate::BuiltIn::TriangleIndices + | crate::BuiltIn::VertexCount + | crate::BuiltIn::Vertices + | crate::BuiltIn::PrimitiveCount + | crate::BuiltIn::Primitives => { + if !enable_extensions.contains(ImplementedEnableExtension::MeshShader) { + return Err(Box::new(Error::EnableExtensionNotEnabled { + span, + kind: ImplementedEnableExtension::MeshShader.into(), + })); + } + } _ => {} } Ok(built_in) diff --git a/naga/src/front/wgsl/parse/directive/enable_extension.rs b/naga/src/front/wgsl/parse/directive/enable_extension.rs index 38d6d6719ca..d376c114ff0 100644 --- a/naga/src/front/wgsl/parse/directive/enable_extension.rs +++ b/naga/src/front/wgsl/parse/directive/enable_extension.rs @@ -10,6 +10,7 @@ use alloc::boxed::Box; /// Tracks the status of every enable-extension known to Naga. #[derive(Clone, Debug, Eq, PartialEq)] pub struct EnableExtensions { + mesh_shader: bool, dual_source_blending: bool, /// Whether `enable f16;` was written earlier in the shader module. f16: bool, @@ -19,6 +20,7 @@ pub struct EnableExtensions { impl EnableExtensions { pub(crate) const fn empty() -> Self { Self { + mesh_shader: false, f16: false, dual_source_blending: false, clip_distances: false, @@ -28,6 +30,7 @@ impl EnableExtensions { /// Add an enable-extension to the set requested by a module. pub(crate) fn add(&mut self, ext: ImplementedEnableExtension) { let field = match ext { + ImplementedEnableExtension::MeshShader => &mut self.mesh_shader, ImplementedEnableExtension::DualSourceBlending => &mut self.dual_source_blending, ImplementedEnableExtension::F16 => &mut self.f16, ImplementedEnableExtension::ClipDistances => &mut self.clip_distances, @@ -38,6 +41,7 @@ impl EnableExtensions { /// Query whether an enable-extension tracked here has been requested. pub(crate) const fn contains(&self, ext: ImplementedEnableExtension) -> bool { match ext { + ImplementedEnableExtension::MeshShader => self.mesh_shader, ImplementedEnableExtension::DualSourceBlending => self.dual_source_blending, ImplementedEnableExtension::F16 => self.f16, ImplementedEnableExtension::ClipDistances => self.clip_distances, @@ -70,6 +74,7 @@ impl EnableExtension { const F16: &'static str = "f16"; const CLIP_DISTANCES: &'static str = "clip_distances"; const DUAL_SOURCE_BLENDING: &'static str = "dual_source_blending"; + const MESH_SHADER: &'static str = "mesh_shading"; const SUBGROUPS: &'static str = "subgroups"; const PRIMITIVE_INDEX: &'static str = "primitive_index"; @@ -81,6 +86,7 @@ impl EnableExtension { Self::DUAL_SOURCE_BLENDING => { Self::Implemented(ImplementedEnableExtension::DualSourceBlending) } + Self::MESH_SHADER => Self::Implemented(ImplementedEnableExtension::MeshShader), Self::SUBGROUPS => Self::Unimplemented(UnimplementedEnableExtension::Subgroups), Self::PRIMITIVE_INDEX => { Self::Unimplemented(UnimplementedEnableExtension::PrimitiveIndex) @@ -93,6 +99,7 @@ impl EnableExtension { pub const fn to_ident(self) -> &'static str { match self { Self::Implemented(kind) => match kind { + ImplementedEnableExtension::MeshShader => Self::MESH_SHADER, ImplementedEnableExtension::DualSourceBlending => Self::DUAL_SOURCE_BLENDING, ImplementedEnableExtension::F16 => Self::F16, ImplementedEnableExtension::ClipDistances => Self::CLIP_DISTANCES, @@ -126,6 +133,8 @@ pub enum ImplementedEnableExtension { /// /// [`enable clip_distances;`]: https://www.w3.org/TR/WGSL/#extension-clip_distances ClipDistances, + /// Enables the `mesh_shader` extension, native only + MeshShader, } /// A variant of [`EnableExtension::Unimplemented`]. diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index c01ba4de30f..e4c04644347 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -178,6 +178,7 @@ struct BindingParser<'a> { sampling: ParsedAttribute, invariant: ParsedAttribute, blend_src: ParsedAttribute>>, + per_primitive: ParsedAttribute<()>, } impl<'a> BindingParser<'a> { @@ -238,6 +239,18 @@ impl<'a> BindingParser<'a> { lexer.skip(Token::Separator(',')); lexer.expect(Token::Paren(')'))?; } + "per_primitive" => { + if !lexer + .enable_extensions + .contains(ImplementedEnableExtension::MeshShader) + { + return Err(Box::new(Error::EnableExtensionNotEnabled { + span: name_span, + kind: ImplementedEnableExtension::MeshShader.into(), + })); + } + self.per_primitive.set((), name_span)?; + } _ => return Err(Box::new(Error::UnknownAttribute(name_span))), } Ok(()) @@ -251,9 +264,10 @@ impl<'a> BindingParser<'a> { self.sampling.value, self.invariant.value.unwrap_or_default(), self.blend_src.value, + self.per_primitive.value, ) { - (None, None, None, None, false, None) => Ok(None), - (Some(location), None, interpolation, sampling, false, blend_src) => { + (None, None, None, None, false, None, None) => Ok(None), + (Some(location), None, interpolation, sampling, false, blend_src, per_primitive) => { // Before handing over the completed `Module`, we call // `apply_default_interpolation` to ensure that the interpolation and // sampling have been explicitly specified on all vertex shader output and fragment @@ -263,17 +277,18 @@ impl<'a> BindingParser<'a> { interpolation, sampling, blend_src, + per_primitive: per_primitive.is_some(), })) } - (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant, None) => { + (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant, None, None) => { Ok(Some(ast::Binding::BuiltIn(crate::BuiltIn::Position { invariant, }))) } - (None, Some(built_in), None, None, false, None) => { + (None, Some(built_in), None, None, false, None, None) => { Ok(Some(ast::Binding::BuiltIn(built_in))) } - (_, _, _, _, _, _) => Err(Box::new(Error::InconsistentBinding(span))), + (_, _, _, _, _, _, _) => Err(Box::new(Error::InconsistentBinding(span))), } } } @@ -1318,7 +1333,7 @@ impl Parser { }; crate::AddressSpace::Storage { access } } - _ => conv::map_address_space(class_str, span)?, + _ => conv::map_address_space(class_str, span, &lexer.enable_extensions)?, }; lexer.expect(Token::Paren('>'))?; } @@ -1691,7 +1706,7 @@ impl Parser { "ptr" => { lexer.expect_generic_paren('<')?; let (ident, span) = lexer.next_ident_with_span()?; - let mut space = conv::map_address_space(ident, span)?; + let mut space = conv::map_address_space(ident, span, &lexer.enable_extensions)?; lexer.expect(Token::Separator(','))?; let base = self.type_decl(lexer, ctx)?; if let crate::AddressSpace::Storage { ref mut access } = space { @@ -2790,12 +2805,14 @@ impl Parser { // read attributes let mut binding = None; let mut stage = ParsedAttribute::default(); - let mut compute_span = Span::new(0, 0); + let mut compute_like_span = Span::new(0, 0); let mut workgroup_size = ParsedAttribute::default(); let mut early_depth_test = ParsedAttribute::default(); let (mut bind_index, mut bind_group) = (ParsedAttribute::default(), ParsedAttribute::default()); let mut id = ParsedAttribute::default(); + let mut payload = ParsedAttribute::default(); + let mut mesh_output = ParsedAttribute::default(); let mut must_use: ParsedAttribute = ParsedAttribute::default(); @@ -2854,7 +2871,51 @@ impl Parser { } "compute" => { stage.set(ShaderStage::Compute, name_span)?; - compute_span = name_span; + compute_like_span = name_span; + } + "task" => { + if !lexer + .enable_extensions + .contains(ImplementedEnableExtension::MeshShader) + { + return Err(Box::new(Error::EnableExtensionNotEnabled { + span: name_span, + kind: ImplementedEnableExtension::MeshShader.into(), + })); + } + stage.set(ShaderStage::Task, name_span)?; + compute_like_span = name_span; + } + "mesh" => { + if !lexer + .enable_extensions + .contains(ImplementedEnableExtension::MeshShader) + { + return Err(Box::new(Error::EnableExtensionNotEnabled { + span: name_span, + kind: ImplementedEnableExtension::MeshShader.into(), + })); + } + stage.set(ShaderStage::Mesh, name_span)?; + compute_like_span = name_span; + + lexer.expect(Token::Paren('('))?; + mesh_output.set(lexer.next_ident_with_span()?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } + "payload" => { + if !lexer + .enable_extensions + .contains(ImplementedEnableExtension::MeshShader) + { + return Err(Box::new(Error::EnableExtensionNotEnabled { + span: name_span, + kind: ImplementedEnableExtension::MeshShader.into(), + })); + } + lexer.expect(Token::Paren('('))?; + payload.set(lexer.next_ident_with_span()?, name_span)?; + lexer.expect(Token::Paren(')'))?; } "workgroup_size" => { lexer.expect(Token::Paren('('))?; @@ -3020,13 +3081,16 @@ impl Parser { )?; Some(ast::GlobalDeclKind::Fn(ast::Function { entry_point: if let Some(stage) = stage.value { - if stage == ShaderStage::Compute && workgroup_size.value.is_none() { - return Err(Box::new(Error::MissingWorkgroupSize(compute_span))); + if stage.compute_like() && workgroup_size.value.is_none() { + return Err(Box::new(Error::MissingWorkgroupSize(compute_like_span))); } + Some(ast::EntryPoint { stage, early_depth_test: early_depth_test.value, workgroup_size: workgroup_size.value, + mesh_output_variable: mesh_output.value, + task_payload: payload.value, }) } else { None diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 4093d823b4b..c3deabe706d 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -450,6 +450,15 @@ pub enum BuiltIn { LineIndices, /// Written in mesh shaders TriangleIndices, + + /// Written to a workgroup variable in mesh shaders + VertexCount, + /// Written to a workgroup variable in mesh shaders + Vertices, + /// Written to a workgroup variable in mesh shaders + PrimitiveCount, + /// Written to a workgroup variable in mesh shaders + Primitives, } /// Number of bytes per scalar. @@ -2211,8 +2220,6 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, - /// A mesh shader intrinsic. - MeshFunction(MeshFunction), /// Calculate a bitmask using a boolean from each active thread in the subgroup SubgroupBallot { /// The [`SubgroupBallotResult`] expression representing this load's result. @@ -2569,21 +2576,21 @@ pub struct DocComments { } /// The output topology for a mesh shader. Note that mesh shaders don't allow things like triangle-strips. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum MeshOutputTopology { /// Outputs individual vertices to be rendered as points. Points, - /// Outputs groups of 2 vertices to be renderedas lines . + /// Outputs groups of 2 vertices to be rendered as lines. Lines, /// Outputs groups of 3 vertices to be rendered as triangles. Triangles, } /// Information specific to mesh shader entry points. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -2603,29 +2610,8 @@ pub struct MeshStageInfo { pub vertex_output_type: Handle, /// The type used by primitive outputs, i.e. what is passed to `setPrimitive`. pub primitive_output_type: Handle, -} - -/// Mesh shader intrinsics -#[derive(Debug, Clone, Copy)] -#[cfg_attr(feature = "serialize", derive(Serialize))] -#[cfg_attr(feature = "deserialize", derive(Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum MeshFunction { - /// Sets the number of vertices and primitives that will be outputted. - SetMeshOutputs { - vertex_count: Handle, - primitive_count: Handle, - }, - /// Sets the output vertex at a given index. - SetVertex { - index: Handle, - value: Handle, - }, - /// Sets the output primitive at a given index. - SetPrimitive { - index: Handle, - value: Handle, - }, + /// The global variable holding the outputted vertices, primitives, and counts + pub output_variable: Handle, } /// Shader module. diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index eca63ee4fb5..64da0a9661e 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -27,6 +27,8 @@ use thiserror::Error; pub use type_methods::min_max_float_representable_by; pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution}; +use crate::non_max_u32::NonMaxU32; + impl From for super::Scalar { fn from(format: super::StorageFormat) -> Self { use super::{ScalarKind as Sk, StorageFormat as Sf}; @@ -653,3 +655,178 @@ fn test_matrix_size() { 48, ); } + +impl crate::Module { + /// Extracts mesh shader info from a mesh output global variable. Used in frontends + /// and by validators. This only validates the output variable itself, and not the + /// vertex and primitive output types. + /// + /// The output contains the extracted mesh stage info, with overrides unset, + /// and then the overrides separately. This is because the overrides should be + /// treated as expressions elsewhere, but that requires mutably modifying the + /// module and the expressions should only be created at parse time, not validation + /// time. + #[allow(clippy::type_complexity)] + pub fn analyze_mesh_shader_info( + &self, + gv: crate::Handle, + ) -> ( + crate::MeshStageInfo, + [Option>; 2], + Option>, + ) { + use crate::span::AddSpan; + use crate::valid::EntryPointError; + #[derive(Default)] + struct OutError { + pub inner: Option, + } + impl OutError { + pub fn set(&mut self, err: EntryPointError) { + if self.inner.is_none() { + self.inner = Some(err); + } + } + } + + // Used to temporarily initialize stuff + let null_type = crate::Handle::new(NonMaxU32::new(0).unwrap()); + let mut output = crate::MeshStageInfo { + topology: crate::MeshOutputTopology::Triangles, + max_vertices: 0, + max_vertices_override: None, + max_primitives: 0, + max_primitives_override: None, + vertex_output_type: null_type, + primitive_output_type: null_type, + output_variable: gv, + }; + // Stores the error to output, if any. + let mut error = OutError::default(); + let r#type = &self.types[self.global_variables[gv].ty].inner; + + let mut topology = output.topology; + // Max, max override, type + let mut vertex_info = (0, None, null_type); + let mut primitive_info = (0, None, null_type); + + match r#type { + &crate::TypeInner::Struct { ref members, .. } => { + let mut builtins = crate::FastHashSet::default(); + for member in members { + match member.binding { + Some(crate::Binding::BuiltIn(crate::BuiltIn::VertexCount)) => { + // Must have type u32 + if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) { + error.set(EntryPointError::BadMeshOutputVariableField); + } + // Each builtin should only occur once + if builtins.contains(&crate::BuiltIn::VertexCount) { + error.set(EntryPointError::BadMeshOutputVariableType); + } + builtins.insert(crate::BuiltIn::VertexCount); + } + Some(crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveCount)) => { + // Must have type u32 + if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) { + error.set(EntryPointError::BadMeshOutputVariableField); + } + // Each builtin should only occur once + if builtins.contains(&crate::BuiltIn::PrimitiveCount) { + error.set(EntryPointError::BadMeshOutputVariableType); + } + builtins.insert(crate::BuiltIn::PrimitiveCount); + } + Some(crate::Binding::BuiltIn( + crate::BuiltIn::Vertices | crate::BuiltIn::Primitives, + )) => { + let ty = &self.types[member.ty].inner; + // Analyze the array type to determine size and vertex/primitive type + let (a, b, c) = match ty { + &crate::TypeInner::Array { base, size, .. } => { + let ty = base; + let (max, max_override) = match size { + crate::ArraySize::Constant(a) => (a.get(), None), + crate::ArraySize::Pending(o) => (0, Some(o)), + crate::ArraySize::Dynamic => { + error.set(EntryPointError::BadMeshOutputVariableField); + (0, None) + } + }; + (max, max_override, ty) + } + _ => { + error.set(EntryPointError::BadMeshOutputVariableField); + (0, None, null_type) + } + }; + if matches!( + member.binding, + Some(crate::Binding::BuiltIn(crate::BuiltIn::Primitives)) + ) { + // Primitives require special analysis to determine topology + primitive_info = (a, b, c); + match self.types[c].inner { + crate::TypeInner::Struct { ref members, .. } => { + for member in members { + match member.binding { + Some(crate::Binding::BuiltIn( + crate::BuiltIn::PointIndex, + )) => { + topology = crate::MeshOutputTopology::Points; + } + Some(crate::Binding::BuiltIn( + crate::BuiltIn::LineIndices, + )) => { + topology = crate::MeshOutputTopology::Lines; + } + Some(crate::Binding::BuiltIn( + crate::BuiltIn::TriangleIndices, + )) => { + topology = crate::MeshOutputTopology::Triangles; + } + _ => (), + } + } + } + _ => (), + } + // Each builtin should only occur once + if builtins.contains(&crate::BuiltIn::Primitives) { + error.set(EntryPointError::BadMeshOutputVariableType); + } + builtins.insert(crate::BuiltIn::Primitives); + } else { + vertex_info = (a, b, c); + // Each builtin should only occur once + if builtins.contains(&crate::BuiltIn::Vertices) { + error.set(EntryPointError::BadMeshOutputVariableType); + } + builtins.insert(crate::BuiltIn::Vertices); + } + } + _ => error.set(EntryPointError::BadMeshOutputVariableType), + } + } + output = crate::MeshStageInfo { + topology, + max_vertices: vertex_info.0, + max_vertices_override: None, + vertex_output_type: vertex_info.2, + max_primitives: primitive_info.0, + max_primitives_override: None, + primitive_output_type: primitive_info.2, + ..output + } + } + _ => error.set(EntryPointError::BadMeshOutputVariableType), + } + ( + output, + [vertex_info.1, primitive_info.1], + error + .inner + .map(|a| a.with_span_handle(self.global_variables[gv].ty, &self.types)), + ) + } +} diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index f76d4c06a3b..b29ccb054a3 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -36,7 +36,6 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::ImageStore { .. } | S::Call { .. } | S::RayQuery { .. } - | S::MeshFunction(..) | S::Atomic { .. } | S::ImageAtomic { .. } | S::WorkGroupUniformLoad { .. } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 6ef2ca0988d..e01a7b0b735 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -85,25 +85,6 @@ struct FunctionUniformity { exit: ExitFlags, } -/// Mesh shader related characteristics of a function. -#[derive(Debug, Clone, Default)] -#[cfg_attr(feature = "serialize", derive(serde::Serialize))] -#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] -#[cfg_attr(test, derive(PartialEq))] -pub struct FunctionMeshShaderInfo { - /// The type of value this function passes to [`SetVertex`], and the - /// expression that first established it. - /// - /// [`SetVertex`]: crate::ir::MeshFunction::SetVertex - pub vertex_type: Option<(Handle, Handle)>, - - /// The type of value this function passes to [`SetPrimitive`], and the - /// expression that first established it. - /// - /// [`SetPrimitive`]: crate::ir::MeshFunction::SetPrimitive - pub primitive_type: Option<(Handle, Handle)>, -} - impl ops::BitOr for FunctionUniformity { type Output = Self; fn bitor(self, other: Self) -> Self { @@ -321,9 +302,6 @@ pub struct FunctionInfo { /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in /// validation. diagnostic_filter_leaf: Option>, - - /// Mesh shader info for this function and its callees. - pub mesh_shader_info: FunctionMeshShaderInfo, } impl FunctionInfo { @@ -520,9 +498,6 @@ impl FunctionInfo { *mine |= *other; } - // Inherit mesh output types from our callees. - self.try_update_mesh_info(&callee.mesh_shader_info)?; - Ok(FunctionUniformity { result: callee.uniformity.clone(), exit: if callee.may_kill { @@ -1155,36 +1130,6 @@ impl FunctionInfo { } FunctionUniformity::new() } - S::MeshFunction(func) => { - self.available_stages |= ShaderStages::MESH; - match &func { - // TODO: double check all of this uniformity stuff. I frankly don't fully understand all of it. - &crate::MeshFunction::SetMeshOutputs { - vertex_count, - primitive_count, - } => { - let _ = self.add_ref(vertex_count); - let _ = self.add_ref(primitive_count); - FunctionUniformity::new() - } - &crate::MeshFunction::SetVertex { index, value } - | &crate::MeshFunction::SetPrimitive { index, value } => { - let _ = self.add_ref(index); - let _ = self.add_ref(value); - let ty = self.expressions[value.index()].ty.handle().ok_or( - FunctionError::InvalidMeshShaderOutputType(value).with_span(), - )?; - - if matches!(func, crate::MeshFunction::SetVertex { .. }) { - self.try_update_mesh_vertex_type(ty, value)?; - } else { - self.try_update_mesh_primitive_type(ty, value)?; - }; - - FunctionUniformity::new() - } - } - } S::SubgroupBallot { result: _, predicate, @@ -1230,72 +1175,6 @@ impl FunctionInfo { } Ok(combined_uniformity) } - - /// Note the type of value passed to [`SetVertex`]. - /// - /// Record that this function passed a value of type `ty` as the second - /// argument to the [`SetVertex`] builtin function. All calls to - /// `SetVertex` must pass the same type, and this must match the - /// function's [`vertex_output_type`]. - /// - /// [`SetVertex`]: crate::ir::MeshFunction::SetVertex - /// [`vertex_output_type`]: crate::ir::MeshStageInfo::vertex_output_type - fn try_update_mesh_vertex_type( - &mut self, - ty: Handle, - value: Handle, - ) -> Result<(), WithSpan> { - if let &Some(ref existing) = &self.mesh_shader_info.vertex_type { - if existing.0 != ty { - return Err( - FunctionError::ConflictingMeshOutputTypes(existing.1, value).with_span() - ); - } - } else { - self.mesh_shader_info.vertex_type = Some((ty, value)); - } - Ok(()) - } - - /// Note the type of value passed to [`SetPrimitive`]. - /// - /// Record that this function passed a value of type `ty` as the second - /// argument to the [`SetPrimitive`] builtin function. All calls to - /// `SetPrimitive` must pass the same type, and this must match the - /// function's [`primitive_output_type`]. - /// - /// [`SetPrimitive`]: crate::ir::MeshFunction::SetPrimitive - /// [`primitive_output_type`]: crate::ir::MeshStageInfo::primitive_output_type - fn try_update_mesh_primitive_type( - &mut self, - ty: Handle, - value: Handle, - ) -> Result<(), WithSpan> { - if let &Some(ref existing) = &self.mesh_shader_info.primitive_type { - if existing.0 != ty { - return Err( - FunctionError::ConflictingMeshOutputTypes(existing.1, value).with_span() - ); - } - } else { - self.mesh_shader_info.primitive_type = Some((ty, value)); - } - Ok(()) - } - - /// Update this function's mesh shader info, given that it calls `callee`. - fn try_update_mesh_info( - &mut self, - callee: &FunctionMeshShaderInfo, - ) -> Result<(), WithSpan> { - if let &Some(ref other_vertex) = &callee.vertex_type { - self.try_update_mesh_vertex_type(other_vertex.0, other_vertex.1)?; - } - if let &Some(ref other_primitive) = &callee.primitive_type { - self.try_update_mesh_primitive_type(other_primitive.0, other_primitive.1)?; - } - Ok(()) - } } impl ModuleInfo { @@ -1331,7 +1210,6 @@ impl ModuleInfo { sampling: crate::FastHashSet::default(), dual_source_blending: false, diagnostic_filter_leaf: fun.diagnostic_filter_leaf, - mesh_shader_info: FunctionMeshShaderInfo::default(), }; let resolve_context = ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments); @@ -1465,7 +1343,6 @@ fn uniform_control_flow() { sampling: crate::FastHashSet::default(), dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: FunctionMeshShaderInfo::default(), }; let resolve_context = ResolveContext { constants: &Arena::new(), diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 0216c6ef7f6..abf6bc430a6 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -1547,41 +1547,6 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } - S::MeshFunction(func) => { - let ensure_u32 = - |expr: Handle| -> Result<(), WithSpan> { - let u32_ty = TypeResolution::Value(Ti::Scalar(crate::Scalar::U32)); - let ty = context - .resolve_type_impl(expr, &self.valid_expression_set) - .map_err_inner(|source| { - FunctionError::Expression { - source, - handle: expr, - } - .with_span_handle(expr, context.expressions) - })?; - if !context.compare_types(&u32_ty, ty) { - return Err(FunctionError::InvalidMeshFunctionCall(expr) - .with_span_handle(expr, context.expressions)); - } - Ok(()) - }; - match func { - crate::MeshFunction::SetMeshOutputs { - vertex_count, - primitive_count, - } => { - ensure_u32(vertex_count)?; - ensure_u32(primitive_count)?; - } - crate::MeshFunction::SetVertex { index, value: _ } - | crate::MeshFunction::SetPrimitive { index, value: _ } => { - ensure_u32(index)?; - // Value is validated elsewhere (since the value type isn't known ahead of time but must match for all calls - // in a function or the function's called functions) - } - } - } S::SubgroupBallot { result, predicate } => { stages &= self.subgroup_stages; if !self.capabilities.contains(super::Capabilities::SUBGROUP) { diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index adb9f355c11..5b7fb3fab75 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -237,6 +237,7 @@ impl super::Validator { Self::validate_global_variable_handle(task_payload, global_variables)?; } if let Some(ref mesh_info) = entry_point.mesh_info { + Self::validate_global_variable_handle(mesh_info.output_variable, global_variables)?; validate_type(mesh_info.vertex_output_type)?; validate_type(mesh_info.primitive_output_type)?; for ov in mesh_info @@ -815,22 +816,6 @@ impl super::Validator { } Ok(()) } - crate::Statement::MeshFunction(func) => match func { - crate::MeshFunction::SetMeshOutputs { - vertex_count, - primitive_count, - } => { - validate_expr(vertex_count)?; - validate_expr(primitive_count)?; - Ok(()) - } - crate::MeshFunction::SetVertex { index, value } - | crate::MeshFunction::SetPrimitive { index, value } => { - validate_expr(index)?; - validate_expr(value)?; - Ok(()) - } - }, crate::Statement::SubgroupBallot { result, predicate } => { validate_expr_opt(predicate)?; validate_expr(result)?; diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 779667c04ec..a040fd1604d 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -98,8 +98,6 @@ pub enum VaryingError { InvalidPerPrimitive, #[error("Non-builtin members of a mesh primitive output struct must be decorated with `@per_primitive`")] MissingPerPrimitive, - #[error("The `MESH_SHADER` capability must be enabled to use per-primitive fragment inputs.")] - PerPrimitiveNotAllowed, } #[derive(Clone, Debug, thiserror::Error)] @@ -141,24 +139,31 @@ pub enum EntryPointError { TaskPayloadWrongAddressSpace, #[error("For a task payload to be used, it must be declared with @payload")] WrongTaskPayloadUsed, - #[error("A function can only set vertex and primitive types that correspond to the mesh shader attributes")] - WrongMeshOutputType, - #[error("Only mesh shader entry points can write to mesh output vertices and primitives")] - UnexpectedMeshShaderOutput, - #[error("Mesh shader entry point cannot have a return type")] - UnexpectedMeshShaderEntryResult, #[error("Task shader entry point must return @builtin(mesh_task_size) vec3")] WrongTaskShaderEntryResult, - #[error("Mesh output type must be a user-defined struct.")] - InvalidMeshOutputType, - #[error("Mesh primitive outputs must have exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`")] - InvalidMeshPrimitiveOutputType, #[error("Task shaders must declare a task payload output")] ExpectedTaskPayload, #[error( - "The `MESH_SHADER` capability must be enabled to compile mesh shaders and task shaders." + "The `MESH_SHADER` capability must be enabled to compile mesh shaders and task shaders" )] MeshShaderCapabilityDisabled, + + #[error( + "Mesh shader output variable must be a struct with fields that are all allowed builtins" + )] + BadMeshOutputVariableType, + #[error("Mesh shader output variable fields must have types that are in accordance with the mesh shader spec")] + BadMeshOutputVariableField, + #[error("Mesh shader entry point cannot have a return type")] + UnexpectedMeshShaderEntryResult, + #[error( + "Mesh output type must be a user-defined struct with fields in alignment with the mesh shader spec" + )] + InvalidMeshOutputType, + #[error("Mesh primitive outputs must have exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`")] + InvalidMeshPrimitiveOutputType, + #[error("Mesh output global variable must live in the workgroup address space")] + WrongMeshOutputAddressSpace, } fn storage_usage(access: crate::StorageAccess) -> GlobalUse { @@ -390,7 +395,29 @@ impl VaryingContext<'_> { scalar: crate::Scalar::U32, }, ), + // Validated elsewhere, shouldn't be here + Bi::VertexCount | Bi::PrimitiveCount | Bi::Vertices | Bi::Primitives => { + (false, true) + } }; + match built_in { + Bi::CullPrimitive + | Bi::PointIndex + | Bi::LineIndices + | Bi::TriangleIndices + | Bi::MeshTaskSize + | Bi::VertexCount + | Bi::PrimitiveCount + | Bi::Vertices + | Bi::Primitives => { + if !self.capabilities.contains(Capabilities::MESH_SHADER) { + return Err(VaryingError::UnsupportedCapability( + Capabilities::MESH_SHADER, + )); + } + } + _ => (), + } if !visible { return Err(VaryingError::InvalidBuiltInStage(built_in)); @@ -408,7 +435,9 @@ impl VaryingContext<'_> { per_primitive, } => { if per_primitive && !self.capabilities.contains(Capabilities::MESH_SHADER) { - return Err(VaryingError::PerPrimitiveNotAllowed); + return Err(VaryingError::UnsupportedCapability( + Capabilities::MESH_SHADER, + )); } // Only IO-shareable types may be stored in locations. if !self.type_info[ty.index()] @@ -868,6 +897,7 @@ impl super::Validator { (crate::ShaderStage::Mesh, &None) => { return Err(EntryPointError::ExpectedMeshShaderAttributes.with_span()); } + (crate::ShaderStage::Mesh, &Some(..)) => {} (_, &Some(_)) => { return Err(EntryPointError::UnexpectedMeshShaderAttributes.with_span()); } @@ -1074,21 +1104,38 @@ impl super::Validator { // If this is a `Mesh` entry point, check its vertex and primitive output types. // We verified previously that only mesh shaders can have `mesh_info`. if let &Some(ref mesh_info) = &ep.mesh_info { - // Mesh shaders don't return any value. All their results are supplied through - // [`SetVertex`] and [`SetPrimitive`] calls. - if let Some((used_vertex_type, _)) = info.mesh_shader_info.vertex_type { - if used_vertex_type != mesh_info.vertex_output_type { - return Err(EntryPointError::WrongMeshOutputType - .with_span_handle(mesh_info.vertex_output_type, &module.types)); + if module.global_variables[mesh_info.output_variable].space + != crate::AddressSpace::WorkGroup + { + return Err(EntryPointError::WrongMeshOutputAddressSpace.with_span()); + } + + let mut implied = module.analyze_mesh_shader_info(mesh_info.output_variable); + if let Some(e) = implied.2 { + return Err(e); + } + + if let Some(e) = mesh_info.max_vertices_override { + if let crate::Expression::Override(o) = module.global_expressions[e] { + if implied.1[0] != Some(o) { + return Err(EntryPointError::BadMeshOutputVariableType.with_span()); + } } } - if let Some((used_primitive_type, _)) = info.mesh_shader_info.primitive_type { - if used_primitive_type != mesh_info.primitive_output_type { - return Err(EntryPointError::WrongMeshOutputType - .with_span_handle(mesh_info.primitive_output_type, &module.types)); + if let Some(e) = mesh_info.max_primitives_override { + if let crate::Expression::Override(o) = module.global_expressions[e] { + if implied.1[1] != Some(o) { + return Err(EntryPointError::BadMeshOutputVariableType.with_span()); + } } } + implied.0.max_vertices_override = mesh_info.max_vertices_override; + implied.0.max_primitives_override = mesh_info.max_primitives_override; + if implied.0 != *mesh_info { + return Err(EntryPointError::BadMeshOutputVariableType.with_span()); + } + self.validate_mesh_output_type( ep, module, @@ -1101,14 +1148,7 @@ impl super::Validator { mesh_info.primitive_output_type, MeshOutputType::PrimitiveOutput, )?; - } else { - // This is not a `Mesh` entry point, so ensure that it never tries to produce - // vertices or primitives. - if info.mesh_shader_info.vertex_type.is_some() - || info.mesh_shader_info.primitive_type.is_some() - { - return Err(EntryPointError::UnexpectedMeshShaderOutput.with_span()); - } + info.insert_global_use(GlobalUse::READ, mesh_info.output_variable); } Ok(info) diff --git a/naga/tests/in/wgsl/mesh-shader.toml b/naga/tests/in/wgsl/mesh-shader.toml new file mode 100644 index 00000000000..accbae9f2de --- /dev/null +++ b/naga/tests/in/wgsl/mesh-shader.toml @@ -0,0 +1,19 @@ +# Stolen from ray-query.toml + +god_mode = true +targets = "IR | ANALYSIS | METAL" + +[msl] +fake_missing_bindings = true +lang_version = [2, 4] +spirv_cross_compatibility = false +zero_initialize_workgroup_memory = false + +[hlsl] +shader_model = "V6_5" +fake_missing_bindings = true +zero_initialize_workgroup_memory = true + +[spv] +version = [1, 4] +capabilities = ["MeshShadingEXT"] diff --git a/naga/tests/in/wgsl/mesh-shader.wgsl b/naga/tests/in/wgsl/mesh-shader.wgsl new file mode 100644 index 00000000000..cdc7366b415 --- /dev/null +++ b/naga/tests/in/wgsl/mesh-shader.wgsl @@ -0,0 +1,74 @@ +enable mesh_shading; + +const positions = array( + vec4(0., 1., 0., 1.), + vec4(-1., -1., 0., 1.), + vec4(1., -1., 0., 1.) +); +const colors = array( + vec4(0., 1., 0., 1.), + vec4(0., 0., 1., 1.), + vec4(1., 0., 0., 1.) +); +struct TaskPayload { + colorMask: vec4, + visible: bool, +} +var taskPayload: TaskPayload; +var workgroupData: f32; +struct VertexOutput { + @builtin(position) position: vec4, + @location(0) color: vec4, +} +struct PrimitiveOutput { + @builtin(triangle_indices) index: vec3, + @builtin(cull_primitive) cull: bool, + @per_primitive @location(1) colorMask: vec4, +} +struct PrimitiveInput { + @per_primitive @location(1) colorMask: vec4, +} + +@task +@payload(taskPayload) +@workgroup_size(1) +fn ts_main() -> @builtin(mesh_task_size) vec3 { + workgroupData = 1.0; + taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + return vec3(3, 1, 1); +} + +struct MeshOutput { + @builtin(vertices) vertices: array, + @builtin(primitives) primitives: array, + @builtin(vertex_count) vertex_count: u32, + @builtin(primitive_count) primitive_count: u32, +} + +var mesh_output: MeshOutput; +@mesh(mesh_output) +@payload(taskPayload) +@workgroup_size(1) +fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3) { + mesh_output.vertex_count = 3; + mesh_output.primitive_count = 1; + workgroupData = 2.0; + + mesh_output.vertices[0].position = positions[0]; + mesh_output.vertices[0].color = colors[0] * taskPayload.colorMask; + + mesh_output.vertices[1].position = positions[1]; + mesh_output.vertices[1].color = colors[1] * taskPayload.colorMask; + + mesh_output.vertices[2].position = positions[2]; + mesh_output.vertices[2].color = colors[2] * taskPayload.colorMask; + + mesh_output.primitives[0].index = vec3(0, 1, 2); + mesh_output.primitives[0].cull = !taskPayload.visible; + mesh_output.primitives[0].colorMask = vec4(1.0, 0.0, 1.0, 1.0); +} +@fragment +fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { + return vertex.color * primitive.colorMask; +} diff --git a/naga/tests/out/analysis/spv-shadow.info.ron b/naga/tests/out/analysis/spv-shadow.info.ron index b08a28438ed..381f841d5d9 100644 --- a/naga/tests/out/analysis/spv-shadow.info.ron +++ b/naga/tests/out/analysis/spv-shadow.info.ron @@ -413,10 +413,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -1595,10 +1591,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], entry_points: [ @@ -1693,10 +1685,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-access.info.ron b/naga/tests/out/analysis/wgsl-access.info.ron index d297b09a404..c22cd768f2e 100644 --- a/naga/tests/out/analysis/wgsl-access.info.ron +++ b/naga/tests/out/analysis/wgsl-access.info.ron @@ -1197,10 +1197,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -2527,10 +2523,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -2571,10 +2563,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -2624,10 +2612,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -2671,10 +2655,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -2769,10 +2749,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -2894,10 +2870,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -2950,10 +2922,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3009,10 +2977,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3065,10 +3029,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3124,10 +3084,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3192,10 +3148,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3269,10 +3221,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3349,10 +3297,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3453,10 +3397,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3653,10 +3593,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], entry_points: [ @@ -4354,10 +4290,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -4810,10 +4742,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -4884,10 +4812,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-collatz.info.ron b/naga/tests/out/analysis/wgsl-collatz.info.ron index 2796f544510..219e016f8d7 100644 --- a/naga/tests/out/analysis/wgsl-collatz.info.ron +++ b/naga/tests/out/analysis/wgsl-collatz.info.ron @@ -275,10 +275,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], entry_points: [ @@ -434,10 +430,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], const_expression_types: [], diff --git a/naga/tests/out/analysis/wgsl-mesh-shader.info.ron b/naga/tests/out/analysis/wgsl-mesh-shader.info.ron new file mode 100644 index 00000000000..eacd33ad0f1 --- /dev/null +++ b/naga/tests/out/analysis/wgsl-mesh-shader.info.ron @@ -0,0 +1,1469 @@ +( + type_flags: [ + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ], + functions: [], + entry_points: [ + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [ + ("READ | WRITE"), + ("WRITE"), + (""), + ], + expressions: [ + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(1), + ty: Value(Pointer( + base: 0, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 2, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Bool, + width: 1, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(6), + ), + ], + sampling: [], + dual_source_blending: false, + diagnostic_filter_leaf: None, + ), + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [ + ("READ"), + ("WRITE"), + ("READ | WRITE"), + ], + expressions: [ + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 0, + assignable_global: None, + ty: Handle(5), + ), + ( + uniformity: ( + non_uniform_result: Some(1), + requirements: (""), + ), + ref_count: 0, + assignable_global: None, + ty: Handle(6), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 5, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 5, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(1), + ty: Value(Pointer( + base: 0, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 9, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 4, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 1, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 9, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 4, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 1, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 9, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 4, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 1, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 9, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 4, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 1, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 9, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 4, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 1, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 9, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 4, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 1, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 10, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 7, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 6, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(6), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 10, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 7, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 2, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 2, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(2), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(2), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 10, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 7, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 1, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ], + sampling: [], + dual_source_blending: false, + diagnostic_filter_leaf: None, + ), + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [ + (""), + (""), + (""), + ], + expressions: [ + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(4), + ), + ( + uniformity: ( + non_uniform_result: Some(1), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(8), + ), + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(1), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ], + sampling: [], + dual_source_blending: false, + diagnostic_filter_leaf: None, + ), + ], + const_expression_types: [], +) \ No newline at end of file diff --git a/naga/tests/out/analysis/wgsl-overrides.info.ron b/naga/tests/out/analysis/wgsl-overrides.info.ron index a76c9c89c9b..92e99112e53 100644 --- a/naga/tests/out/analysis/wgsl-overrides.info.ron +++ b/naga/tests/out/analysis/wgsl-overrides.info.ron @@ -201,10 +201,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-storage-textures.info.ron b/naga/tests/out/analysis/wgsl-storage-textures.info.ron index 35b5a7e320c..8bb298a6450 100644 --- a/naga/tests/out/analysis/wgsl-storage-textures.info.ron +++ b/naga/tests/out/analysis/wgsl-storage-textures.info.ron @@ -184,10 +184,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -400,10 +396,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], const_expression_types: [], diff --git a/naga/tests/out/ir/wgsl-mesh-shader.compact.ron b/naga/tests/out/ir/wgsl-mesh-shader.compact.ron new file mode 100644 index 00000000000..1147b017f5c --- /dev/null +++ b/naga/tests/out/ir/wgsl-mesh-shader.compact.ron @@ -0,0 +1,980 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Quad, + scalar: ( + kind: Float, + width: 4, + ), + ), + ), + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: Some("TaskPayload"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: None, + offset: 0, + ), + ( + name: Some("visible"), + ty: 2, + binding: None, + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("VertexOutput"), + inner: Struct( + members: [ + ( + name: Some("position"), + ty: 1, + binding: Some(BuiltIn(Position( + invariant: false, + ))), + offset: 0, + ), + ( + name: Some("color"), + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Uint, + width: 4, + ), + ), + ), + ( + name: Some("PrimitiveOutput"), + inner: Struct( + members: [ + ( + name: Some("index"), + ty: 6, + binding: Some(BuiltIn(TriangleIndices)), + offset: 0, + ), + ( + name: Some("cull"), + ty: 2, + binding: Some(BuiltIn(CullPrimitive)), + offset: 12, + ), + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("PrimitiveInput"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 0, + ), + ], + span: 16, + ), + ), + ( + name: None, + inner: Array( + base: 4, + size: Constant(3), + stride: 32, + ), + ), + ( + name: None, + inner: Array( + base: 7, + size: Constant(1), + stride: 32, + ), + ), + ( + name: Some("MeshOutput"), + inner: Struct( + members: [ + ( + name: Some("vertices"), + ty: 9, + binding: Some(BuiltIn(Vertices)), + offset: 0, + ), + ( + name: Some("primitives"), + ty: 10, + binding: Some(BuiltIn(Primitives)), + offset: 96, + ), + ( + name: Some("vertex_count"), + ty: 5, + binding: Some(BuiltIn(VertexCount)), + offset: 128, + ), + ( + name: Some("primitive_count"), + ty: 5, + binding: Some(BuiltIn(PrimitiveCount)), + offset: 132, + ), + ], + span: 144, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ray_vertex_return: None, + external_texture_params: None, + external_texture_transfer_function: None, + predeclared_types: {}, + ), + constants: [], + overrides: [], + global_variables: [ + ( + name: Some("taskPayload"), + space: TaskPayload, + binding: None, + ty: 3, + init: None, + ), + ( + name: Some("workgroupData"), + space: WorkGroup, + binding: None, + ty: 0, + init: None, + ), + ( + name: Some("mesh_output"), + space: WorkGroup, + binding: None, + ty: 11, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "ts_main", + stage: Task, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ts_main"), + arguments: [], + result: Some(( + ty: 6, + binding: Some(BuiltIn(MeshTaskSize)), + )), + local_variables: [], + expressions: [ + GlobalVariable(1), + Literal(F32(1.0)), + GlobalVariable(0), + AccessIndex( + base: 2, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 4, + 5, + 6, + 7, + ], + ), + GlobalVariable(0), + AccessIndex( + base: 9, + index: 1, + ), + Literal(Bool(true)), + Literal(U32(3)), + Literal(U32(1)), + Literal(U32(1)), + Compose( + ty: 6, + components: [ + 12, + 13, + 14, + ], + ), + ], + named_expressions: {}, + body: [ + Store( + pointer: 0, + value: 1, + ), + Emit(( + start: 3, + end: 4, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 8, + end: 9, + )), + Store( + pointer: 3, + value: 8, + ), + Emit(( + start: 10, + end: 11, + )), + Store( + pointer: 10, + value: 11, + ), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 15, + end: 16, + )), + Return( + value: Some(15), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: Some(0), + ), + ( + name: "ms_main", + stage: Mesh, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ms_main"), + arguments: [ + ( + name: Some("index"), + ty: 5, + binding: Some(BuiltIn(LocalInvocationIndex)), + ), + ( + name: Some("id"), + ty: 6, + binding: Some(BuiltIn(GlobalInvocationId)), + ), + ], + result: None, + local_variables: [], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + GlobalVariable(2), + AccessIndex( + base: 2, + index: 2, + ), + Literal(U32(3)), + GlobalVariable(2), + AccessIndex( + base: 5, + index: 3, + ), + Literal(U32(1)), + GlobalVariable(1), + Literal(F32(2.0)), + GlobalVariable(2), + AccessIndex( + base: 10, + index: 0, + ), + AccessIndex( + base: 11, + index: 0, + ), + AccessIndex( + base: 12, + index: 0, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 14, + 15, + 16, + 17, + ], + ), + GlobalVariable(2), + AccessIndex( + base: 19, + index: 0, + ), + AccessIndex( + base: 20, + index: 0, + ), + AccessIndex( + base: 21, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 23, + index: 0, + ), + Load( + pointer: 24, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 26, + 27, + 28, + 29, + ], + ), + Binary( + op: Multiply, + left: 30, + right: 25, + ), + GlobalVariable(2), + AccessIndex( + base: 32, + index: 0, + ), + AccessIndex( + base: 33, + index: 1, + ), + AccessIndex( + base: 34, + index: 0, + ), + Literal(F32(-1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 36, + 37, + 38, + 39, + ], + ), + GlobalVariable(2), + AccessIndex( + base: 41, + index: 0, + ), + AccessIndex( + base: 42, + index: 1, + ), + AccessIndex( + base: 43, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 45, + index: 0, + ), + Load( + pointer: 46, + ), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 48, + 49, + 50, + 51, + ], + ), + Binary( + op: Multiply, + left: 52, + right: 47, + ), + GlobalVariable(2), + AccessIndex( + base: 54, + index: 0, + ), + AccessIndex( + base: 55, + index: 2, + ), + AccessIndex( + base: 56, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 58, + 59, + 60, + 61, + ], + ), + GlobalVariable(2), + AccessIndex( + base: 63, + index: 0, + ), + AccessIndex( + base: 64, + index: 2, + ), + AccessIndex( + base: 65, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 67, + index: 0, + ), + Load( + pointer: 68, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 70, + 71, + 72, + 73, + ], + ), + Binary( + op: Multiply, + left: 74, + right: 69, + ), + GlobalVariable(2), + AccessIndex( + base: 76, + index: 1, + ), + AccessIndex( + base: 77, + index: 0, + ), + AccessIndex( + base: 78, + index: 0, + ), + Literal(U32(0)), + Literal(U32(1)), + Literal(U32(2)), + Compose( + ty: 6, + components: [ + 80, + 81, + 82, + ], + ), + GlobalVariable(2), + AccessIndex( + base: 84, + index: 1, + ), + AccessIndex( + base: 85, + index: 0, + ), + AccessIndex( + base: 86, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 88, + index: 1, + ), + Load( + pointer: 89, + ), + Unary( + op: LogicalNot, + expr: 90, + ), + GlobalVariable(2), + AccessIndex( + base: 92, + index: 1, + ), + AccessIndex( + base: 93, + index: 0, + ), + AccessIndex( + base: 94, + index: 2, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 96, + 97, + 98, + 99, + ], + ), + ], + named_expressions: { + 0: "index", + 1: "id", + }, + body: [ + Emit(( + start: 3, + end: 4, + )), + Store( + pointer: 3, + value: 4, + ), + Emit(( + start: 6, + end: 7, + )), + Store( + pointer: 6, + value: 7, + ), + Store( + pointer: 8, + value: 9, + ), + Emit(( + start: 11, + end: 12, + )), + Emit(( + start: 12, + end: 14, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 18, + end: 19, + )), + Store( + pointer: 13, + value: 18, + ), + Emit(( + start: 20, + end: 21, + )), + Emit(( + start: 21, + end: 23, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 24, + end: 26, + )), + Emit(( + start: 30, + end: 32, + )), + Store( + pointer: 22, + value: 31, + ), + Emit(( + start: 33, + end: 34, + )), + Emit(( + start: 34, + end: 36, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 40, + end: 41, + )), + Store( + pointer: 35, + value: 40, + ), + Emit(( + start: 42, + end: 43, + )), + Emit(( + start: 43, + end: 45, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 46, + end: 48, + )), + Emit(( + start: 52, + end: 54, + )), + Store( + pointer: 44, + value: 53, + ), + Emit(( + start: 55, + end: 56, + )), + Emit(( + start: 56, + end: 58, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 62, + end: 63, + )), + Store( + pointer: 57, + value: 62, + ), + Emit(( + start: 64, + end: 65, + )), + Emit(( + start: 65, + end: 67, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 68, + end: 70, + )), + Emit(( + start: 74, + end: 76, + )), + Store( + pointer: 66, + value: 75, + ), + Emit(( + start: 77, + end: 78, + )), + Emit(( + start: 78, + end: 80, + )), + Emit(( + start: 83, + end: 84, + )), + Store( + pointer: 79, + value: 83, + ), + Emit(( + start: 85, + end: 86, + )), + Emit(( + start: 86, + end: 88, + )), + Emit(( + start: 89, + end: 92, + )), + Store( + pointer: 87, + value: 91, + ), + Emit(( + start: 93, + end: 94, + )), + Emit(( + start: 94, + end: 96, + )), + Emit(( + start: 100, + end: 101, + )), + Store( + pointer: 95, + value: 100, + ), + Return( + value: None, + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: Some(( + topology: Triangles, + max_vertices: 3, + max_vertices_override: None, + max_primitives: 1, + max_primitives_override: None, + vertex_output_type: 4, + primitive_output_type: 7, + output_variable: 2, + )), + task_payload: Some(0), + ), + ( + name: "fs_main", + stage: Fragment, + early_depth_test: None, + workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, + function: ( + name: Some("fs_main"), + arguments: [ + ( + name: Some("vertex"), + ty: 4, + binding: None, + ), + ( + name: Some("primitive"), + ty: 8, + binding: None, + ), + ], + result: Some(( + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + )), + local_variables: [], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + AccessIndex( + base: 0, + index: 1, + ), + AccessIndex( + base: 1, + index: 0, + ), + Binary( + op: Multiply, + left: 2, + right: 3, + ), + ], + named_expressions: { + 0: "vertex", + 1: "primitive", + }, + body: [ + Emit(( + start: 2, + end: 5, + )), + Return( + value: Some(4), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: None, + ), + ], + diagnostic_filters: [], + diagnostic_filter_leaf: None, + doc_comments: None, +) \ No newline at end of file diff --git a/naga/tests/out/ir/wgsl-mesh-shader.ron b/naga/tests/out/ir/wgsl-mesh-shader.ron new file mode 100644 index 00000000000..1147b017f5c --- /dev/null +++ b/naga/tests/out/ir/wgsl-mesh-shader.ron @@ -0,0 +1,980 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Quad, + scalar: ( + kind: Float, + width: 4, + ), + ), + ), + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: Some("TaskPayload"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: None, + offset: 0, + ), + ( + name: Some("visible"), + ty: 2, + binding: None, + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("VertexOutput"), + inner: Struct( + members: [ + ( + name: Some("position"), + ty: 1, + binding: Some(BuiltIn(Position( + invariant: false, + ))), + offset: 0, + ), + ( + name: Some("color"), + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Uint, + width: 4, + ), + ), + ), + ( + name: Some("PrimitiveOutput"), + inner: Struct( + members: [ + ( + name: Some("index"), + ty: 6, + binding: Some(BuiltIn(TriangleIndices)), + offset: 0, + ), + ( + name: Some("cull"), + ty: 2, + binding: Some(BuiltIn(CullPrimitive)), + offset: 12, + ), + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("PrimitiveInput"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 0, + ), + ], + span: 16, + ), + ), + ( + name: None, + inner: Array( + base: 4, + size: Constant(3), + stride: 32, + ), + ), + ( + name: None, + inner: Array( + base: 7, + size: Constant(1), + stride: 32, + ), + ), + ( + name: Some("MeshOutput"), + inner: Struct( + members: [ + ( + name: Some("vertices"), + ty: 9, + binding: Some(BuiltIn(Vertices)), + offset: 0, + ), + ( + name: Some("primitives"), + ty: 10, + binding: Some(BuiltIn(Primitives)), + offset: 96, + ), + ( + name: Some("vertex_count"), + ty: 5, + binding: Some(BuiltIn(VertexCount)), + offset: 128, + ), + ( + name: Some("primitive_count"), + ty: 5, + binding: Some(BuiltIn(PrimitiveCount)), + offset: 132, + ), + ], + span: 144, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ray_vertex_return: None, + external_texture_params: None, + external_texture_transfer_function: None, + predeclared_types: {}, + ), + constants: [], + overrides: [], + global_variables: [ + ( + name: Some("taskPayload"), + space: TaskPayload, + binding: None, + ty: 3, + init: None, + ), + ( + name: Some("workgroupData"), + space: WorkGroup, + binding: None, + ty: 0, + init: None, + ), + ( + name: Some("mesh_output"), + space: WorkGroup, + binding: None, + ty: 11, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "ts_main", + stage: Task, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ts_main"), + arguments: [], + result: Some(( + ty: 6, + binding: Some(BuiltIn(MeshTaskSize)), + )), + local_variables: [], + expressions: [ + GlobalVariable(1), + Literal(F32(1.0)), + GlobalVariable(0), + AccessIndex( + base: 2, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 4, + 5, + 6, + 7, + ], + ), + GlobalVariable(0), + AccessIndex( + base: 9, + index: 1, + ), + Literal(Bool(true)), + Literal(U32(3)), + Literal(U32(1)), + Literal(U32(1)), + Compose( + ty: 6, + components: [ + 12, + 13, + 14, + ], + ), + ], + named_expressions: {}, + body: [ + Store( + pointer: 0, + value: 1, + ), + Emit(( + start: 3, + end: 4, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 8, + end: 9, + )), + Store( + pointer: 3, + value: 8, + ), + Emit(( + start: 10, + end: 11, + )), + Store( + pointer: 10, + value: 11, + ), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 15, + end: 16, + )), + Return( + value: Some(15), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: Some(0), + ), + ( + name: "ms_main", + stage: Mesh, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ms_main"), + arguments: [ + ( + name: Some("index"), + ty: 5, + binding: Some(BuiltIn(LocalInvocationIndex)), + ), + ( + name: Some("id"), + ty: 6, + binding: Some(BuiltIn(GlobalInvocationId)), + ), + ], + result: None, + local_variables: [], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + GlobalVariable(2), + AccessIndex( + base: 2, + index: 2, + ), + Literal(U32(3)), + GlobalVariable(2), + AccessIndex( + base: 5, + index: 3, + ), + Literal(U32(1)), + GlobalVariable(1), + Literal(F32(2.0)), + GlobalVariable(2), + AccessIndex( + base: 10, + index: 0, + ), + AccessIndex( + base: 11, + index: 0, + ), + AccessIndex( + base: 12, + index: 0, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 14, + 15, + 16, + 17, + ], + ), + GlobalVariable(2), + AccessIndex( + base: 19, + index: 0, + ), + AccessIndex( + base: 20, + index: 0, + ), + AccessIndex( + base: 21, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 23, + index: 0, + ), + Load( + pointer: 24, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 26, + 27, + 28, + 29, + ], + ), + Binary( + op: Multiply, + left: 30, + right: 25, + ), + GlobalVariable(2), + AccessIndex( + base: 32, + index: 0, + ), + AccessIndex( + base: 33, + index: 1, + ), + AccessIndex( + base: 34, + index: 0, + ), + Literal(F32(-1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 36, + 37, + 38, + 39, + ], + ), + GlobalVariable(2), + AccessIndex( + base: 41, + index: 0, + ), + AccessIndex( + base: 42, + index: 1, + ), + AccessIndex( + base: 43, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 45, + index: 0, + ), + Load( + pointer: 46, + ), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 48, + 49, + 50, + 51, + ], + ), + Binary( + op: Multiply, + left: 52, + right: 47, + ), + GlobalVariable(2), + AccessIndex( + base: 54, + index: 0, + ), + AccessIndex( + base: 55, + index: 2, + ), + AccessIndex( + base: 56, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 58, + 59, + 60, + 61, + ], + ), + GlobalVariable(2), + AccessIndex( + base: 63, + index: 0, + ), + AccessIndex( + base: 64, + index: 2, + ), + AccessIndex( + base: 65, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 67, + index: 0, + ), + Load( + pointer: 68, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 70, + 71, + 72, + 73, + ], + ), + Binary( + op: Multiply, + left: 74, + right: 69, + ), + GlobalVariable(2), + AccessIndex( + base: 76, + index: 1, + ), + AccessIndex( + base: 77, + index: 0, + ), + AccessIndex( + base: 78, + index: 0, + ), + Literal(U32(0)), + Literal(U32(1)), + Literal(U32(2)), + Compose( + ty: 6, + components: [ + 80, + 81, + 82, + ], + ), + GlobalVariable(2), + AccessIndex( + base: 84, + index: 1, + ), + AccessIndex( + base: 85, + index: 0, + ), + AccessIndex( + base: 86, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 88, + index: 1, + ), + Load( + pointer: 89, + ), + Unary( + op: LogicalNot, + expr: 90, + ), + GlobalVariable(2), + AccessIndex( + base: 92, + index: 1, + ), + AccessIndex( + base: 93, + index: 0, + ), + AccessIndex( + base: 94, + index: 2, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 96, + 97, + 98, + 99, + ], + ), + ], + named_expressions: { + 0: "index", + 1: "id", + }, + body: [ + Emit(( + start: 3, + end: 4, + )), + Store( + pointer: 3, + value: 4, + ), + Emit(( + start: 6, + end: 7, + )), + Store( + pointer: 6, + value: 7, + ), + Store( + pointer: 8, + value: 9, + ), + Emit(( + start: 11, + end: 12, + )), + Emit(( + start: 12, + end: 14, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 18, + end: 19, + )), + Store( + pointer: 13, + value: 18, + ), + Emit(( + start: 20, + end: 21, + )), + Emit(( + start: 21, + end: 23, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 24, + end: 26, + )), + Emit(( + start: 30, + end: 32, + )), + Store( + pointer: 22, + value: 31, + ), + Emit(( + start: 33, + end: 34, + )), + Emit(( + start: 34, + end: 36, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 40, + end: 41, + )), + Store( + pointer: 35, + value: 40, + ), + Emit(( + start: 42, + end: 43, + )), + Emit(( + start: 43, + end: 45, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 46, + end: 48, + )), + Emit(( + start: 52, + end: 54, + )), + Store( + pointer: 44, + value: 53, + ), + Emit(( + start: 55, + end: 56, + )), + Emit(( + start: 56, + end: 58, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 62, + end: 63, + )), + Store( + pointer: 57, + value: 62, + ), + Emit(( + start: 64, + end: 65, + )), + Emit(( + start: 65, + end: 67, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 68, + end: 70, + )), + Emit(( + start: 74, + end: 76, + )), + Store( + pointer: 66, + value: 75, + ), + Emit(( + start: 77, + end: 78, + )), + Emit(( + start: 78, + end: 80, + )), + Emit(( + start: 83, + end: 84, + )), + Store( + pointer: 79, + value: 83, + ), + Emit(( + start: 85, + end: 86, + )), + Emit(( + start: 86, + end: 88, + )), + Emit(( + start: 89, + end: 92, + )), + Store( + pointer: 87, + value: 91, + ), + Emit(( + start: 93, + end: 94, + )), + Emit(( + start: 94, + end: 96, + )), + Emit(( + start: 100, + end: 101, + )), + Store( + pointer: 95, + value: 100, + ), + Return( + value: None, + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: Some(( + topology: Triangles, + max_vertices: 3, + max_vertices_override: None, + max_primitives: 1, + max_primitives_override: None, + vertex_output_type: 4, + primitive_output_type: 7, + output_variable: 2, + )), + task_payload: Some(0), + ), + ( + name: "fs_main", + stage: Fragment, + early_depth_test: None, + workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, + function: ( + name: Some("fs_main"), + arguments: [ + ( + name: Some("vertex"), + ty: 4, + binding: None, + ), + ( + name: Some("primitive"), + ty: 8, + binding: None, + ), + ], + result: Some(( + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + )), + local_variables: [], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + AccessIndex( + base: 0, + index: 1, + ), + AccessIndex( + base: 1, + index: 0, + ), + Binary( + op: Multiply, + left: 2, + right: 3, + ), + ], + named_expressions: { + 0: "vertex", + 1: "primitive", + }, + body: [ + Emit(( + start: 2, + end: 5, + )), + Return( + value: Some(4), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: None, + ), + ], + diagnostic_filters: [], + diagnostic_filter_leaf: None, + doc_comments: None, +) \ No newline at end of file diff --git a/naga/tests/out/msl/wgsl-mesh-shader.msl b/naga/tests/out/msl/wgsl-mesh-shader.msl new file mode 100644 index 00000000000..5280464ea63 --- /dev/null +++ b/naga/tests/out/msl/wgsl-mesh-shader.msl @@ -0,0 +1,98 @@ +// language: metal2.4 +#include +#include + +using metal::uint; + +struct TaskPayload { + metal::float4 colorMask; + bool visible; + char _pad2[15]; +}; +struct VertexOutput { + metal::float4 position; + metal::float4 color; +}; +struct PrimitiveOutput { + metal::packed_uint3 index; + bool cull; + char _pad2[3]; + metal::float4 colorMask; +}; +struct PrimitiveInput { + metal::float4 colorMask; +}; +struct type_5 { + VertexOutput inner[3]; +}; +struct type_6 { + PrimitiveOutput inner[1]; +}; +struct MeshOutput { + type_5 vertices; + type_6 primitives; + uint vertex_count; + uint primitive_count; + char _pad4[8]; +}; + +struct ts_mainOutput { + metal::uint3 member [[TODO_MESH_BUILTIN]]; +}; +[[task]] +ts_mainOutput ts_main( + object_data TaskPayload& taskPayload [[payload]] +, threadgroup float& workgroupData +) { + workgroupData = 1.0; + taskPayload.colorMask = metal::float4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + return ts_mainOutput { metal::uint3(3u, 1u, 1u) }; +} + + +struct ms_mainInput { +}; +[[mesh]] +void ms_main( + uint index [[thread_index_in_threadgroup]] +, metal::uint3 id [[thread_position_in_grid]] +, object_data TaskPayload const& taskPayload [[payload]] +, threadgroup float& workgroupData +, threadgroup MeshOutput& mesh_output +) { + mesh_output.vertex_count = 3u; + mesh_output.primitive_count = 1u; + workgroupData = 2.0; + mesh_output.vertices.inner[0].position = metal::float4(0.0, 1.0, 0.0, 1.0); + metal::float4 _e25 = taskPayload.colorMask; + mesh_output.vertices.inner[0].color = metal::float4(0.0, 1.0, 0.0, 1.0) * _e25; + mesh_output.vertices.inner[1].position = metal::float4(-1.0, -1.0, 0.0, 1.0); + metal::float4 _e47 = taskPayload.colorMask; + mesh_output.vertices.inner[1].color = metal::float4(0.0, 0.0, 1.0, 1.0) * _e47; + mesh_output.vertices.inner[2].position = metal::float4(1.0, -1.0, 0.0, 1.0); + metal::float4 _e69 = taskPayload.colorMask; + mesh_output.vertices.inner[2].color = metal::float4(1.0, 0.0, 0.0, 1.0) * _e69; + mesh_output.primitives.inner[0].index = metal::uint3(0u, 1u, 2u); + bool _e90 = taskPayload.visible; + mesh_output.primitives.inner[0].cull = !(_e90); + mesh_output.primitives.inner[0].colorMask = metal::float4(1.0, 0.0, 1.0, 1.0); + return; +} + + +struct fs_mainInput { + metal::float4 color [[user(loc0), center_perspective]]; + metal::float4 colorMask [[user(loc1), center_perspective]]; +}; +struct fs_mainOutput { + metal::float4 member_2 [[color(0)]]; +}; +fragment fs_mainOutput fs_main( + fs_mainInput varyings_2 [[stage_in]] +, metal::float4 position [[position]] +) { + const VertexOutput vertex_ = { position, varyings_2.color }; + const PrimitiveInput primitive = { varyings_2.colorMask }; + return fs_mainOutput { vertex_.color * primitive.colorMask }; +} diff --git a/tests/tests/wgpu-gpu/mesh_shader/mod.rs b/tests/tests/wgpu-gpu/mesh_shader/mod.rs index 069fa1bf567..675436a1d7e 100644 --- a/tests/tests/wgpu-gpu/mesh_shader/mod.rs +++ b/tests/tests/wgpu-gpu/mesh_shader/mod.rs @@ -5,12 +5,13 @@ use std::{ use wgpu::{util::DeviceExt, Backends}; use wgpu_test::{ - fail, gpu_test, FailureCase, GpuTestConfiguration, GpuTestInitializer, TestParameters, - TestingContext, + gpu_test, FailureCase, GpuTestConfiguration, GpuTestInitializer, TestParameters, TestingContext, }; /// Backends that support mesh shaders -const MESH_SHADER_BACKENDS: Backends = Backends::DX12.union(Backends::VULKAN); +const MESH_SHADER_BACKENDS: Backends = Backends::DX12 + .union(Backends::VULKAN) + .union(Backends::METAL); pub fn all_tests(tests: &mut Vec) { tests.extend([ @@ -23,11 +24,16 @@ pub fn all_tests(tests: &mut Vec) { MESH_MULTI_DRAW_INDIRECT_COUNT, MESH_PIPELINE_BASIC_MESH_NO_DRAW, MESH_PIPELINE_BASIC_TASK_MESH_FRAG_NO_DRAW, - MESH_DISABLED, ]); } // Same as in mesh shader example +fn compile_wgsl(device: &wgpu::Device) -> wgpu::ShaderModule { + device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()), + }) +} fn compile_glsl(device: &wgpu::Device, shader_stage: &'static str) -> wgpu::ShaderModule { let cmd = std::process::Command::new("glslc") .args([ @@ -55,7 +61,6 @@ fn compile_glsl(device: &wgpu::Device, shader_stage: &'static str) -> wgpu::Shad }) } } - fn compile_hlsl( device: &wgpu::Device, entry: &str, @@ -107,39 +112,47 @@ fn get_shaders( Option, wgpu::ShaderModule, Option, + &'static str, + &'static str, + &'static str, ) { - // On backends that don't support mesh shaders, or for the MESH_DISABLED - // test, compile a dummy shader so we can construct a structurally valid - // pipeline description and test that `create_mesh_pipeline` fails. - // (In the case that the platform does support mesh shaders, the dummy - // shader is used to avoid requiring EXPERIMENTAL_PASSTHROUGH_SHADERS.) + // In the case that the platform does support mesh shaders, the dummy + // shader is used to avoid requiring EXPERIMENTAL_PASSTHROUGH_SHADERS. let dummy_shader = device.create_shader_module(wgpu::include_wgsl!("non_mesh.wgsl")); - if backend == wgpu::Backend::Vulkan { + if backend == wgpu::Backend::Metal { + let s = compile_wgsl(device); + ( + info.use_task.then(|| s.clone()), + s.clone(), + info.use_frag.then_some(s), + "ts_main", + if info.use_task { "ms_main" } else { "ms_no_ts" }, + "fs_main", + ) + } else if backend == wgpu::Backend::Vulkan { ( info.use_task.then(|| compile_glsl(device, "task")), - if info.use_mesh { - compile_glsl(device, "mesh") - } else { - dummy_shader - }, + compile_glsl(device, "mesh"), info.use_frag.then(|| compile_glsl(device, "frag")), + "main", + "main", + "main", ) } else if backend == wgpu::Backend::Dx12 { ( info.use_task .then(|| compile_hlsl(device, "Task", "as", test_name)), - if info.use_mesh { - compile_hlsl(device, "Mesh", "ms", test_name) - } else { - dummy_shader - }, + compile_hlsl(device, "Mesh", "ms", test_name), info.use_frag .then(|| compile_hlsl(device, "Frag", "ps", test_name)), + "main", + "main", + "main", ) } else { assert!(!MESH_SHADER_BACKENDS.contains(Backends::from(backend))); - assert!(!info.use_task && !info.use_mesh && !info.use_frag); - (None, dummy_shader, None) + assert!(!info.use_task && !info.use_frag); + (None, dummy_shader, None, "main", "main", "main") } } @@ -174,7 +187,6 @@ fn create_depth( struct MeshPipelineTestInfo { use_task: bool, - use_mesh: bool, use_frag: bool, draw: bool, } @@ -191,7 +203,8 @@ fn mesh_pipeline_build(ctx: &TestingContext, info: MeshPipelineTestInfo) { let (_depth_image, depth_view, depth_state) = create_depth(device); let test_hash = hash_testing_context(ctx).to_string(); - let (task, mesh, frag) = get_shaders(device, backend, &test_hash, &info); + let (task, mesh, frag, ts_name, ms_name, fs_name) = + get_shaders(device, backend, &test_hash, &info); let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { label: None, bind_group_layouts: &[], @@ -202,17 +215,17 @@ fn mesh_pipeline_build(ctx: &TestingContext, info: MeshPipelineTestInfo) { layout: Some(&layout), task: task.as_ref().map(|task| wgpu::TaskState { module: task, - entry_point: Some("main"), + entry_point: Some(ts_name), compilation_options: Default::default(), }), mesh: wgpu::MeshState { module: &mesh, - entry_point: Some("main"), + entry_point: Some(ms_name), compilation_options: Default::default(), }, fragment: frag.as_ref().map(|frag| wgpu::FragmentState { module: frag, - entry_point: Some("main"), + entry_point: Some(fs_name), targets: &[], compilation_options: Default::default(), }), @@ -273,11 +286,11 @@ fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) { let test_hash = hash_testing_context(ctx).to_string(); let info = MeshPipelineTestInfo { use_task: true, - use_mesh: true, use_frag: true, draw: true, }; - let (task, mesh, frag) = get_shaders(device, backend, &test_hash, &info); + let (task, mesh, frag, ts_name, ms_name, fs_name) = + get_shaders(device, backend, &test_hash, &info); let task = task.unwrap(); let frag = frag.unwrap(); let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { @@ -290,17 +303,17 @@ fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) { layout: Some(&layout), task: Some(wgpu::TaskState { module: &task, - entry_point: Some("main"), + entry_point: Some(ts_name), compilation_options: Default::default(), }), mesh: wgpu::MeshState { module: &mesh, - entry_point: Some("main"), + entry_point: Some(ms_name), compilation_options: Default::default(), }, fragment: Some(wgpu::FragmentState { module: &frag, - entry_point: Some("main"), + entry_point: Some(fs_name), targets: &[], compilation_options: Default::default(), }), @@ -400,7 +413,6 @@ pub static MESH_PIPELINE_BASIC_MESH: GpuTestConfiguration = &ctx, MeshPipelineTestInfo { use_task: false, - use_mesh: true, use_frag: false, draw: true, }, @@ -413,7 +425,6 @@ pub static MESH_PIPELINE_BASIC_TASK_MESH: GpuTestConfiguration = &ctx, MeshPipelineTestInfo { use_task: true, - use_mesh: true, use_frag: false, draw: true, }, @@ -426,7 +437,6 @@ pub static MESH_PIPELINE_BASIC_MESH_FRAG: GpuTestConfiguration = &ctx, MeshPipelineTestInfo { use_task: false, - use_mesh: true, use_frag: true, draw: true, }, @@ -439,7 +449,6 @@ pub static MESH_PIPELINE_BASIC_TASK_MESH_FRAG: GpuTestConfiguration = &ctx, MeshPipelineTestInfo { use_task: true, - use_mesh: true, use_frag: true, draw: true, }, @@ -452,7 +461,6 @@ pub static MESH_PIPELINE_BASIC_MESH_NO_DRAW: GpuTestConfiguration = &ctx, MeshPipelineTestInfo { use_task: false, - use_mesh: true, use_frag: false, draw: false, }, @@ -465,7 +473,6 @@ pub static MESH_PIPELINE_BASIC_TASK_MESH_FRAG_NO_DRAW: GpuTestConfiguration = &ctx, MeshPipelineTestInfo { use_task: true, - use_mesh: true, use_frag: true, draw: false, }, @@ -488,30 +495,3 @@ pub static MESH_MULTI_DRAW_INDIRECT_COUNT: GpuTestConfiguration = default_gpu_test_config(DrawType::MultiIndirectCount).run_sync(|ctx| { mesh_draw(&ctx, DrawType::MultiIndirectCount); }); - -/// When the mesh shading feature is disabled, calls to `create_mesh_pipeline` -/// should be rejected. This should be the case on all backends, not just the -/// ones where the feature could be turned on. -#[gpu_test] -pub static MESH_DISABLED: GpuTestConfiguration = GpuTestConfiguration::new().run_sync(|ctx| { - fail( - &ctx.device, - || { - mesh_pipeline_build( - &ctx, - MeshPipelineTestInfo { - use_task: false, - use_mesh: false, - use_frag: false, - draw: true, - }, - ); - }, - Some(concat![ - "Features Features { ", - "features_wgpu: FeaturesWGPU(EXPERIMENTAL_MESH_SHADER), ", - "features_webgpu: FeaturesWebGPU(0x0) ", - "} are required but not enabled on the device", - ]), - ) -}); diff --git a/tests/tests/wgpu-gpu/mesh_shader/shader.wgsl b/tests/tests/wgpu-gpu/mesh_shader/shader.wgsl new file mode 100644 index 00000000000..cdc7366b415 --- /dev/null +++ b/tests/tests/wgpu-gpu/mesh_shader/shader.wgsl @@ -0,0 +1,74 @@ +enable mesh_shading; + +const positions = array( + vec4(0., 1., 0., 1.), + vec4(-1., -1., 0., 1.), + vec4(1., -1., 0., 1.) +); +const colors = array( + vec4(0., 1., 0., 1.), + vec4(0., 0., 1., 1.), + vec4(1., 0., 0., 1.) +); +struct TaskPayload { + colorMask: vec4, + visible: bool, +} +var taskPayload: TaskPayload; +var workgroupData: f32; +struct VertexOutput { + @builtin(position) position: vec4, + @location(0) color: vec4, +} +struct PrimitiveOutput { + @builtin(triangle_indices) index: vec3, + @builtin(cull_primitive) cull: bool, + @per_primitive @location(1) colorMask: vec4, +} +struct PrimitiveInput { + @per_primitive @location(1) colorMask: vec4, +} + +@task +@payload(taskPayload) +@workgroup_size(1) +fn ts_main() -> @builtin(mesh_task_size) vec3 { + workgroupData = 1.0; + taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + return vec3(3, 1, 1); +} + +struct MeshOutput { + @builtin(vertices) vertices: array, + @builtin(primitives) primitives: array, + @builtin(vertex_count) vertex_count: u32, + @builtin(primitive_count) primitive_count: u32, +} + +var mesh_output: MeshOutput; +@mesh(mesh_output) +@payload(taskPayload) +@workgroup_size(1) +fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3) { + mesh_output.vertex_count = 3; + mesh_output.primitive_count = 1; + workgroupData = 2.0; + + mesh_output.vertices[0].position = positions[0]; + mesh_output.vertices[0].color = colors[0] * taskPayload.colorMask; + + mesh_output.vertices[1].position = positions[1]; + mesh_output.vertices[1].color = colors[1] * taskPayload.colorMask; + + mesh_output.vertices[2].position = positions[2]; + mesh_output.vertices[2].color = colors[2] * taskPayload.colorMask; + + mesh_output.primitives[0].index = vec3(0, 1, 2); + mesh_output.primitives[0].cull = !taskPayload.visible; + mesh_output.primitives[0].colorMask = vec4(1.0, 0.0, 1.0, 1.0); +} +@fragment +fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { + return vertex.color * primitive.colorMask; +} diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index ec5203f291b..c9e77e8ba45 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -514,6 +514,10 @@ pub fn create_validator( Caps::SHADER_BARYCENTRICS, features.intersects(wgt::Features::SHADER_BARYCENTRICS), ); + caps.set( + Caps::MESH_SHADER, + features.intersects(wgt::Features::EXPERIMENTAL_MESH_SHADER), + ); naga::valid::Validator::new(flags, caps) } diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index fa3d2fa8d41..369fd5a885c 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -607,6 +607,9 @@ impl super::PrivateCapabilities { let argument_buffers = device.argument_buffers_support(); + // Lmao + let is_virtual = device.name().to_lowercase().contains("virtual"); + Self { family_check, msl_version: if os_is_xr || version.at_least((14, 0), (17, 0), os_is_mac) { @@ -902,6 +905,12 @@ impl super::PrivateCapabilities { && (device.supports_family(MTLGPUFamily::Apple7) || device.supports_family(MTLGPUFamily::Mac2)), supports_shared_event: version.at_least((10, 14), (12, 0), os_is_mac), + mesh_shaders: family_check + && (device.supports_family(MTLGPUFamily::Metal3) + || device.supports_family(MTLGPUFamily::Apple7) + || device.supports_family(MTLGPUFamily::Mac2)) + // Mesh shaders don't work on virtual devices even if they should be supported. + && !is_virtual, supported_vertex_amplification_factor: { let mut factor = 1; // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf#page=8 @@ -1023,6 +1032,8 @@ impl super::PrivateCapabilities { features.insert(F::SUBGROUP | F::SUBGROUP_BARRIER); } + features.set(F::EXPERIMENTAL_MESH_SHADER, self.mesh_shaders); + if self.supported_vertex_amplification_factor > 1 { features.insert(F::MULTIVIEW); } @@ -1102,10 +1113,11 @@ impl super::PrivateCapabilities { max_buffer_size: self.max_buffer_size, max_non_sampler_bindings: u32::MAX, - max_task_workgroup_total_count: 0, - max_task_workgroups_per_dimension: 0, + // See https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf, Maximum threadgroups per mesh shader grid + max_task_workgroup_total_count: 1024, + max_task_workgroups_per_dimension: 1024, max_mesh_multiview_view_count: 0, - max_mesh_output_layers: 0, + max_mesh_output_layers: self.max_texture_layers as u32, max_blas_primitive_count: 0, // When added: 2^28 from https://developer.apple.com/documentation/metal/mtlaccelerationstructureusage/extendedlimits max_blas_geometry_count: 0, // When added: 2^24 diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index ebabbd4c756..ec3089d1028 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -22,11 +22,9 @@ impl Default for super::CommandState { compute: None, raw_primitive_type: MTLPrimitiveType::Point, index: None, - raw_wg_size: MTLSize::new(0, 0, 0), stage_infos: Default::default(), storage_buffer_length_map: Default::default(), vertex_buffer_size_map: Default::default(), - work_group_memory_sizes: Vec::new(), push_constants: Vec::new(), pending_timer_queries: Vec::new(), } @@ -146,6 +144,127 @@ impl super::CommandEncoder { self.state.reset(); self.leave_blit(); } + + /// Updates the bindings for a single shader stage, called in `set_bind_group`. + #[expect(clippy::too_many_arguments)] + fn update_bind_group_state( + &mut self, + stage: naga::ShaderStage, + render_encoder: Option<&metal::RenderCommandEncoder>, + compute_encoder: Option<&metal::ComputeCommandEncoder>, + index_base: super::ResourceData, + bg_info: &super::BindGroupLayoutInfo, + dynamic_offsets: &[wgt::DynamicOffset], + group_index: u32, + group: &super::BindGroup, + ) { + let resource_indices = match stage { + naga::ShaderStage::Vertex => &bg_info.base_resource_indices.vs, + naga::ShaderStage::Fragment => &bg_info.base_resource_indices.fs, + naga::ShaderStage::Task => &bg_info.base_resource_indices.ts, + naga::ShaderStage::Mesh => &bg_info.base_resource_indices.ms, + naga::ShaderStage::Compute => &bg_info.base_resource_indices.cs, + }; + let buffers = match stage { + naga::ShaderStage::Vertex => group.counters.vs.buffers, + naga::ShaderStage::Fragment => group.counters.fs.buffers, + naga::ShaderStage::Task => group.counters.ts.buffers, + naga::ShaderStage::Mesh => group.counters.ms.buffers, + naga::ShaderStage::Compute => group.counters.cs.buffers, + }; + let mut changes_sizes_buffer = false; + for index in 0..buffers { + let buf = &group.buffers[(index_base.buffers + index) as usize]; + let mut offset = buf.offset; + if let Some(dyn_index) = buf.dynamic_index { + offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; + } + let a1 = (resource_indices.buffers + index) as u64; + let a2 = Some(buf.ptr.as_native()); + let a3 = offset; + match stage { + naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_buffer(a1, a2, a3), + naga::ShaderStage::Fragment => { + render_encoder.unwrap().set_fragment_buffer(a1, a2, a3) + } + naga::ShaderStage::Task => render_encoder.unwrap().set_object_buffer(a1, a2, a3), + naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_buffer(a1, a2, a3), + naga::ShaderStage::Compute => compute_encoder.unwrap().set_buffer(a1, a2, a3), + } + if let Some(size) = buf.binding_size { + let br = naga::ResourceBinding { + group: group_index, + binding: buf.binding_location, + }; + self.state.storage_buffer_length_map.insert(br, size); + changes_sizes_buffer = true; + } + } + if changes_sizes_buffer { + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(stage, &mut self.temp.binding_sizes) + { + let a1 = index as _; + let a2 = (sizes.len() * WORD_SIZE) as u64; + let a3 = sizes.as_ptr().cast(); + match stage { + naga::ShaderStage::Vertex => { + render_encoder.unwrap().set_vertex_bytes(a1, a2, a3) + } + naga::ShaderStage::Fragment => { + render_encoder.unwrap().set_fragment_bytes(a1, a2, a3) + } + naga::ShaderStage::Task => render_encoder.unwrap().set_object_bytes(a1, a2, a3), + naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_bytes(a1, a2, a3), + naga::ShaderStage::Compute => compute_encoder.unwrap().set_bytes(a1, a2, a3), + } + } + } + let samplers = match stage { + naga::ShaderStage::Vertex => group.counters.vs.samplers, + naga::ShaderStage::Fragment => group.counters.fs.samplers, + naga::ShaderStage::Task => group.counters.ts.samplers, + naga::ShaderStage::Mesh => group.counters.ms.samplers, + naga::ShaderStage::Compute => group.counters.cs.samplers, + }; + for index in 0..samplers { + let res = group.samplers[(index_base.samplers + index) as usize]; + let a1 = (resource_indices.samplers + index) as u64; + let a2 = Some(res.as_native()); + match stage { + naga::ShaderStage::Vertex => { + render_encoder.unwrap().set_vertex_sampler_state(a1, a2) + } + naga::ShaderStage::Fragment => { + render_encoder.unwrap().set_fragment_sampler_state(a1, a2) + } + naga::ShaderStage::Task => render_encoder.unwrap().set_object_sampler_state(a1, a2), + naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_sampler_state(a1, a2), + naga::ShaderStage::Compute => compute_encoder.unwrap().set_sampler_state(a1, a2), + } + } + + let textures = match stage { + naga::ShaderStage::Vertex => group.counters.vs.textures, + naga::ShaderStage::Fragment => group.counters.fs.textures, + naga::ShaderStage::Task => group.counters.ts.textures, + naga::ShaderStage::Mesh => group.counters.ms.textures, + naga::ShaderStage::Compute => group.counters.cs.textures, + }; + for index in 0..textures { + let res = group.textures[(index_base.textures + index) as usize]; + let a1 = (resource_indices.textures + index) as u64; + let a2 = Some(res.as_native()); + match stage { + naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_texture(a1, a2), + naga::ShaderStage::Fragment => render_encoder.unwrap().set_fragment_texture(a1, a2), + naga::ShaderStage::Task => render_encoder.unwrap().set_object_texture(a1, a2), + naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_texture(a1, a2), + naga::ShaderStage::Compute => compute_encoder.unwrap().set_texture(a1, a2), + } + } + } } impl super::CommandState { @@ -155,7 +274,8 @@ impl super::CommandState { self.stage_infos.vs.clear(); self.stage_infos.fs.clear(); self.stage_infos.cs.clear(); - self.work_group_memory_sizes.clear(); + self.stage_infos.ts.clear(); + self.stage_infos.ms.clear(); self.push_constants.clear(); } @@ -702,168 +822,86 @@ impl crate::CommandEncoder for super::CommandEncoder { dynamic_offsets: &[wgt::DynamicOffset], ) { let bg_info = &layout.bind_group_infos[group_index as usize]; - - if let Some(ref encoder) = self.state.render { - let mut changes_sizes_buffer = false; - for index in 0..group.counters.vs.buffers { - let buf = &group.buffers[index as usize]; - let mut offset = buf.offset; - if let Some(dyn_index) = buf.dynamic_index { - offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; - } - encoder.set_vertex_buffer( - (bg_info.base_resource_indices.vs.buffers + index) as u64, - Some(buf.ptr.as_native()), - offset, - ); - if let Some(size) = buf.binding_size { - let br = naga::ResourceBinding { - group: group_index, - binding: buf.binding_location, - }; - self.state.storage_buffer_length_map.insert(br, size); - changes_sizes_buffer = true; - } - } - if changes_sizes_buffer { - if let Some((index, sizes)) = self.state.make_sizes_buffer_update( - naga::ShaderStage::Vertex, - &mut self.temp.binding_sizes, - ) { - encoder.set_vertex_bytes( - index as _, - (sizes.len() * WORD_SIZE) as u64, - sizes.as_ptr().cast(), - ); - } - } - - changes_sizes_buffer = false; - for index in 0..group.counters.fs.buffers { - let buf = &group.buffers[(group.counters.vs.buffers + index) as usize]; - let mut offset = buf.offset; - if let Some(dyn_index) = buf.dynamic_index { - offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; - } - encoder.set_fragment_buffer( - (bg_info.base_resource_indices.fs.buffers + index) as u64, - Some(buf.ptr.as_native()), - offset, + let render_encoder = self.state.render.clone(); + let compute_encoder = self.state.compute.clone(); + let mut update_stage = + |stage: naga::ShaderStage, + render_encoder: Option<&metal::RenderCommandEncoder>, + compute_encoder: Option<&metal::ComputeCommandEncoder>, + index_base: super::ResourceData| { + self.update_bind_group_state( + stage, + render_encoder, + compute_encoder, + index_base, + bg_info, + dynamic_offsets, + group_index, + group, ); - if let Some(size) = buf.binding_size { - let br = naga::ResourceBinding { - group: group_index, - binding: buf.binding_location, - }; - self.state.storage_buffer_length_map.insert(br, size); - changes_sizes_buffer = true; - } - } - if changes_sizes_buffer { - if let Some((index, sizes)) = self.state.make_sizes_buffer_update( - naga::ShaderStage::Fragment, - &mut self.temp.binding_sizes, - ) { - encoder.set_fragment_bytes( - index as _, - (sizes.len() * WORD_SIZE) as u64, - sizes.as_ptr().cast(), - ); - } - } - - for index in 0..group.counters.vs.samplers { - let res = group.samplers[index as usize]; - encoder.set_vertex_sampler_state( - (bg_info.base_resource_indices.vs.samplers + index) as u64, - Some(res.as_native()), - ); - } - for index in 0..group.counters.fs.samplers { - let res = group.samplers[(group.counters.vs.samplers + index) as usize]; - encoder.set_fragment_sampler_state( - (bg_info.base_resource_indices.fs.samplers + index) as u64, - Some(res.as_native()), - ); - } - - for index in 0..group.counters.vs.textures { - let res = group.textures[index as usize]; - encoder.set_vertex_texture( - (bg_info.base_resource_indices.vs.textures + index) as u64, - Some(res.as_native()), - ); - } - for index in 0..group.counters.fs.textures { - let res = group.textures[(group.counters.vs.textures + index) as usize]; - encoder.set_fragment_texture( - (bg_info.base_resource_indices.fs.textures + index) as u64, - Some(res.as_native()), - ); - } - + }; + if let Some(encoder) = render_encoder { + update_stage( + naga::ShaderStage::Vertex, + Some(&encoder), + None, + // All zeros, as vs comes first + super::ResourceData::default(), + ); + update_stage( + naga::ShaderStage::Task, + Some(&encoder), + None, + // All zeros, as ts comes first + super::ResourceData::default(), + ); + update_stage( + naga::ShaderStage::Mesh, + Some(&encoder), + None, + group.counters.ts.clone(), + ); + update_stage( + naga::ShaderStage::Fragment, + Some(&encoder), + None, + super::ResourceData { + buffers: group.counters.vs.buffers + + group.counters.ts.buffers + + group.counters.ms.buffers, + textures: group.counters.vs.textures + + group.counters.ts.textures + + group.counters.ms.textures, + samplers: group.counters.vs.samplers + + group.counters.ts.samplers + + group.counters.ms.samplers, + }, + ); // Call useResource on all textures and buffers used indirectly so they are alive for (resource, use_info) in group.resources_to_use.iter() { encoder.use_resource_at(resource.as_native(), use_info.uses, use_info.stages); } } - - if let Some(ref encoder) = self.state.compute { - let index_base = super::ResourceData { - buffers: group.counters.vs.buffers + group.counters.fs.buffers, - samplers: group.counters.vs.samplers + group.counters.fs.samplers, - textures: group.counters.vs.textures + group.counters.fs.textures, - }; - - let mut changes_sizes_buffer = false; - for index in 0..group.counters.cs.buffers { - let buf = &group.buffers[(index_base.buffers + index) as usize]; - let mut offset = buf.offset; - if let Some(dyn_index) = buf.dynamic_index { - offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; - } - encoder.set_buffer( - (bg_info.base_resource_indices.cs.buffers + index) as u64, - Some(buf.ptr.as_native()), - offset, - ); - if let Some(size) = buf.binding_size { - let br = naga::ResourceBinding { - group: group_index, - binding: buf.binding_location, - }; - self.state.storage_buffer_length_map.insert(br, size); - changes_sizes_buffer = true; - } - } - if changes_sizes_buffer { - if let Some((index, sizes)) = self.state.make_sizes_buffer_update( - naga::ShaderStage::Compute, - &mut self.temp.binding_sizes, - ) { - encoder.set_bytes( - index as _, - (sizes.len() * WORD_SIZE) as u64, - sizes.as_ptr().cast(), - ); - } - } - - for index in 0..group.counters.cs.samplers { - let res = group.samplers[(index_base.samplers + index) as usize]; - encoder.set_sampler_state( - (bg_info.base_resource_indices.cs.samplers + index) as u64, - Some(res.as_native()), - ); - } - for index in 0..group.counters.cs.textures { - let res = group.textures[(index_base.textures + index) as usize]; - encoder.set_texture( - (bg_info.base_resource_indices.cs.textures + index) as u64, - Some(res.as_native()), - ); - } - + if let Some(encoder) = compute_encoder { + update_stage( + naga::ShaderStage::Compute, + None, + Some(&encoder), + super::ResourceData { + buffers: group.counters.vs.buffers + + group.counters.ts.buffers + + group.counters.ms.buffers + + group.counters.fs.buffers, + textures: group.counters.vs.textures + + group.counters.ts.textures + + group.counters.ms.textures + + group.counters.fs.textures, + samplers: group.counters.vs.samplers + + group.counters.ts.samplers + + group.counters.ms.samplers + + group.counters.fs.samplers, + }, + ); // Call useResource on all textures and buffers used indirectly so they are alive for (resource, use_info) in group.resources_to_use.iter() { if !use_info.visible_in_compute { @@ -911,6 +949,20 @@ impl crate::CommandEncoder for super::CommandEncoder { state_pc.as_ptr().cast(), ) } + if stages.contains(wgt::ShaderStages::TASK) { + self.state.render.as_ref().unwrap().set_object_bytes( + layout.push_constants_infos.ts.unwrap().buffer_index as _, + (layout.total_push_constants as usize * WORD_SIZE) as _, + state_pc.as_ptr().cast(), + ) + } + if stages.contains(wgt::ShaderStages::MESH) { + self.state.render.as_ref().unwrap().set_object_bytes( + layout.push_constants_infos.ms.unwrap().buffer_index as _, + (layout.total_push_constants as usize * WORD_SIZE) as _, + state_pc.as_ptr().cast(), + ) + } } unsafe fn insert_debug_marker(&mut self, label: &str) { @@ -935,11 +987,22 @@ impl crate::CommandEncoder for super::CommandEncoder { unsafe fn set_render_pipeline(&mut self, pipeline: &super::RenderPipeline) { self.state.raw_primitive_type = pipeline.raw_primitive_type; - self.state.stage_infos.vs.assign_from(&pipeline.vs_info); + match pipeline.vs_info { + Some(ref info) => self.state.stage_infos.vs.assign_from(info), + None => self.state.stage_infos.vs.clear(), + } match pipeline.fs_info { Some(ref info) => self.state.stage_infos.fs.assign_from(info), None => self.state.stage_infos.fs.clear(), } + match pipeline.ts_info { + Some(ref info) => self.state.stage_infos.ts.assign_from(info), + None => self.state.stage_infos.ts.clear(), + } + match pipeline.ms_info { + Some(ref info) => self.state.stage_infos.ms.assign_from(info), + None => self.state.stage_infos.ms.clear(), + } let encoder = self.state.render.as_ref().unwrap(); encoder.set_render_pipeline_state(&pipeline.raw); @@ -954,7 +1017,7 @@ impl crate::CommandEncoder for super::CommandEncoder { encoder.set_depth_bias(bias.constant as f32, bias.slope_scale, bias.clamp); } - { + if pipeline.vs_info.is_some() { if let Some((index, sizes)) = self .state .make_sizes_buffer_update(naga::ShaderStage::Vertex, &mut self.temp.binding_sizes) @@ -966,7 +1029,7 @@ impl crate::CommandEncoder for super::CommandEncoder { ); } } - if pipeline.fs_lib.is_some() { + if pipeline.fs_info.is_some() { if let Some((index, sizes)) = self .state .make_sizes_buffer_update(naga::ShaderStage::Fragment, &mut self.temp.binding_sizes) @@ -978,6 +1041,56 @@ impl crate::CommandEncoder for super::CommandEncoder { ); } } + if let Some(ts_info) = &pipeline.ts_info { + // update the threadgroup memory sizes + while self.state.stage_infos.ms.work_group_memory_sizes.len() + < ts_info.work_group_memory_sizes.len() + { + self.state.stage_infos.ms.work_group_memory_sizes.push(0); + } + for (index, (cur_size, pipeline_size)) in self + .state + .stage_infos + .ms + .work_group_memory_sizes + .iter_mut() + .zip(ts_info.work_group_memory_sizes.iter()) + .enumerate() + { + let size = pipeline_size.next_multiple_of(16); + if *cur_size != size { + *cur_size = size; + encoder.set_object_threadgroup_memory_length(index as _, size as _); + } + } + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(naga::ShaderStage::Task, &mut self.temp.binding_sizes) + { + encoder.set_object_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr().cast(), + ); + } + } + if let Some(_ms_info) = &pipeline.ms_info { + // So there isn't an equivalent to + // https://developer.apple.com/documentation/metal/mtlrendercommandencoder/setthreadgroupmemorylength(_:offset:index:) + // for mesh shaders. This is probably because the CPU has less control over the dispatch sizes and such. Interestingly + // it also affects mesh shaders without task/object shaders, even though none of compute, task or fragment shaders + // behave this way. + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(naga::ShaderStage::Mesh, &mut self.temp.binding_sizes) + { + encoder.set_mesh_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr().cast(), + ); + } + } } unsafe fn set_index_buffer<'a>( @@ -1140,11 +1253,21 @@ impl crate::CommandEncoder for super::CommandEncoder { unsafe fn draw_mesh_tasks( &mut self, - _group_count_x: u32, - _group_count_y: u32, - _group_count_z: u32, + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, ) { - unreachable!() + let encoder = self.state.render.as_ref().unwrap(); + let size = MTLSize { + width: group_count_x as u64, + height: group_count_y as u64, + depth: group_count_z as u64, + }; + encoder.draw_mesh_threadgroups( + size, + self.state.stage_infos.ts.raw_wg_size, + self.state.stage_infos.ms.raw_wg_size, + ); } unsafe fn draw_indirect( @@ -1183,11 +1306,20 @@ impl crate::CommandEncoder for super::CommandEncoder { unsafe fn draw_mesh_tasks_indirect( &mut self, - _buffer: &::Buffer, - _offset: wgt::BufferAddress, - _draw_count: u32, + buffer: &::Buffer, + mut offset: wgt::BufferAddress, + draw_count: u32, ) { - unreachable!() + let encoder = self.state.render.as_ref().unwrap(); + for _ in 0..draw_count { + encoder.draw_mesh_threadgroups_with_indirect_buffer( + &buffer.raw, + offset, + self.state.stage_infos.ts.raw_wg_size, + self.state.stage_infos.ms.raw_wg_size, + ); + offset += size_of::() as wgt::BufferAddress; + } } unsafe fn draw_indirect_count( @@ -1219,7 +1351,7 @@ impl crate::CommandEncoder for super::CommandEncoder { _count_offset: wgt::BufferAddress, _max_count: u32, ) { - unreachable!() + //TODO } // compute @@ -1295,7 +1427,8 @@ impl crate::CommandEncoder for super::CommandEncoder { } unsafe fn set_compute_pipeline(&mut self, pipeline: &super::ComputePipeline) { - self.state.raw_wg_size = pipeline.work_group_size; + let previous_sizes = + core::mem::take(&mut self.state.stage_infos.cs.work_group_memory_sizes); self.state.stage_infos.cs.assign_from(&pipeline.cs_info); let encoder = self.state.compute.as_ref().unwrap(); @@ -1313,20 +1446,23 @@ impl crate::CommandEncoder for super::CommandEncoder { } // update the threadgroup memory sizes - while self.state.work_group_memory_sizes.len() < pipeline.work_group_memory_sizes.len() { - self.state.work_group_memory_sizes.push(0); - } - for (index, (cur_size, pipeline_size)) in self + for (i, current_size) in self .state + .stage_infos + .cs .work_group_memory_sizes .iter_mut() - .zip(pipeline.work_group_memory_sizes.iter()) .enumerate() { - let size = pipeline_size.next_multiple_of(16); - if *cur_size != size { - *cur_size = size; - encoder.set_threadgroup_memory_length(index as _, size as _); + let prev_size = if i < previous_sizes.len() { + previous_sizes[i] + } else { + u32::MAX + }; + let size: u32 = current_size.next_multiple_of(16); + *current_size = size; + if size != prev_size { + encoder.set_threadgroup_memory_length(i as _, size as _); } } } @@ -1339,13 +1475,17 @@ impl crate::CommandEncoder for super::CommandEncoder { height: count[1] as u64, depth: count[2] as u64, }; - encoder.dispatch_thread_groups(raw_count, self.state.raw_wg_size); + encoder.dispatch_thread_groups(raw_count, self.state.stage_infos.cs.raw_wg_size); } } unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) { let encoder = self.state.compute.as_ref().unwrap(); - encoder.dispatch_thread_groups_indirect(&buffer.raw, offset, self.state.raw_wg_size); + encoder.dispatch_thread_groups_indirect( + &buffer.raw, + offset, + self.state.stage_infos.cs.raw_wg_size, + ); } unsafe fn build_acceleration_structures<'a, T>( diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 91533453569..26b4e2239cb 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -18,6 +18,11 @@ use metal::{ type DeviceResult = Result; +enum MetalGenericRenderPipelineDescriptor { + Standard(metal::RenderPipelineDescriptor), + Mesh(metal::MeshRenderPipelineDescriptor), +} + struct CompiledShader { library: metal::Library, function: metal::Function, @@ -1113,96 +1118,243 @@ impl crate::Device for super::Device { super::PipelineCache, >, ) -> Result { - let (desc_vertex_stage, desc_vertex_buffers) = match &desc.vertex_processor { - crate::VertexProcessor::Standard { - vertex_buffers, - vertex_stage, - } => (vertex_stage, *vertex_buffers), - crate::VertexProcessor::Mesh { .. } => unreachable!(), - }; - objc::rc::autoreleasepool(|| { - let descriptor = metal::RenderPipelineDescriptor::new(); - - let raw_triangle_fill_mode = match desc.primitive.polygon_mode { - wgt::PolygonMode::Fill => MTLTriangleFillMode::Fill, - wgt::PolygonMode::Line => MTLTriangleFillMode::Lines, - wgt::PolygonMode::Point => panic!( - "{:?} is not enabled for this backend", - wgt::Features::POLYGON_MODE_POINT - ), - }; - let (primitive_class, raw_primitive_type) = conv::map_primitive_topology(desc.primitive.topology); - // Vertex shader - let (vs_lib, vs_info) = { - let mut vertex_buffer_mappings = Vec::::new(); - for (i, vbl) in desc_vertex_buffers.iter().enumerate() { - let mut attributes = Vec::::new(); - for attribute in vbl.attributes.iter() { - attributes.push(naga::back::msl::AttributeMapping { - shader_location: attribute.shader_location, - offset: attribute.offset as u32, - format: convert_vertex_format_to_naga(attribute.format), + let vs_info; + let ts_info; + let ms_info; + + // Create the pipeline descriptor and do vertex/mesh pipeline specific setup + let descriptor = match desc.vertex_processor { + crate::VertexProcessor::Standard { + vertex_buffers, + ref vertex_stage, + } => { + // Vertex pipeline specific setup + + let descriptor = metal::RenderPipelineDescriptor::new(); + ts_info = None; + ms_info = None; + + // Collect vertex buffer mappings + let mut vertex_buffer_mappings = + Vec::::new(); + for (i, vbl) in vertex_buffers.iter().enumerate() { + let mut attributes = Vec::::new(); + for attribute in vbl.attributes.iter() { + attributes.push(naga::back::msl::AttributeMapping { + shader_location: attribute.shader_location, + offset: attribute.offset as u32, + format: convert_vertex_format_to_naga(attribute.format), + }); + } + + let mapping = naga::back::msl::VertexBufferMapping { + id: self.shared.private_caps.max_vertex_buffers - 1 - i as u32, + stride: if vbl.array_stride > 0 { + vbl.array_stride.try_into().unwrap() + } else { + vbl.attributes + .iter() + .map(|attribute| attribute.offset + attribute.format.size()) + .max() + .unwrap_or(0) + .try_into() + .unwrap() + }, + step_mode: match (vbl.array_stride == 0, vbl.step_mode) { + (true, _) => naga::back::msl::VertexBufferStepMode::Constant, + (false, wgt::VertexStepMode::Vertex) => { + naga::back::msl::VertexBufferStepMode::ByVertex + } + (false, wgt::VertexStepMode::Instance) => { + naga::back::msl::VertexBufferStepMode::ByInstance + } + }, + attributes, + }; + vertex_buffer_mappings.push(mapping); + } + + // Setup vertex shader + { + let vs = self.load_shader( + vertex_stage, + &vertex_buffer_mappings, + desc.layout, + primitive_class, + naga::ShaderStage::Vertex, + )?; + + descriptor.set_vertex_function(Some(&vs.function)); + if self.shared.private_caps.supports_mutability { + Self::set_buffers_mutability( + descriptor.vertex_buffers().unwrap(), + vs.immutable_buffer_mask, + ); + } + + vs_info = Some(super::PipelineStageInfo { + push_constants: desc.layout.push_constants_infos.vs, + sizes_slot: desc.layout.per_stage_map.vs.sizes_buffer, + sized_bindings: vs.sized_bindings, + vertex_buffer_mappings, + library: Some(vs.library), + raw_wg_size: Default::default(), + work_group_memory_sizes: vec![], }); } - vertex_buffer_mappings.push(naga::back::msl::VertexBufferMapping { - id: self.shared.private_caps.max_vertex_buffers - 1 - i as u32, - stride: if vbl.array_stride > 0 { - vbl.array_stride.try_into().unwrap() - } else { - vbl.attributes - .iter() - .map(|attribute| attribute.offset + attribute.format.size()) - .max() - .unwrap_or(0) - .try_into() - .unwrap() - }, - step_mode: match (vbl.array_stride == 0, vbl.step_mode) { - (true, _) => naga::back::msl::VertexBufferStepMode::Constant, - (false, wgt::VertexStepMode::Vertex) => { - naga::back::msl::VertexBufferStepMode::ByVertex + // Validate vertex buffer count + if desc.layout.total_counters.vs.buffers + (vertex_buffers.len() as u32) + > self.shared.private_caps.max_vertex_buffers + { + let msg = format!( + "pipeline needs too many buffers in the vertex stage: {} vertex and {} layout", + vertex_buffers.len(), + desc.layout.total_counters.vs.buffers + ); + return Err(crate::PipelineError::Linkage( + wgt::ShaderStages::VERTEX, + msg, + )); + } + + // Set the pipeline vertex buffer info + if !vertex_buffers.is_empty() { + let vertex_descriptor = metal::VertexDescriptor::new(); + for (i, vb) in vertex_buffers.iter().enumerate() { + let buffer_index = + self.shared.private_caps.max_vertex_buffers as u64 - 1 - i as u64; + let buffer_desc = + vertex_descriptor.layouts().object_at(buffer_index).unwrap(); + + // Metal expects the stride to be the actual size of the attributes. + // The semantics of array_stride == 0 can be achieved by setting + // the step function to constant and rate to 0. + if vb.array_stride == 0 { + let stride = vb + .attributes + .iter() + .map(|attribute| attribute.offset + attribute.format.size()) + .max() + .unwrap_or(0); + buffer_desc.set_stride(wgt::math::align_to(stride, 4)); + buffer_desc.set_step_function(MTLVertexStepFunction::Constant); + buffer_desc.set_step_rate(0); + } else { + buffer_desc.set_stride(vb.array_stride); + buffer_desc.set_step_function(conv::map_step_mode(vb.step_mode)); } - (false, wgt::VertexStepMode::Instance) => { - naga::back::msl::VertexBufferStepMode::ByInstance + + for at in vb.attributes { + let attribute_desc = vertex_descriptor + .attributes() + .object_at(at.shader_location as u64) + .unwrap(); + attribute_desc.set_format(conv::map_vertex_format(at.format)); + attribute_desc.set_buffer_index(buffer_index); + attribute_desc.set_offset(at.offset); } - }, - attributes, - }); + } + descriptor.set_vertex_descriptor(Some(vertex_descriptor)); + } + + MetalGenericRenderPipelineDescriptor::Standard(descriptor) } + crate::VertexProcessor::Mesh { + ref task_stage, + ref mesh_stage, + } => { + // Mesh pipeline specific setup + + vs_info = None; + let descriptor = metal::MeshRenderPipelineDescriptor::new(); + + // Setup task stage + if let Some(ref task_stage) = task_stage { + let ts = self.load_shader( + task_stage, + &[], + desc.layout, + primitive_class, + naga::ShaderStage::Task, + )?; + descriptor.set_object_function(Some(&ts.function)); + if self.shared.private_caps.supports_mutability { + Self::set_buffers_mutability( + descriptor.mesh_buffers().unwrap(), + ts.immutable_buffer_mask, + ); + } + ts_info = Some(super::PipelineStageInfo { + push_constants: desc.layout.push_constants_infos.ts, + sizes_slot: desc.layout.per_stage_map.ts.sizes_buffer, + sized_bindings: ts.sized_bindings, + vertex_buffer_mappings: vec![], + library: Some(ts.library), + raw_wg_size: ts.wg_size, + work_group_memory_sizes: ts.wg_memory_sizes, + }); + } else { + ts_info = None; + } - let vs = self.load_shader( - desc_vertex_stage, - &vertex_buffer_mappings, - desc.layout, - primitive_class, - naga::ShaderStage::Vertex, - )?; - - descriptor.set_vertex_function(Some(&vs.function)); - if self.shared.private_caps.supports_mutability { - Self::set_buffers_mutability( - descriptor.vertex_buffers().unwrap(), - vs.immutable_buffer_mask, - ); + // Setup mesh stage + { + let ms = self.load_shader( + mesh_stage, + &[], + desc.layout, + primitive_class, + naga::ShaderStage::Mesh, + )?; + descriptor.set_mesh_function(Some(&ms.function)); + if self.shared.private_caps.supports_mutability { + Self::set_buffers_mutability( + descriptor.mesh_buffers().unwrap(), + ms.immutable_buffer_mask, + ); + } + ms_info = Some(super::PipelineStageInfo { + push_constants: desc.layout.push_constants_infos.ms, + sizes_slot: desc.layout.per_stage_map.ms.sizes_buffer, + sized_bindings: ms.sized_bindings, + vertex_buffer_mappings: vec![], + library: Some(ms.library), + raw_wg_size: ms.wg_size, + work_group_memory_sizes: ms.wg_memory_sizes, + }); + } + + MetalGenericRenderPipelineDescriptor::Mesh(descriptor) } + }; - let info = super::PipelineStageInfo { - push_constants: desc.layout.push_constants_infos.vs, - sizes_slot: desc.layout.per_stage_map.vs.sizes_buffer, - sized_bindings: vs.sized_bindings, - vertex_buffer_mappings, + // Standard and mesh render pipeline descriptors don't inherit from the same interface, despite sharing + // many methods. This function lets us call a function by name on whichever descriptor we are using. + macro_rules! descriptor_fn { + ($method:ident $( ( $($args:expr),* ) )? ) => { + match descriptor { + MetalGenericRenderPipelineDescriptor::Standard(ref inner) => inner.$method$(($($args),*))?, + MetalGenericRenderPipelineDescriptor::Mesh(ref inner) => inner.$method$(($($args),*))?, + } }; + } - (vs.library, info) + let raw_triangle_fill_mode = match desc.primitive.polygon_mode { + wgt::PolygonMode::Fill => MTLTriangleFillMode::Fill, + wgt::PolygonMode::Line => MTLTriangleFillMode::Lines, + wgt::PolygonMode::Point => panic!( + "{:?} is not enabled for this backend", + wgt::Features::POLYGON_MODE_POINT + ), }; // Fragment shader - let (fs_lib, fs_info) = match desc.fragment_stage { + let fs_info = match desc.fragment_stage { Some(ref stage) => { let fs = self.load_shader( stage, @@ -1212,35 +1364,41 @@ impl crate::Device for super::Device { naga::ShaderStage::Fragment, )?; - descriptor.set_fragment_function(Some(&fs.function)); + descriptor_fn!(set_fragment_function(Some(&fs.function))); if self.shared.private_caps.supports_mutability { Self::set_buffers_mutability( - descriptor.fragment_buffers().unwrap(), + descriptor_fn!(fragment_buffers()).unwrap(), fs.immutable_buffer_mask, ); } - let info = super::PipelineStageInfo { + Some(super::PipelineStageInfo { push_constants: desc.layout.push_constants_infos.fs, sizes_slot: desc.layout.per_stage_map.fs.sizes_buffer, sized_bindings: fs.sized_bindings, vertex_buffer_mappings: vec![], - }; - - (Some(fs.library), Some(info)) + library: Some(fs.library), + raw_wg_size: Default::default(), + work_group_memory_sizes: vec![], + }) } None => { // TODO: This is a workaround for what appears to be a Metal validation bug // A pixel format is required even though no attachments are provided if desc.color_targets.is_empty() && desc.depth_stencil.is_none() { - descriptor.set_depth_attachment_pixel_format(MTLPixelFormat::Depth32Float); + descriptor_fn!(set_depth_attachment_pixel_format( + MTLPixelFormat::Depth32Float + )); } - (None, None) + None } }; + // Setup pipeline color attachments for (i, ct) in desc.color_targets.iter().enumerate() { - let at_descriptor = descriptor.color_attachments().object_at(i as u64).unwrap(); + let at_descriptor = descriptor_fn!(color_attachments()) + .object_at(i as u64) + .unwrap(); let ct = if let Some(color_target) = ct.as_ref() { color_target } else { @@ -1267,15 +1425,16 @@ impl crate::Device for super::Device { } } + // Setup depth stencil state let depth_stencil = match desc.depth_stencil { Some(ref ds) => { let raw_format = self.shared.private_caps.map_format(ds.format); let aspects = crate::FormatAspects::from(ds.format); if aspects.contains(crate::FormatAspects::DEPTH) { - descriptor.set_depth_attachment_pixel_format(raw_format); + descriptor_fn!(set_depth_attachment_pixel_format(raw_format)); } if aspects.contains(crate::FormatAspects::STENCIL) { - descriptor.set_stencil_attachment_pixel_format(raw_format); + descriptor_fn!(set_stencil_attachment_pixel_format(raw_format)); } let ds_descriptor = create_depth_stencil_desc(ds); @@ -1289,94 +1448,57 @@ impl crate::Device for super::Device { None => None, }; - if desc.layout.total_counters.vs.buffers + (desc_vertex_buffers.len() as u32) - > self.shared.private_caps.max_vertex_buffers - { - let msg = format!( - "pipeline needs too many buffers in the vertex stage: {} vertex and {} layout", - desc_vertex_buffers.len(), - desc.layout.total_counters.vs.buffers - ); - return Err(crate::PipelineError::Linkage( - wgt::ShaderStages::VERTEX, - msg, - )); - } - - if !desc_vertex_buffers.is_empty() { - let vertex_descriptor = metal::VertexDescriptor::new(); - for (i, vb) in desc_vertex_buffers.iter().enumerate() { - let buffer_index = - self.shared.private_caps.max_vertex_buffers as u64 - 1 - i as u64; - let buffer_desc = vertex_descriptor.layouts().object_at(buffer_index).unwrap(); - - // Metal expects the stride to be the actual size of the attributes. - // The semantics of array_stride == 0 can be achieved by setting - // the step function to constant and rate to 0. - if vb.array_stride == 0 { - let stride = vb - .attributes - .iter() - .map(|attribute| attribute.offset + attribute.format.size()) - .max() - .unwrap_or(0); - buffer_desc.set_stride(wgt::math::align_to(stride, 4)); - buffer_desc.set_step_function(MTLVertexStepFunction::Constant); - buffer_desc.set_step_rate(0); - } else { - buffer_desc.set_stride(vb.array_stride); - buffer_desc.set_step_function(conv::map_step_mode(vb.step_mode)); + // Setup multisample state + if desc.multisample.count != 1 { + //TODO: handle sample mask + match descriptor { + MetalGenericRenderPipelineDescriptor::Standard(ref inner) => { + inner.set_sample_count(desc.multisample.count as u64); } - - for at in vb.attributes { - let attribute_desc = vertex_descriptor - .attributes() - .object_at(at.shader_location as u64) - .unwrap(); - attribute_desc.set_format(conv::map_vertex_format(at.format)); - attribute_desc.set_buffer_index(buffer_index); - attribute_desc.set_offset(at.offset); + MetalGenericRenderPipelineDescriptor::Mesh(ref inner) => { + inner.set_raster_sample_count(desc.multisample.count as u64); } } - descriptor.set_vertex_descriptor(Some(vertex_descriptor)); - } - - if desc.multisample.count != 1 { - //TODO: handle sample mask - descriptor.set_sample_count(desc.multisample.count as u64); - descriptor - .set_alpha_to_coverage_enabled(desc.multisample.alpha_to_coverage_enabled); + descriptor_fn!(set_alpha_to_coverage_enabled( + desc.multisample.alpha_to_coverage_enabled + )); //descriptor.set_alpha_to_one_enabled(desc.multisample.alpha_to_one_enabled); } + // Set debug label if let Some(name) = desc.label { - descriptor.set_label(name); + descriptor_fn!(set_label(name)); } - if let Some(mv) = desc.multiview_mask { - descriptor.set_max_vertex_amplification_count(mv.get().count_ones() as u64); + descriptor_fn!(set_max_vertex_amplification_count( + mv.get().count_ones() as u64 + )); } - let raw = self - .shared - .device - .lock() - .new_render_pipeline_state(&descriptor) - .map_err(|e| { - crate::PipelineError::Linkage( - wgt::ShaderStages::VERTEX | wgt::ShaderStages::FRAGMENT, - format!("new_render_pipeline_state: {e:?}"), - ) - })?; + // Create the pipeline from descriptor + let raw = match descriptor { + MetalGenericRenderPipelineDescriptor::Standard(d) => { + self.shared.device.lock().new_render_pipeline_state(&d) + } + MetalGenericRenderPipelineDescriptor::Mesh(d) => { + self.shared.device.lock().new_mesh_render_pipeline_state(&d) + } + } + .map_err(|e| { + crate::PipelineError::Linkage( + wgt::ShaderStages::VERTEX | wgt::ShaderStages::FRAGMENT, + format!("new_render_pipeline_state: {e:?}"), + ) + })?; self.counters.render_pipelines.add(1); Ok(super::RenderPipeline { raw, - vs_lib, - fs_lib, vs_info, fs_info, + ts_info, + ms_info, raw_primitive_type, raw_triangle_fill_mode, raw_front_winding: conv::map_winding(desc.primitive.front_face), @@ -1444,10 +1566,13 @@ impl crate::Device for super::Device { } let cs_info = super::PipelineStageInfo { + library: Some(cs.library), push_constants: desc.layout.push_constants_infos.cs, sizes_slot: desc.layout.per_stage_map.cs.sizes_buffer, sized_bindings: cs.sized_bindings, vertex_buffer_mappings: vec![], + raw_wg_size: cs.wg_size, + work_group_memory_sizes: cs.wg_memory_sizes, }; if let Some(name) = desc.label { @@ -1468,13 +1593,7 @@ impl crate::Device for super::Device { self.counters.compute_pipelines.add(1); - Ok(super::ComputePipeline { - raw, - cs_info, - cs_lib: cs.library, - work_group_size: cs.wg_size, - work_group_memory_sizes: cs.wg_memory_sizes, - }) + Ok(super::ComputePipeline { raw, cs_info }) }) } diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index e23246a6a75..7258a885f25 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -302,6 +302,7 @@ struct PrivateCapabilities { int64_atomics: bool, float_atomics: bool, supports_shared_event: bool, + mesh_shaders: bool, supported_vertex_amplification_factor: u32, shader_barycentrics: bool, supports_memoryless_storage: bool, @@ -609,12 +610,16 @@ struct MultiStageData { vs: T, fs: T, cs: T, + ts: T, + ms: T, } const NAGA_STAGES: MultiStageData = MultiStageData { vs: naga::ShaderStage::Vertex, fs: naga::ShaderStage::Fragment, cs: naga::ShaderStage::Compute, + ts: naga::ShaderStage::Task, + ms: naga::ShaderStage::Mesh, }; impl ops::Index for MultiStageData { @@ -624,7 +629,8 @@ impl ops::Index for MultiStageData { naga::ShaderStage::Vertex => &self.vs, naga::ShaderStage::Fragment => &self.fs, naga::ShaderStage::Compute => &self.cs, - naga::ShaderStage::Task | naga::ShaderStage::Mesh => unreachable!(), + naga::ShaderStage::Task => &self.ts, + naga::ShaderStage::Mesh => &self.ms, } } } @@ -635,6 +641,8 @@ impl MultiStageData { vs: fun(&self.vs), fs: fun(&self.fs), cs: fun(&self.cs), + ts: fun(&self.ts), + ms: fun(&self.ms), } } fn map(self, fun: impl Fn(T) -> Y) -> MultiStageData { @@ -642,17 +650,23 @@ impl MultiStageData { vs: fun(self.vs), fs: fun(self.fs), cs: fun(self.cs), + ts: fun(self.ts), + ms: fun(self.ms), } } fn iter<'a>(&'a self) -> impl Iterator { iter::once(&self.vs) .chain(iter::once(&self.fs)) .chain(iter::once(&self.cs)) + .chain(iter::once(&self.ts)) + .chain(iter::once(&self.ms)) } fn iter_mut<'a>(&'a mut self) -> impl Iterator { iter::once(&mut self.vs) .chain(iter::once(&mut self.fs)) .chain(iter::once(&mut self.cs)) + .chain(iter::once(&mut self.ts)) + .chain(iter::once(&mut self.ms)) } } @@ -816,6 +830,8 @@ impl crate::DynShaderModule for ShaderModule {} #[derive(Debug, Default)] struct PipelineStageInfo { + #[allow(dead_code)] + library: Option, push_constants: Option, /// The buffer argument table index at which we pass runtime-sized arrays' buffer sizes. @@ -830,6 +846,12 @@ struct PipelineStageInfo { /// Info on all bound vertex buffers. vertex_buffer_mappings: Vec, + + /// The workgroup size for compute, task or mesh stages + raw_wg_size: MTLSize, + + /// The workgroup memory sizes for compute task or mesh stages + work_group_memory_sizes: Vec, } impl PipelineStageInfo { @@ -838,6 +860,9 @@ impl PipelineStageInfo { self.sizes_slot = None; self.sized_bindings.clear(); self.vertex_buffer_mappings.clear(); + self.library = None; + self.work_group_memory_sizes.clear(); + self.raw_wg_size = Default::default(); } fn assign_from(&mut self, other: &Self) { @@ -848,18 +873,21 @@ impl PipelineStageInfo { self.vertex_buffer_mappings.clear(); self.vertex_buffer_mappings .extend_from_slice(&other.vertex_buffer_mappings); + self.library = Some(other.library.as_ref().unwrap().clone()); + self.raw_wg_size = other.raw_wg_size; + self.work_group_memory_sizes.clear(); + self.work_group_memory_sizes + .extend_from_slice(&other.work_group_memory_sizes); } } #[derive(Debug)] pub struct RenderPipeline { raw: metal::RenderPipelineState, - #[allow(dead_code)] - vs_lib: metal::Library, - #[allow(dead_code)] - fs_lib: Option, - vs_info: PipelineStageInfo, + vs_info: Option, fs_info: Option, + ts_info: Option, + ms_info: Option, raw_primitive_type: MTLPrimitiveType, raw_triangle_fill_mode: MTLTriangleFillMode, raw_front_winding: MTLWinding, @@ -876,11 +904,7 @@ impl crate::DynRenderPipeline for RenderPipeline {} #[derive(Debug)] pub struct ComputePipeline { raw: metal::ComputePipelineState, - #[allow(dead_code)] - cs_lib: metal::Library, cs_info: PipelineStageInfo, - work_group_size: MTLSize, - work_group_memory_sizes: Vec, } unsafe impl Send for ComputePipeline {} @@ -954,7 +978,6 @@ struct CommandState { compute: Option, raw_primitive_type: MTLPrimitiveType, index: Option, - raw_wg_size: MTLSize, stage_infos: MultiStageData, /// Sizes of currently bound [`wgt::BufferBindingType::Storage`] buffers. @@ -980,7 +1003,6 @@ struct CommandState { vertex_buffer_size_map: FastHashMap, - work_group_memory_sizes: Vec, push_constants: Vec, /// Timer query that should be executed when the next pass starts. diff --git a/wgpu-types/src/features.rs b/wgpu-types/src/features.rs index 234d01d8544..e53b2b8468f 100644 --- a/wgpu-types/src/features.rs +++ b/wgpu-types/src/features.rs @@ -1169,12 +1169,11 @@ bitflags_array! { /// This is a native only feature. const UNIFORM_BUFFER_BINDING_ARRAYS = 1 << 47; - /// Enables mesh shaders and task shaders in mesh shader pipelines. + /// Enables mesh shaders and task shaders in mesh shader pipelines. This extension does NOT imply support for + /// compiling mesh shaders at runtime. Rather, the user must use custom passthrough shaders. /// /// Supported platforms: /// - Vulkan (with [VK_EXT_mesh_shader](https://registry.khronos.org/vulkan/specs/latest/man/html/VK_EXT_mesh_shader.html)) - /// - /// Potential Platforms: /// - DX12 /// - Metal /// diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index 33d2bf1eb71..5b02a27e64a 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -1057,14 +1057,12 @@ impl Limits { #[must_use] pub const fn using_recommended_minimum_mesh_shader_values(self) -> Self { Self { - // Literally just made this up as 256^2 or 2^16. - // My GPU supports 2^22, and compute shaders don't have this kind of limit. - // This very likely is never a real limiter - max_task_workgroup_total_count: 65536, - max_task_workgroups_per_dimension: 256, + // This is a common limit for apple devices. It's not immediately clear why. + max_task_workgroup_total_count: 1024, + max_task_workgroups_per_dimension: 1024, // llvmpipe reports 0 multiview count, which just means no multiview is allowed max_mesh_multiview_view_count: 0, - // llvmpipe once again requires this to be 8. An RTX 3060 supports well over 1024. + // llvmpipe once again requires this to be <=8. An RTX 3060 supports well over 1024. max_mesh_output_layers: 8, ..self }