Skip to content

Commit

Permalink
Workgroup uniform load (#2201)
Browse files Browse the repository at this point in the history
Implement the WGSL `workgroupUniformLoad` function.
  • Loading branch information
DJMcNab committed May 26, 2023
1 parent 1c17fa8 commit 907b7c7
Show file tree
Hide file tree
Showing 25 changed files with 462 additions and 13 deletions.
6 changes: 6 additions & 0 deletions src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ impl StatementGraph {
}
"Atomic"
}
S::WorkGroupUniformLoad { pointer, result } => {
self.emits.push((id, result));
self.dependencies.push((id, pointer, "pointer"));
"WorkGroupUniformLoad"
}
S::RayQuery { query, ref fun } => {
self.dependencies.push((id, query, "query"));
match *fun {
Expand Down Expand Up @@ -570,6 +575,7 @@ fn write_function_expressions(
}
E::CallResult(_function) => ("CallResult".into(), 4),
E::AtomicResult { .. } => ("AtomicResult".into(), 4),
E::WorkGroupUniformLoadResult { .. } => ("WorkGroupUniformLoadResult".into(), 4),
E::ArrayLength(expr) => {
edges.insert("", expr);
("ArrayLength".into(), 7)
Expand Down
27 changes: 22 additions & 5 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1824,7 +1824,7 @@ impl<'a, W: Write> Writer<'a, W> {

if let Some(name) = expr_name {
write!(self.out, "{level}")?;
self.write_named_expr(handle, name, ctx)?;
self.write_named_expr(handle, name, handle, ctx)?;
}
}
}
Expand Down Expand Up @@ -2114,6 +2114,19 @@ impl<'a, W: Write> Writer<'a, W> {
self.write_expr(value, ctx)?;
writeln!(self.out, ";")?
}
Statement::WorkGroupUniformLoad { pointer, result } => {
// GLSL doesn't have pointers, which means that this backend needs to ensure that
// the actual "loading" is happening between the two barriers.
// This is done in `Emit` by never emitting a variable name for pointer variables
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;

let result_name = format!("{}{}", back::BAKE_PREFIX, result.index());
write!(self.out, "{level}")?;
// Expressions cannot have side effects, so just writing the expression here is fine.
self.write_named_expr(pointer, result_name, result, ctx)?;

self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
}
// Stores a value into an image.
Statement::ImageStore {
image,
Expand Down Expand Up @@ -3289,7 +3302,8 @@ impl<'a, W: Write> Writer<'a, W> {
// These expressions never show up in `Emit`.
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult => unreachable!(),
| Expression::RayQueryProceedResult
| Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
// `ArrayLength` is written as `expr.length()` and we convert it to a uint
Expression::ArrayLength(expr) => {
write!(self.out, "uint(")?;
Expand Down Expand Up @@ -3725,9 +3739,12 @@ impl<'a, W: Write> Writer<'a, W> {
&mut self,
handle: Handle<crate::Expression>,
name: String,
// The expression which is being named.
// Generally, this is the same as handle, except in WorkGroupUniformLoad
named: Handle<crate::Expression>,
ctx: &back::FunctionCtx,
) -> BackendResult {
match ctx.info[handle].ty {
match ctx.info[named].ty {
proc::TypeResolution::Handle(ty_handle) => match self.module.types[ty_handle].inner {
TypeInner::Struct { .. } => {
let ty_name = &self.names[&NameKey::Type(ty_handle)];
Expand All @@ -3742,7 +3759,7 @@ impl<'a, W: Write> Writer<'a, W> {
}
}

let base_ty_res = &ctx.info[handle].ty;
let base_ty_res = &ctx.info[named].ty;
let resolved = base_ty_res.inner_with(&self.module.types);

write!(self.out, " {name}")?;
Expand All @@ -3752,7 +3769,7 @@ impl<'a, W: Write> Writer<'a, W> {
write!(self.out, " = ")?;
self.write_expr(handle, ctx)?;
writeln!(self.out, ";")?;
self.named_expressions.insert(handle, name);
self.named_expressions.insert(named, name);

Ok(())
}
Expand Down
20 changes: 16 additions & 4 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {

if let Some(name) = expr_name {
write!(self.out, "{level}")?;
self.write_named_expr(module, handle, name, func_ctx)?;
self.write_named_expr(module, handle, name, handle, func_ctx)?;
}
}
}
Expand Down Expand Up @@ -1899,6 +1899,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
writeln!(self.out, ", {res_name});")?;
self.named_expressions.insert(result, res_name);
}
Statement::WorkGroupUniformLoad { pointer, result } => {
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
write!(self.out, "{level}")?;
let name = format!("_expr{}", result.index());
self.write_named_expr(module, pointer, name, result, func_ctx)?;

self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
}
Statement::Switch {
selector,
ref cases,
Expand Down Expand Up @@ -2933,6 +2941,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
// Nothing to do here, since call expression already cached
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::WorkGroupUniformLoadResult { .. }
| Expression::RayQueryProceedResult => {}
}

Expand Down Expand Up @@ -3023,9 +3032,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
module: &Module,
handle: Handle<crate::Expression>,
name: String,
// The expression which is being named.
// Generally, this is the same as handle, except in WorkGroupUniformLoad
named: Handle<crate::Expression>,
ctx: &back::FunctionCtx,
) -> BackendResult {
match ctx.info[handle].ty {
match ctx.info[named].ty {
proc::TypeResolution::Handle(ty_handle) => match module.types[ty_handle].inner {
TypeInner::Struct { .. } => {
let ty_name = &self.names[&NameKey::Type(ty_handle)];
Expand All @@ -3040,7 +3052,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
}

let base_ty_res = &ctx.info[handle].ty;
let base_ty_res = &ctx.info[named].ty;
let resolved = base_ty_res.inner_with(&module.types);

write!(self.out, " {name}")?;
Expand All @@ -3051,7 +3063,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, " = ")?;
self.write_expr(module, handle, ctx)?;
writeln!(self.out, ";")?;
self.named_expressions.insert(handle, name);
self.named_expressions.insert(named, name);

Ok(())
}
Expand Down
13 changes: 13 additions & 0 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1880,6 +1880,7 @@ impl<W: Write> Writer<W> {
// has to be a named expression
crate::Expression::CallResult(_)
| crate::Expression::AtomicResult { .. }
| crate::Expression::WorkGroupUniformLoadResult { .. }
| crate::Expression::RayQueryProceedResult => {
unreachable!()
}
Expand Down Expand Up @@ -2842,6 +2843,18 @@ impl<W: Write> Writer<W> {
// done
writeln!(self.out, ";")?;
}
crate::Statement::WorkGroupUniformLoad { pointer, result } => {
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;

write!(self.out, "{level}")?;
let name = self.namer.call("");
self.start_baking_expression(result, &context.expression, &name)?;
self.put_load(pointer, &context.expression, true)?;
self.named_expressions.insert(result, name);

writeln!(self.out, ";")?;
self.write_barrier(crate::Barrier::WORK_GROUP, level)?;
}
crate::Statement::RayQuery { query, ref fun } => {
match *fun {
crate::RayQueryFunction::Initialize {
Expand Down
41 changes: 41 additions & 0 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1089,6 +1089,7 @@ impl<'w> BlockContext<'w> {
crate::Expression::FunctionArgument(index) => self.function.parameter_id(index),
crate::Expression::CallResult(_)
| crate::Expression::AtomicResult { .. }
| crate::Expression::WorkGroupUniformLoadResult { .. }
| crate::Expression::RayQueryProceedResult => self.cached[expr_handle],
crate::Expression::As {
expr,
Expand Down Expand Up @@ -2209,6 +2210,46 @@ impl<'w> BlockContext<'w> {

block.body.push(instruction);
}
crate::Statement::WorkGroupUniformLoad { pointer, result } => {
self.writer
.write_barrier(crate::Barrier::WORK_GROUP, &mut block);
let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
// Embed the body of
match self.write_expression_pointer(pointer, &mut block, None)? {
ExpressionPointer::Ready { pointer_id } => {
let id = self.gen_id();
block.body.push(Instruction::load(
result_type_id,
id,
pointer_id,
None,
));
self.cached[result] = id;
}
ExpressionPointer::Conditional { condition, access } => {
self.cached[result] = self.write_conditional_indexed_load(
result_type_id,
condition,
&mut block,
move |id_gen, block| {
// The in-bounds path. Perform the access and the load.
let pointer_id = access.result_id.unwrap();
let value_id = id_gen.next();
block.body.push(access);
block.body.push(Instruction::load(
result_type_id,
value_id,
pointer_id,
None,
));
value_id
},
)
}
}
self.writer
.write_barrier(crate::Barrier::WORK_GROUP, &mut block);
}
crate::Statement::RayQuery { query, ref fun } => {
self.write_ray_query_function(query, fun, &mut block);
}
Expand Down
13 changes: 12 additions & 1 deletion src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,16 @@ impl<W: Write> Writer<W> {
self.write_expr(module, value, func_ctx)?;
writeln!(self.out, ");")?
}
Statement::WorkGroupUniformLoad { pointer, result } => {
write!(self.out, "{level}")?;
// TODO: Obey named expressions here.
let res_name = format!("{}{}", back::BAKE_PREFIX, result.index());
self.start_named_expr(module, result, func_ctx, &res_name)?;
self.named_expressions.insert(result, res_name);
write!(self.out, "workgroupUniformLoad(")?;
self.write_expr(module, pointer, func_ctx)?;
writeln!(self.out, ");")?;
}
Statement::ImageStore {
image,
coordinate,
Expand Down Expand Up @@ -1633,7 +1643,8 @@ impl<W: Write> Writer<W> {
// Nothing to do here, since call expression already cached
Expression::CallResult(_)
| Expression::AtomicResult { .. }
| Expression::RayQueryProceedResult => {}
| Expression::RayQueryProceedResult
| Expression::WorkGroupUniformLoadResult { .. } => {}
}

Ok(())
Expand Down
1 change: 1 addition & 0 deletions src/front/glsl/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ impl<'a> ConstantSolver<'a> {
Expression::Derivative { .. } => Err(ConstantSolvingError::Derivative),
Expression::Relational { .. } => Err(ConstantSolvingError::Relational),
Expression::CallResult { .. } => Err(ConstantSolvingError::Call),
Expression::WorkGroupUniformLoadResult { .. } => unreachable!(),
Expression::AtomicResult { .. } => Err(ConstantSolvingError::Atomic),
Expression::FunctionArgument(_) => Err(ConstantSolvingError::FunctionArg),
Expression::GlobalVariable(_) => Err(ConstantSolvingError::GlobalVariable),
Expand Down
1 change: 1 addition & 0 deletions src/front/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3845,6 +3845,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
}
}
}
S::WorkGroupUniformLoad { .. } => unreachable!(),
}
i += 1;
}
Expand Down
6 changes: 6 additions & 0 deletions src/front/wgsl/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ pub enum Error<'a> {
found: u32,
},
FunctionReturnsVoid(Span),
InvalidWorkGroupUniformLoad(Span),
Other,
ExpectedArraySize(Span),
NonPositiveArrayLength(Span),
Expand Down Expand Up @@ -682,6 +683,11 @@ impl<'a> Error<'a> {
"perhaps you meant to call the function in a separate statement?".into(),
],
},
Error::InvalidWorkGroupUniformLoad(span) => ParseError {
message: "incorrect type passed to workgroupUniformLoad".into(),
labels: vec![(span, "".into())],
notes: vec!["passed type must be a workgroup pointer".into()],
},
Error::Other => ParseError {
message: "other error".to_string(),
labels: vec![],
Expand Down
30 changes: 30 additions & 0 deletions src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1691,6 +1691,36 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.push(crate::Statement::Barrier(crate::Barrier::WORK_GROUP), span);
return Ok(None);
}
"workgroupUniformLoad" => {
let mut args = ctx.prepare_args(arguments, 1, span);
let expr = args.next()?;
args.finish()?;

let pointer = self.expression(expr, ctx.reborrow())?;
ctx.grow_types(pointer)?;
let result_ty = match *ctx.resolved_inner(pointer) {
crate::TypeInner::Pointer {
base,
space: crate::AddressSpace::WorkGroup,
} => base,
ref other => {
log::error!("Type {other:?} passed to workgroupUniformLoad");
let span = ctx.ast_expressions.get_span(expr);
return Err(Error::InvalidWorkGroupUniformLoad(span));
}
};
let result = ctx.interrupt_emitter(
crate::Expression::WorkGroupUniformLoadResult { ty: result_ty },
span,
);
ctx.block.push(
crate::Statement::WorkGroupUniformLoad { pointer, result },
span,
);

ctx.grow_types(pointer)?;
return Ok(Some(result));
}
"textureStore" => {
let mut args = ctx.prepare_args(arguments, 3, span);

Expand Down
22 changes: 22 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,13 @@ pub enum Expression {
CallResult(Handle<Function>),
/// Result of an atomic operation.
AtomicResult { ty: Handle<Type>, comparison: bool },
/// Result of a [`WorkGroupUniformLoad`] statement.
///
/// [`WorkGroupUniformLoad`]: Statement::WorkGroupUniformLoad
WorkGroupUniformLoadResult {
/// The type of the result
ty: Handle<Type>,
},
/// Get the length of an array.
/// The expression must resolve to a pointer to an array with a dynamic size.
///
Expand Down Expand Up @@ -1732,6 +1739,21 @@ pub enum Statement {
/// [`AtomicResult`]: crate::Expression::AtomicResult
result: Handle<Expression>,
},
/// Load uniformly from a uniform pointer in the workgroup address space.
///
/// Corresponds to the [`workgroupUniformLoad`](https://www.w3.org/TR/WGSL/#workgroupUniformLoad-builtin)
/// built-in function of wgsl, and has the same barrier semantics
WorkGroupUniformLoad {
/// This must be of type [`Pointer`] in the [`WorkGroup`] address space
///
/// [`Pointer`]: TypeInner::Pointer
/// [`WorkGroup`]: AddressSpace::WorkGroup
pointer: Handle<Expression>,
/// The [`WorkGroupUniformLoadResult`] expression representing this load's result.
///
/// [`WorkGroupUniformLoadResult`]: Expression::WorkGroupUniformLoadResult
result: Handle<Expression>,
},
/// Calls a function.
///
/// If the `result` is `Some`, the corresponding expression has to be
Expand Down
1 change: 1 addition & 0 deletions src/proc/terminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) {
| S::Call { .. }
| S::RayQuery { .. }
| S::Atomic { .. }
| S::WorkGroupUniformLoad { .. }
| S::Barrier(_)),
)
| None => block.push(S::Return { value: None }, Default::default()),
Expand Down
1 change: 1 addition & 0 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ impl<'a> ResolveContext<'a> {
| crate::BinaryOperator::ShiftRight => past(left)?.clone(),
},
crate::Expression::AtomicResult { ty, .. } => TypeResolution::Handle(ty),
crate::Expression::WorkGroupUniformLoadResult { ty } => TypeResolution::Handle(ty),
crate::Expression::Select { accept, .. } => past(accept)?.clone(),
crate::Expression::Derivative { expr, .. } => past(expr)?.clone(),
crate::Expression::Relational { fun, argument } => match fun {
Expand Down
Loading

0 comments on commit 907b7c7

Please sign in to comment.