Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement override-expression evaluation in functions #5387

Merged
merged 12 commits into from Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
387 changes: 366 additions & 21 deletions naga/src/back/pipeline_constants.rs

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions naga/src/block.rs
Expand Up @@ -65,6 +65,12 @@ impl Block {
self.span_info.splice(range.clone(), other.span_info);
self.body.splice(range, other.body);
}

pub fn span_into_iter(self) -> impl Iterator<Item = (Statement, Span)> {
let Block { body, span_info } = self;
body.into_iter().zip(span_info)
}

pub fn span_iter(&self) -> impl Iterator<Item = (&Statement, &Span)> {
let span_iter = self.span_info.iter();
self.body.iter().zip(span_iter)
Expand Down
2 changes: 1 addition & 1 deletion naga/src/front/wgsl/lower/mod.rs
Expand Up @@ -916,7 +916,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

let init;
if let Some(init_ast) = v.init {
let mut ectx = ctx.as_const();
let mut ectx = ctx.as_override();
let lowered = self.expression_for_abstract(init_ast, &mut ectx)?;
let ty_res = crate::proc::TypeResolution::Handle(ty);
let converted = ectx
Expand Down
75 changes: 46 additions & 29 deletions naga/src/proc/constant_evaluator.rs
Expand Up @@ -258,6 +258,17 @@ enum Behavior<'a> {
Glsl(GlslRestrictions<'a>),
}

impl Behavior<'_> {
/// Returns `true` if the inner WGSL/GLSL restrictions are runtime restrictions.
const fn has_runtime_restrictions(&self) -> bool {
matches!(
self,
&Behavior::Wgsl(WgslRestrictions::Runtime(_))
| &Behavior::Glsl(GlslRestrictions::Runtime(_))
)
}
}

/// A context for evaluating constant expressions.
///
/// A `ConstantEvaluator` points at an expression arena to which it can append
Expand Down Expand Up @@ -699,37 +710,43 @@ impl<'a> ConstantEvaluator<'a> {
expr: Expression,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
match (
&self.behavior,
self.expression_kind_tracker.type_of_with_expr(&expr),
) {
// avoid errors on unimplemented functionality if possible
(
&Behavior::Wgsl(WgslRestrictions::Runtime(_))
| &Behavior::Glsl(GlslRestrictions::Runtime(_)),
ExpressionKind::Const,
) => match self.try_eval_and_append_impl(&expr, span) {
Err(
ConstantEvaluatorError::NotImplemented(_)
| ConstantEvaluatorError::InvalidBinaryOpArgs,
) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)),
res => res,
match self.expression_kind_tracker.type_of_with_expr(&expr) {
ExpressionKind::Const => {
let eval_result = self.try_eval_and_append_impl(&expr, span);
// We should be able to evaluate `Const` expressions at this
// point. If we failed to, then that probably means we just
// haven't implemented that part of constant evaluation. Work
// around this by simply emitting it as a run-time expression.
if self.behavior.has_runtime_restrictions()
&& matches!(
eval_result,
Err(ConstantEvaluatorError::NotImplemented(_)
| ConstantEvaluatorError::InvalidBinaryOpArgs,)
)
{
Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
} else {
eval_result
}
}
ExpressionKind::Override => match self.behavior {
Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)) => {
Ok(self.append_expr(expr, span, ExpressionKind::Override))
}
Behavior::Wgsl(WgslRestrictions::Const) => {
Err(ConstantEvaluatorError::OverrideExpr)
}
Behavior::Glsl(_) => {
unreachable!()
}
},
(_, ExpressionKind::Const) => self.try_eval_and_append_impl(&expr, span),
(&Behavior::Wgsl(WgslRestrictions::Const), ExpressionKind::Override) => {
Err(ConstantEvaluatorError::OverrideExpr)
ExpressionKind::Runtime => {
if self.behavior.has_runtime_restrictions() {
Ok(self.append_expr(expr, span, ExpressionKind::Runtime))
} else {
Err(ConstantEvaluatorError::RuntimeExpr)
}
}
(
&Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)),
ExpressionKind::Override,
) => Ok(self.append_expr(expr, span, ExpressionKind::Override)),
(&Behavior::Glsl(_), ExpressionKind::Override) => unreachable!(),
(
&Behavior::Wgsl(WgslRestrictions::Runtime(_))
| &Behavior::Glsl(GlslRestrictions::Runtime(_)),
ExpressionKind::Runtime,
) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)),
(_, ExpressionKind::Runtime) => Err(ConstantEvaluatorError::RuntimeExpr),
}
}

Expand Down
4 changes: 2 additions & 2 deletions naga/src/valid/interface.rs
Expand Up @@ -31,7 +31,7 @@ pub enum GlobalVariableError {
Handle<crate::Type>,
#[source] Disalignment,
),
#[error("Initializer must be a const-expression")]
#[error("Initializer must be an override-expression")]
InitializerExprType,
#[error("Initializer doesn't match the variable type")]
InitializerType,
Expand Down Expand Up @@ -529,7 +529,7 @@ impl super::Validator {
}
}

