Skip to content

Commit

Permalink
Subgroup Operations (#5301)
Browse files Browse the repository at this point in the history
Co-authored-by: Jacob Hughes <j@distanthills.org>
Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
Co-authored-by: atlas dostal <rodol@rivalrebels.com>
  • Loading branch information
4 people committed Apr 17, 2024
1 parent 0dc9dd6 commit ea77d56
Show file tree
Hide file tree
Showing 64 changed files with 3,328 additions and 70 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ Bottom level categories:
### Bug Fixes

#### General
- Add `SUBGROUP, SUBGROUP_VERTEX, SUBGROUP_BARRIER` features. By @exrook and @lichtso in [#5301](https://github.com/gfx-rs/wgpu/pull/5301)
- Fix `serde` feature not compiling for `wgpu-types`. By @KirmesBude in [#5149](https://github.com/gfx-rs/wgpu/pull/5149)
- Fix the validation of vertex and index ranges. By @nical in [#5144](https://github.com/gfx-rs/wgpu/pull/5144) and [#5156](https://github.com/gfx-rs/wgpu/pull/5156)
- Fix panic when creating a surface while no backend is available. By @wumpf [#5166](https://github.com/gfx-rs/wgpu/pull/5166)
Expand Down
4 changes: 4 additions & 0 deletions naga-cli/src/bin/naga.rs
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,8 @@ fn run() -> Result<(), Box<dyn std::error::Error>> {

// Validate the IR before compaction.
let info = match naga::valid::Validator::new(params.validation_flags, validation_caps)
.subgroup_stages(naga::valid::ShaderStages::all())
.subgroup_operations(naga::valid::SubgroupOperationSet::all())
.validate(&module)
{
Ok(info) => Some(info),
Expand Down Expand Up @@ -760,6 +762,8 @@ fn bulk_validate(args: Args, params: &Parameters) -> Result<(), Box<dyn std::err

let mut validator =
naga::valid::Validator::new(params.validation_flags, naga::valid::Capabilities::all());
validator.subgroup_stages(naga::valid::ShaderStages::all());
validator.subgroup_operations(naga::valid::SubgroupOperationSet::all());

if let Err(error) = validator.validate(&module) {
invalid.push(input_path.clone());
Expand Down
90 changes: 90 additions & 0 deletions naga/src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,94 @@ impl StatementGraph {
crate::RayQueryFunction::Terminate => "RayQueryTerminate",
}
}
S::SubgroupBallot { result, predicate } => {
if let Some(predicate) = predicate {
self.dependencies.push((id, predicate, "predicate"));
}
self.emits.push((id, result));
"SubgroupBallot"
}
S::SubgroupCollectiveOperation {
op,
collective_op,
argument,
result,
} => {
self.dependencies.push((id, argument, "arg"));
self.emits.push((id, result));
match (collective_op, op) {
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
"SubgroupAll"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
"SubgroupAny"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
"SubgroupAdd"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
"SubgroupMul"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
"SubgroupMax"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
"SubgroupMin"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
"SubgroupAnd"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
"SubgroupOr"
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
"SubgroupXor"
}
(
crate::CollectiveOperation::ExclusiveScan,
crate::SubgroupOperation::Add,
) => "SubgroupExclusiveAdd",
(
crate::CollectiveOperation::ExclusiveScan,
crate::SubgroupOperation::Mul,
) => "SubgroupExclusiveMul",
(
crate::CollectiveOperation::InclusiveScan,
crate::SubgroupOperation::Add,
) => "SubgroupInclusiveAdd",
(
crate::CollectiveOperation::InclusiveScan,
crate::SubgroupOperation::Mul,
) => "SubgroupInclusiveMul",
_ => unimplemented!(),
}
}
S::SubgroupGather {
mode,
argument,
result,
} => {
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
self.dependencies.push((id, index, "index"))
}
}
self.dependencies.push((id, argument, "arg"));
self.emits.push((id, result));
match mode {
crate::GatherMode::BroadcastFirst => "SubgroupBroadcastFirst",
crate::GatherMode::Broadcast(_) => "SubgroupBroadcast",
crate::GatherMode::Shuffle(_) => "SubgroupShuffle",
crate::GatherMode::ShuffleDown(_) => "SubgroupShuffleDown",
crate::GatherMode::ShuffleUp(_) => "SubgroupShuffleUp",
crate::GatherMode::ShuffleXor(_) => "SubgroupShuffleXor",
}
}
};
// Set the last node to the merge node
last_node = merge_id;
Expand Down Expand Up @@ -587,6 +675,8 @@ fn write_function_expressions(
let ty = if committed { "Committed" } else { "Candidate" };
(format!("rayQueryGet{}Intersection", ty).into(), 4)
}
E::SubgroupBallotResult => ("SubgroupBallotResult".into(), 4),
E::SubgroupOperationResult { .. } => ("SubgroupOperationResult".into(), 4),
};

// give uniform expressions an outline
Expand Down
23 changes: 23 additions & 0 deletions naga/src/back/glsl/features.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ bitflags::bitflags! {
const INSTANCE_INDEX = 1 << 22;
/// Sample specific LODs of cube / array shadow textures
const TEXTURE_SHADOW_LOD = 1 << 23;
/// Subgroup operations
const SUBGROUP_OPERATIONS = 1 << 24;
}
}

Expand Down Expand Up @@ -117,6 +119,7 @@ impl FeaturesManager {
check_feature!(SAMPLE_VARIABLES, 400, 300);
check_feature!(DYNAMIC_ARRAY_SIZE, 430, 310);
check_feature!(DUAL_SOURCE_BLENDING, 330, 300 /* with extension */);
check_feature!(SUBGROUP_OPERATIONS, 430, 310);
match version {
Version::Embedded { is_webgl: true, .. } => check_feature!(MULTI_VIEW, 140, 300),
_ => check_feature!(MULTI_VIEW, 140, 310),
Expand Down Expand Up @@ -259,6 +262,22 @@ impl FeaturesManager {
writeln!(out, "#extension GL_EXT_texture_shadow_lod : require")?;
}

if self.0.contains(Features::SUBGROUP_OPERATIONS) {
// https://registry.khronos.org/OpenGL/extensions/KHR/KHR_shader_subgroup.txt
writeln!(out, "#extension GL_KHR_shader_subgroup_basic : require")?;
writeln!(out, "#extension GL_KHR_shader_subgroup_vote : require")?;
writeln!(
out,
"#extension GL_KHR_shader_subgroup_arithmetic : require"
)?;
writeln!(out, "#extension GL_KHR_shader_subgroup_ballot : require")?;
writeln!(out, "#extension GL_KHR_shader_subgroup_shuffle : require")?;
writeln!(
out,
"#extension GL_KHR_shader_subgroup_shuffle_relative : require"
)?;
}

Ok(())
}
}
Expand Down Expand Up @@ -518,6 +537,10 @@ impl<'a, W> Writer<'a, W> {
}
}
}
Expression::SubgroupBallotResult |
Expression::SubgroupOperationResult { .. } => {
features.request(Features::SUBGROUP_OPERATIONS)
}
_ => {}
}
}
Expand Down
131 changes: 130 additions & 1 deletion naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2390,6 +2390,125 @@ impl<'a, W: Write> Writer<'a, W> {
writeln!(self.out, ");")?;
}
Statement::RayQuery { .. } => unreachable!(),
Statement::SubgroupBallot { result, predicate } => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);

