Skip to content

Commit

Permalink
[wgsl-in] Avoid splatting all binary operator expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
fornwall committed Aug 16, 2023
1 parent 7a19f3a commit 83282ec
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,13 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
left: &mut Handle<crate::Expression>,
right: &mut Handle<crate::Expression>,
) -> Result<(), Error<'source>> {
if op != crate::BinaryOperator::Multiply {
if matches!(
op,
crate::BinaryOperator::Add
| crate::BinaryOperator::Subtract
| crate::BinaryOperator::Divide
| crate::BinaryOperator::Modulo
) {
self.grow_types(*left)?.grow_types(*right)?;

let left_size = match *self.resolved_inner(*left) {
Expand Down
48 changes: 48 additions & 0 deletions src/front/wgsl/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,54 @@ fn parse_expressions() {
}").unwrap();
}

#[test]
fn binary_expression_mixed_scalar_and_vector_operands() {
for (operand, expect_splat) in [
('<', false),
('>', false),
('&', false),
('|', false),
('+', true),
('-', true),
('*', false),
('/', true),
('%', true),
] {
let module = parse_str(&format!(
"
const some_vec = vec3<f32>(1.0, 1.0, 1.0);
@fragment
fn main() -> @location(0) vec4<f32> {{
if (all(1.0 {operand} some_vec)) {{
return vec4(0.0);
}}
return vec4(1.0);
}}
"
))
.unwrap();

let expressions = &&module.entry_points[0].function.expressions;

let found_expressions = expressions
.iter()
.filter(|&(_, e)| {
if let crate::Expression::Binary { left, .. } = *e {
matches!(
(expect_splat, &expressions[left]),
(false, &crate::Expression::Literal(crate::Literal::F32(..)))
| (true, &crate::Expression::Splat { .. })
)
} else {
false
}
})
.count();

assert_eq!(found_expressions, 1);
}
}

#[test]
fn parse_pointers() {
parse_str(
Expand Down

0 comments on commit 83282ec

Please sign in to comment.