if !global_expr_kind.is_const(init) {
if !global_expr_kind.is_const_or_override(init) {
return Err(GlobalVariableError::InitializerExprType);
}

Expand Down
9 changes: 9 additions & 0 deletions naga/tests/in/overrides-atomicCompareExchangeWeak.param.ron
@@ -0,0 +1,9 @@
(
spv: (
version: (1, 0),
separate_entry_points: true,
),
pipeline_constants: {
"o": 2.0
}
)
7 changes: 7 additions & 0 deletions naga/tests/in/overrides-atomicCompareExchangeWeak.wgsl
@@ -0,0 +1,7 @@
override o: i32;
var<workgroup> a: atomic<u32>;

@compute @workgroup_size(1)
fn f() {
atomicCompareExchangeWeak(&a, u32(o), 1u);
}
18 changes: 18 additions & 0 deletions naga/tests/in/overrides-ray-query.param.ron
@@ -0,0 +1,18 @@
(
god_mode: true,
spv: (
version: (1, 4),
separate_entry_points: true,
),
msl: (
lang_version: (2, 4),
spirv_cross_compatibility: false,
fake_missing_bindings: true,
zero_initialize_workgroup_memory: false,
per_entry_point_map: {},
inline_samplers: [],
),
pipeline_constants: {
"o": 2.0
}
)
21 changes: 21 additions & 0 deletions naga/tests/in/overrides-ray-query.wgsl
@@ -0,0 +1,21 @@
override o: f32;

@group(0) @binding(0)
var acc_struct: acceleration_structure;

@compute @workgroup_size(1)
fn main() {
var rq: ray_query;

let desc = RayDesc(
RAY_FLAG_TERMINATE_ON_FIRST_HIT,
0xFFu,
o * 17.0,
o * 19.0,
vec3<f32>(o * 23.0),
vec3<f32>(o * 29.0, o * 31.0, o * 37.0),
);
rayQueryInitialize(&rq, acc_struct, desc);

while (rayQueryProceed(&rq)) {}
}
10 changes: 9 additions & 1 deletion naga/tests/in/overrides.wgsl
Expand Up @@ -13,5 +13,13 @@

override inferred_f32 = 2.718;

var<private> gain_x_10: f32 = gain * 10.;

@compute @workgroup_size(1)
fn main() {}
fn main() {
var t = height * 5;
let a = !has_point_light;
var x = a;

var gain_x_100 = gain_x_10 * 10.;
}
148 changes: 146 additions & 2 deletions naga/tests/out/analysis/overrides.info.ron
Expand Up @@ -14,8 +14,143 @@
),
may_kill: false,
sampling_set: [],
global_uses: [],
expressions: [],
global_uses: [
("READ"),
],
expressions: [
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(2),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar((
kind: Float,
width: 4,
))),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar((
kind: Float,
width: 4,
))),
),
(
uniformity: (
non_uniform_result: Some(4),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Pointer(
base: 2,
space: Function,
)),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(1),
),
(
uniformity: (
non_uniform_result: Some(7),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Pointer(
base: 1,
space: Function,
)),
),
(
uniformity: (
non_uniform_result: Some(8),
requirements: (""),
),
ref_count: 1,
assignable_global: Some(1),
ty: Value(Pointer(
base: 2,
space: Private,
)),
),
(
uniformity: (
non_uniform_result: Some(8),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Handle(2),
),
(
uniformity: (
non_uniform_result: None,
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar((
kind: Float,
width: 4,
))),
),
(
uniformity: (
non_uniform_result: Some(8),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Scalar((
kind: Float,
width: 4,
))),
),
(
uniformity: (
non_uniform_result: Some(12),
requirements: (""),
),
ref_count: 1,
assignable_global: None,
ty: Value(Pointer(
base: 2,
space: Function,
)),
),
],
sampling: [],
dual_source_blending: false,
),
Expand Down Expand Up @@ -43,5 +178,14 @@
kind: Float,
width: 4,
))),
Handle(2),
Value(Scalar((
kind: Float,
width: 4,
))),
Value(Scalar((
kind: Float,
width: 4,
))),
],
)
10 changes: 10 additions & 0 deletions naga/tests/out/hlsl/overrides.hlsl
Expand Up @@ -6,8 +6,18 @@ static const float depth = 2.3;
static const float height = 4.6;
static const float inferred_f32_ = 2.718;

static float gain_x_10_ = 11.0;

[numthreads(1, 1, 1)]
void main()
{
float t = (float)0;
bool x = (bool)0;
float gain_x_100_ = (float)0;

t = 23.0;
x = true;
float _expr10 = gain_x_10_;
gain_x_100_ = (_expr10 * 10.0);
return;
}