Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the polyfill for workgroup variable zero initialization #5521

Draft
wants to merge 12 commits into
base: trunk
Choose a base branch
from
92 changes: 71 additions & 21 deletions naga/src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,13 @@ use crate::{
use features::FeaturesManager;
use std::{
cmp::Ordering,
fmt,
fmt::{Error as FmtError, Write},
fmt::{self, Error as FmtError, Write},
mem,
};
use thiserror::Error;

use super::zero_init;

/// Contains the features related code and the features querying method
mod features;
/// Contains a constant with a slice of all the reserved keywords RESERVED_KEYWORDS
Expand Down Expand Up @@ -1685,10 +1686,10 @@ impl<'a, W: Write> Writer<'a, W> {
// Close the parentheses and open braces to start the function body
writeln!(self.out, ") {{")?;

if self.options.zero_initialize_workgroup_memory
&& ctx.ty.is_compute_entry_point(self.module)
{
self.write_workgroup_variables_initialization(&ctx)?;
if self.options.zero_initialize_workgroup_memory {
if let Some(workgroup_size) = ctx.ty.compute_entry_point_workgroup_size(self.module) {
self.write_workgroup_variables_initialization(&ctx, workgroup_size)?;
}
}

// Compose the function arguments from globals, in case of an entry point.
Expand Down Expand Up @@ -1780,31 +1781,80 @@ impl<'a, W: Write> Writer<'a, W> {
fn write_workgroup_variables_initialization(
&mut self,
ctx: &back::FunctionCtx,
workgroup_size: [u32; 3],
) -> BackendResult {
let mut vars = self
let vars = self
.module
.global_variables
.iter()
.filter(|&(handle, var)| {
!ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
})
.peekable();

if vars.peek().is_some() {
let level = back::Level(1);
});
let zero_init_res =
zero_init::zero_init(&self.module, vars, workgroup_size.into_iter().product());
if zero_init_res.is_empty() {
return Ok(());
}

writeln!(self.out, "{level}if (gl_LocalInvocationID == uvec3(0u)) {{")?;
let mut level = back::Level(1);
let mut remainder = None;

for (handle, var) in vars {
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{}{} = ", level.next(), name)?;
self.write_zero_init_value(var.ty)?;
writeln!(self.out, ";")?;
for (handle, init) in zero_init_res {
match init {
zero_init::ZeroInitKind::LocalPlusIndex {
index,
if_less_than,
} => {
if if_less_than != remainder {
let Some(if_less_than) = if_less_than else {
panic!("Got decrementing index")
};
remainder = Some(if_less_than);
writeln!(
self.out,
"{level}if (gl_LocalInvocationIndex < {if_less_than}u) {{"
)?;
level = level.next();
}
let var = &self.module.global_variables[handle];
let base_type = match &self.module.types[var.ty].inner {
TypeInner::Array { base, .. } => base,
_ => unreachable!(),
};
let name = &self.names[&NameKey::GlobalVariable(handle)];
if let Some(index) = index {
write!(
self.out,
"{}{}[gl_LocalInvocationIndex + {index}u] = ",
level, name
)?;
} else {
write!(self.out, "{}{}[gl_LocalInvocationIndex] = ", level, name)?;
}
self.write_zero_init_value(*base_type)?;
writeln!(self.out, ";")?;
}
zero_init::ZeroInitKind::NotArray => {
if remainder != Some(1) {
writeln!(self.out, "{level}if (gl_LocalInvocationIndex < 1u) {{")?;
level = level.next();
remainder = Some(1);
}
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{}{} = ", level.next(), name)?;
let var = &self.module.global_variables[handle];
self.write_zero_init_value(var.ty)?;
writeln!(self.out, ";")?;
}
}

writeln!(self.out, "{level}}}")?;
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
}
// Close all opened brackets
for level in (1..level.0).rev() {
writeln!(self.out, "{}}}", back::Level(level))?;
}
level = back::Level(1);

self.write_barrier(crate::Barrier::WORK_GROUP, level)?;

Ok(())
}
Expand Down
104 changes: 78 additions & 26 deletions naga/src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::{
BackendResult, Error, Options,
};
use crate::{
back,
back::{self, zero_init},
proc::{self, NameKey},
valid, Handle, Module, ScalarKind, ShaderStage, TypeInner,
};
Expand Down Expand Up @@ -1113,8 +1113,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// Write function name
write!(self.out, " {name}(")?;

let need_workgroup_variables_initialization =
self.need_workgroup_variables_initialization(func_ctx, module);
let workgroup_size_for_initialization =
self.workgroup_size_for_variables_initialization(func_ctx, module);

// Write function arguments for non entry point functions
match func_ctx.ty {
Expand Down Expand Up @@ -1169,11 +1169,11 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}

if need_workgroup_variables_initialization {
if workgroup_size_for_initialization.is_some() {
if !func.arguments.is_empty() {
write!(self.out, ", ")?;
}
write!(self.out, "uint3 __local_invocation_id : SV_GroupThreadID")?;
write!(self.out, "uint __local_invocation_index : SV_GroupIndex")?;
}
}
}
Expand All @@ -1197,8 +1197,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out)?;
writeln!(self.out, "{{")?;

if need_workgroup_variables_initialization {
self.write_workgroup_variables_initialization(func_ctx, module)?;
if let Some(workgroup_size) = workgroup_size_for_initialization {
self.write_workgroup_variables_initialization(func_ctx, module, workgroup_size)?;
}

if let back::FunctionType::EntryPoint(index) = func_ctx.ty {
Expand Down Expand Up @@ -1249,43 +1249,95 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Ok(())
}

fn need_workgroup_variables_initialization(
fn workgroup_size_for_variables_initialization(
&mut self,
func_ctx: &back::FunctionCtx,
module: &Module,
) -> bool {
self.options.zero_initialize_workgroup_memory
&& func_ctx.ty.is_compute_entry_point(module)
) -> Option<[u32; 3]> {
(self.options.zero_initialize_workgroup_memory
&& module.global_variables.iter().any(|(handle, var)| {
!func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
})
}))
.then_some(|| ())
.and(func_ctx.ty.compute_entry_point_workgroup_size(module))
}

fn write_workgroup_variables_initialization(
&mut self,
func_ctx: &back::FunctionCtx,
module: &Module,
workgroup_size: [u32; 3],
) -> BackendResult {
let level = back::Level(1);

writeln!(
self.out,
"{level}if (all(__local_invocation_id == uint3(0u, 0u, 0u))) {{"
)?;

let vars = module.global_variables.iter().filter(|&(handle, var)| {
!func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup
});

for (handle, var) in vars {
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{}{} = ", level.next(), name)?;
self.write_default_init(module, var.ty)?;
writeln!(self.out, ";")?;
let zero_init_res =
zero_init::zero_init(&module, vars, workgroup_size.into_iter().product());
if zero_init_res.is_empty() {
return Ok(());
}

let mut level = back::Level(1);
let mut remainder = None;

for (handle, init) in zero_init_res {
match init {
zero_init::ZeroInitKind::LocalPlusIndex {
index,
if_less_than,
} => {
if if_less_than != remainder {
let Some(if_less_than) = if_less_than else {
panic!("Got decrementing index")
};
remainder = Some(if_less_than);
writeln!(
self.out,
"{level}if (__local_invocation_index < {if_less_than}u) {{"
)?;
level = level.next();
}
let var = &module.global_variables[handle];
let base_type = match &module.types[var.ty].inner {
TypeInner::Array { base, .. } => base,
_ => unreachable!(),
};
let name = &self.names[&NameKey::GlobalVariable(handle)];
if let Some(index) = index {
write!(
self.out,
"{}{}[__local_invocation_index + {index}u] = ",
level, name
)?;
} else {
write!(self.out, "{}{}[__local_invocation_index] = ", level, name)?;
}
self.write_default_init(module, *base_type)?;
writeln!(self.out, ";")?;
}
zero_init::ZeroInitKind::NotArray => {
if remainder != Some(1) {
writeln!(self.out, "{level}if (__local_invocation_index < 1u) {{")?;
level = level.next();
remainder = Some(1);
}
let name = &self.names[&NameKey::GlobalVariable(handle)];
write!(self.out, "{}{} = ", level.next(), name)?;
let var = &module.global_variables[handle];
self.write_default_init(module, var.ty)?;
writeln!(self.out, ";")?;
}
}
}
// Close all opened brackets
for level in (1..level.0).rev() {
writeln!(self.out, "{}}}", back::Level(level))?;
}
level = back::Level(1);

writeln!(self.out, "{level}}}")?;
self.write_barrier(crate::Barrier::WORK_GROUP, level)
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
Ok(())
}

/// Helper method used to write statements
Expand Down
16 changes: 13 additions & 3 deletions naga/src/back/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ pub mod wgsl;
))]
pub mod pipeline_constants;

#[cfg(any(
feature = "hlsl-out",
feature = "msl-out",
feature = "spv-out",
feature = "glsl-out"
))]
pub(crate) mod zero_init;

/// Names of vector components.
pub const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
/// Indent for backends.
Expand Down Expand Up @@ -86,12 +94,14 @@ pub enum FunctionType {

impl FunctionType {
/// Returns true if the function is an entry point for a compute shader.
pub fn is_compute_entry_point(&self, module: &crate::Module) -> bool {
pub fn compute_entry_point_workgroup_size(&self, module: &crate::Module) -> Option<[u32; 3]> {
match *self {
FunctionType::EntryPoint(index) => {
module.entry_points[index as usize].stage == crate::ShaderStage::Compute
let entry_point = &module.entry_points[index as usize];
(entry_point.stage == crate::ShaderStage::Compute)
.then_some(entry_point.workgroup_size)
}
FunctionType::Function(_) => false,
FunctionType::Function(_) => None,
}
}
}
Expand Down
Loading
Loading