Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions naga/src/back/msl/keywords.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ pub const RESERVED: &[&str] = &[
super::writer::MODF_FUNCTION,
super::writer::ABS_FUNCTION,
super::writer::DIV_FUNCTION,
super::writer::DOT_FUNCTION,
super::writer::MOD_FUNCTION,
super::writer::NEG_FUNCTION,
super::writer::F2I32_FUNCTION,
Expand Down
102 changes: 67 additions & 35 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
pub(crate) const ABS_FUNCTION: &str = "naga_abs";
pub(crate) const DIV_FUNCTION: &str = "naga_div";
pub(crate) const DOT_FUNCTION: &str = "naga_dot";
pub(crate) const MOD_FUNCTION: &str = "naga_mod";
pub(crate) const NEG_FUNCTION: &str = "naga_neg";
pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
Expand Down Expand Up @@ -2331,26 +2332,26 @@ impl<W: Write> Writer<W> {
crate::TypeInner::Vector {
scalar:
crate::Scalar {
// Resolve float values to MSL's builtin dot function.
kind: crate::ScalarKind::Float,
..
},
..
} => "dot",
crate::TypeInner::Vector { size, .. } => {
return self.put_dot_product(
arg,
arg1.unwrap(),
size as usize,
|writer, arg, index| {
// Write the vector expression; this expression is marked to be
// cached so unless it can't be cached (for example, it's a Constant)
// it shouldn't produce large expressions.
writer.put_expression(arg, context, true)?;
// Access the current component on the vector.
write!(writer.out, ".{}", back::COMPONENTS[index])?;
Ok(())
},
);
crate::TypeInner::Vector { size, scalar }
if matches!(
scalar.kind,
crate::ScalarKind::Sint | crate::ScalarKind::Uint
) =>
{
// Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
write!(self.out, "{fun_name}(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", ")?;
self.put_expression(arg1.unwrap(), context, true)?;
write!(self.out, ")")?;
return Ok(());
}
_ => unreachable!(
"Correct TypeInner for dot product should be already validated"
Expand Down Expand Up @@ -3367,26 +3368,6 @@ impl<W: Write> Writer<W> {
} = *expr
{
match fun {
crate::MathFunction::Dot => {
// WGSL's `dot` function works on any `vecN` type, but Metal's only
// works on floating-point vectors, so we emit inline code for
// integer vector `dot` calls. But that code uses each argument `N`
// times, once for each component (see `put_dot_product`), so to
// avoid duplicated evaluation, we must bake integer operands.

// check what kind of product this is depending
// on the resolve type of the Dot function itself
let inner = context.resolve_type(expr_handle);
if let crate::TypeInner::Scalar(scalar) = *inner {
match scalar.kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
}
_ => {}
}
}
}
crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
Expand Down Expand Up @@ -5803,6 +5784,18 @@ template <typename A>
Ok(())
}

/// Build the mangled helper name for integer vector dot products.
/// Result format: `{DOT_FUNCTION}_{type}{N}` (e.g., `naga_dot_int3`).
fn get_dot_wrapper_function_helper_name(
&self,
scalar: crate::Scalar,
size: crate::VectorSize,
) -> String {
let type_name = scalar.to_msl_name();
let size_suffix = common::vector_size_str(size);
format!("{DOT_FUNCTION}_{type_name}{size_suffix}")
}

#[allow(clippy::too_many_arguments)]
fn write_wrapped_math_function(
&mut self,
Expand Down Expand Up @@ -5858,6 +5851,45 @@ template <typename A>
writeln!(self.out, "}}")?;
writeln!(self.out)?;
}

crate::MathFunction::Dot => match *arg_ty {
crate::TypeInner::Vector { size, scalar }
if matches!(
scalar.kind,
crate::ScalarKind::Sint | crate::ScalarKind::Uint
) =>
{
// De-duplicate per (fun, arg type) like other wrapped math functions
let wrapped = WrappedFunction::Math {
fun,
arg_ty: (Some(size), scalar),
};
if !self.wrapped_functions.insert(wrapped) {
return Ok(());
}

let mut vec_ty = String::new();
put_numeric_type(&mut vec_ty, scalar, &[size])?;
let mut ret_ty = String::new();
put_numeric_type(&mut ret_ty, scalar, &[])?;

let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);

// Emit function signature and body using put_dot_product for the expression
writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
let level = back::Level(1);
write!(self.out, "{level}return ")?;
self.put_dot_product("a", "b", size as usize, |writer, name, index| {
write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
Ok(())
})?;
writeln!(self.out, ";")?;
writeln!(self.out, "}}")?;
writeln!(self.out)?;
}
_ => {}
},

_ => {}
}
Ok(())
Expand Down
12 changes: 10 additions & 2 deletions naga/tests/out/msl/wgsl-functions.msl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,22 @@ metal::float2 test_fma(
return metal::fma(a, b, c);
}

int naga_dot_int2(metal::int2 a, metal::int2 b) {
return ( + a.x * b.x + a.y * b.y);
}

uint naga_dot_uint3(metal::uint3 a, metal::uint3 b) {
return ( + a.x * b.x + a.y * b.y + a.z * b.z);
}

int test_integer_dot_product(
) {
metal::int2 a_2_ = metal::int2(1);
metal::int2 b_2_ = metal::int2(1);
int c_2_ = ( + a_2_.x * b_2_.x + a_2_.y * b_2_.y);
int c_2_ = naga_dot_int2(a_2_, b_2_);
metal::uint3 a_3_ = metal::uint3(1u);
metal::uint3 b_3_ = metal::uint3(1u);
uint c_3_ = ( + a_3_.x * b_3_.x + a_3_.y * b_3_.y + a_3_.z * b_3_.z);
uint c_3_ = naga_dot_uint3(a_3_, b_3_);
return 32;
}

Expand Down
16 changes: 10 additions & 6 deletions naga/tests/out/msl/wgsl-int64.msl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ long naga_abs(long val) {
return metal::select(as_type<long>(-as_type<ulong>(val)), val, val >= 0);
}

long naga_dot_long2(metal::long2 a, metal::long2 b) {
return ( + a.x * b.x + a.y * b.y);
}

long int64_function(
long x,
thread long& private_variable,
Expand Down Expand Up @@ -111,11 +115,9 @@ long int64_function(
long _e130 = val;
val = as_type<long>(as_type<ulong>(_e130) + as_type<ulong>(metal::clamp(_e126, _e127, _e128)));
long _e132 = val;
metal::long2 _e133 = metal::long2(_e132);
long _e134 = val;
metal::long2 _e135 = metal::long2(_e134);
long _e137 = val;
val = as_type<long>(as_type<ulong>(_e137) + as_type<ulong>(( + _e133.x * _e135.x + _e133.y * _e135.y)));
val = as_type<long>(as_type<ulong>(_e137) + as_type<ulong>(naga_dot_long2(metal::long2(_e132), metal::long2(_e134))));
long _e139 = val;
long _e140 = val;
long _e142 = val;
Expand All @@ -135,6 +137,10 @@ ulong naga_f2u64(float value) {
return static_cast<ulong>(metal::clamp(value, 0.0, 18446743000000000000.0));
}

ulong naga_dot_ulong2(metal::ulong2 a, metal::ulong2 b) {
return ( + a.x * b.x + a.y * b.y);
}

ulong uint64_function(
ulong x_1,
constant UniformCompatible& input_uniform,
Expand Down Expand Up @@ -199,11 +205,9 @@ ulong uint64_function(
ulong _e125 = val_1;
val_1 = _e125 + metal::clamp(_e121, _e122, _e123);
ulong _e127 = val_1;
metal::ulong2 _e128 = metal::ulong2(_e127);
ulong _e129 = val_1;
metal::ulong2 _e130 = metal::ulong2(_e129);
ulong _e132 = val_1;
val_1 = _e132 + ( + _e128.x * _e130.x + _e128.y * _e130.y);
val_1 = _e132 + naga_dot_ulong2(metal::ulong2(_e127), metal::ulong2(_e129));
ulong _e134 = val_1;
ulong _e135 = val_1;
ulong _e137 = val_1;
Expand Down
Loading