write!(self.out, "subgroupBallot(")?;
match predicate {
Some(predicate) => self.write_expr(predicate, ctx)?,
None => write!(self.out, "true")?,
}
writeln!(self.out, ");")?;
}
Statement::SubgroupCollectiveOperation {
op,
collective_op,
argument,
result,
} => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);

match (collective_op, op) {
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::All) => {
write!(self.out, "subgroupAll(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Any) => {
write!(self.out, "subgroupAny(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupAdd(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupMul(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Max) => {
write!(self.out, "subgroupMax(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Min) => {
write!(self.out, "subgroupMin(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::And) => {
write!(self.out, "subgroupAnd(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Or) => {
write!(self.out, "subgroupOr(")?
}
(crate::CollectiveOperation::Reduce, crate::SubgroupOperation::Xor) => {
write!(self.out, "subgroupXor(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupExclusiveAdd(")?
}
(crate::CollectiveOperation::ExclusiveScan, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupExclusiveMul(")?
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Add) => {
write!(self.out, "subgroupInclusiveAdd(")?
}
(crate::CollectiveOperation::InclusiveScan, crate::SubgroupOperation::Mul) => {
write!(self.out, "subgroupInclusiveMul(")?
}
_ => unimplemented!(),
}
self.write_expr(argument, ctx)?;
writeln!(self.out, ");")?;
}
Statement::SubgroupGather {
mode,
argument,
result,
} => {
write!(self.out, "{level}")?;
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
let res_ty = ctx.info[result].ty.inner_with(&self.module.types);
self.write_value_type(res_ty)?;
write!(self.out, " {res_name} = ")?;
self.named_expressions.insert(result, res_name);

match mode {
crate::GatherMode::BroadcastFirst => {
write!(self.out, "subgroupBroadcastFirst(")?;
}
crate::GatherMode::Broadcast(_) => {
write!(self.out, "subgroupBroadcast(")?;
}
crate::GatherMode::Shuffle(_) => {
write!(self.out, "subgroupShuffle(")?;
}
crate::GatherMode::ShuffleDown(_) => {
write!(self.out, "subgroupShuffleDown(")?;
}
crate::GatherMode::ShuffleUp(_) => {
write!(self.out, "subgroupShuffleUp(")?;
}
crate::GatherMode::ShuffleXor(_) => {
write!(self.out, "subgroupShuffleXor(")?;
}
}
self.write_expr(argument, ctx)?;
match mode {
crate::GatherMode::BroadcastFirst => {}
crate::GatherMode::Broadcast(index)
| crate::GatherMode::Shuffle(index)
| crate::GatherMode::ShuffleDown(index)
| crate::GatherMode::ShuffleUp(index)
| crate::GatherMode::ShuffleXor(index) => {
write!(self.out, ", ")?;
self.write_expr(index, ctx)?;
}
}
writeln!(self.out, ");")?;
}
}

Ok(())
Expand Down Expand Up @@ -3658,7 +3777,9 @@ impl<'a, W: Write> Writer<'a, W> {
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult
| Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
| Expression::WorkGroupUniformLoadResult { .. }
| Expression::SubgroupOperationResult { .. }
| Expression::SubgroupBallotResult => unreachable!(),
// `ArrayLength` is written as `expr.length()` and we convert it to a uint
Expression::ArrayLength(expr) => {
write!(self.out, "uint(")?;
Expand Down Expand Up @@ -4227,6 +4348,9 @@ impl<'a, W: Write> Writer<'a, W> {
if flags.contains(crate::Barrier::WORK_GROUP) {
writeln!(self.out, "{level}memoryBarrierShared();")?;
}
if flags.contains(crate::Barrier::SUB_GROUP) {
writeln!(self.out, "{level}subgroupMemoryBarrier();")?;
}
writeln!(self.out, "{level}barrier();")?;
Ok(())
}
Expand Down Expand Up @@ -4496,6 +4620,11 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s
Bi::WorkGroupId => "gl_WorkGroupID",
Bi::WorkGroupSize => "gl_WorkGroupSize",
Bi::NumWorkGroups => "gl_NumWorkGroups",
// subgroup
Bi::NumSubgroups => "gl_NumSubgroups",
Bi::SubgroupId => "gl_SubgroupID",
Bi::SubgroupSize => "gl_SubgroupSize",
Bi::SubgroupInvocationId => "gl_SubgroupInvocationID",
}
}

Expand Down
5 changes: 5 additions & 0 deletions naga/src/back/hlsl/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ impl crate::BuiltIn {
// to this field will get replaced with references to `SPECIAL_CBUF_VAR`
// in `Writer::write_expr`.
Self::NumWorkGroups => "SV_GroupID",
// These builtins map to functions
Self::SubgroupSize
| Self::SubgroupInvocationId
| Self::NumSubgroups
| Self::SubgroupId => unreachable!(),
Self::BaseInstance | Self::BaseVertex | Self::WorkGroupSize => {
return Err(Error::Unimplemented(format!("builtin {self:?}")))
}
Expand Down
Loading

0 comments on commit ea77d56

Please sign in to comment.