Skip to content

Commit

Permalink
NEON fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
bitshifter committed Mar 25, 2024
1 parent 9f26d6f commit ecd93dc
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ jobs:
- aarch64-unknown-linux-gnu
- arm-unknown-linux-gnueabi
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: stable
Expand Down
4 changes: 1 addition & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,12 @@ libm = { version = "0.2", optional = true, default-features = false}
#rand_xoshiro = "0.6"
#serde_json = "1.0"

[target.'cfg(target_arch = "x86_64")'.dev-dependencies]
criterion = { version = "0.4", features = ["html_reports"] }
rand_xoshiro = "0.6"
# Set a size_xx feature so that this crate compiles properly with --all-targets --all-features
rkyv = { version = "0.7", default-features = false, features = ["size_32"] }
serde_json = "1.0"

[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
[target.'cfg(target_arch = "x86_64")'.dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }

[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
Expand Down
12 changes: 10 additions & 2 deletions codegen/templates/vec.rs.tera
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ impl {{ self_t }} {
#[inline]
#[must_use]
pub fn element_sum(self) -> {{ scalar_t }} {
{% if is_scalar %}
{% if is_scalar or is_neon %}
{% for c in components %}
self.{{ c }} {% if not loop.last %} + {% endif %}
{%- endfor %}
Expand Down Expand Up @@ -861,6 +861,8 @@ impl {{ self_t }} {
{% elif dim == 4 %}
self.0.reduce_sum()
{% endif %}
{% else %}
unimplemented!()
{% endif %}
}

Expand All @@ -870,7 +872,7 @@ impl {{ self_t }} {
#[inline]
#[must_use]
pub fn element_product(self) -> {{ scalar_t }} {
{% if is_scalar %}
{% if is_scalar or is_neon %}
{% for c in components %}
self.{{ c }} {% if not loop.last %} * {% endif %}
{%- endfor %}
Expand Down Expand Up @@ -908,6 +910,8 @@ impl {{ self_t }} {
{% elif dim == 4 %}
self.0.reduce_product()
{% endif %}
{% else %}
unimplemented!()
{% endif %}
}

Expand Down Expand Up @@ -1653,6 +1657,10 @@ impl {{ self_t }} {
Self(f32x4_trunc(self.0))
{% elif is_coresimd %}
Self(self.0.trunc())
{% elif is_neon %}
Self(unsafe { vrndq_f32(self.0) })
{% else %}
unimplemented!()
{% endif %}
}

Expand Down
11 changes: 10 additions & 1 deletion codegen/templates/vec_mask.rs.tera
Original file line number Diff line number Diff line change
Expand Up @@ -298,11 +298,20 @@ impl {{ self_t }} {
}
{% elif is_coresimd %}
self.0.set(index, value)
{% elif is_neon %}
self.0 = match index {
{% for c in components %}
{{ loop.index0 }} => unsafe {
vsetq_lane_u32(MASK[value as usize], self.0, {{ loop.index0 }})
},
{%- endfor %}
_ => panic!("index out of bounds")
}
{% else %}
use crate::{{ vec_t }};
let mut v = {{ vec_t }}(self.0);
v[index] = f32::from_bits(MASK[value as usize]);
*self = Self(v.0);
self.0 = v.0;
{% endif %}
}

Expand Down
10 changes: 6 additions & 4 deletions src/bool/neon/bvec3a.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,12 @@ impl BVec3A {
/// Panics if `index` is greater than 2.
#[inline]
pub fn set(&mut self, index: usize, value: bool) {
use crate::Vec3A;
let mut v = Vec3A(self.0);
v[index] = f32::from_bits(MASK[value as usize]);
*self = Self(v.0);
self.0 = match index {
0 => unsafe { vsetq_lane_u32(MASK[value as usize], self.0, 0) },
1 => unsafe { vsetq_lane_u32(MASK[value as usize], self.0, 1) },
2 => unsafe { vsetq_lane_u32(MASK[value as usize], self.0, 2) },
_ => panic!("index out of bounds"),
}
}

#[inline]
Expand Down
11 changes: 7 additions & 4 deletions src/bool/neon/bvec4a.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,13 @@ impl BVec4A {
/// Panics if `index` is greater than 3.
#[inline]
pub fn set(&mut self, index: usize, value: bool) {
use crate::Vec4;
let mut v = Vec4(self.0);
v[index] = f32::from_bits(MASK[value as usize]);
*self = Self(v.0);
self.0 = match index {
0 => unsafe { vsetq_lane_u32(MASK[value as usize], self.0, 0) },
1 => unsafe { vsetq_lane_u32(MASK[value as usize], self.0, 1) },
2 => unsafe { vsetq_lane_u32(MASK[value as usize], self.0, 2) },
3 => unsafe { vsetq_lane_u32(MASK[value as usize], self.0, 3) },
_ => panic!("index out of bounds"),
}
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion src/bool/sse2/bvec3a.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl BVec3A {
use crate::Vec3A;
let mut v = Vec3A(self.0);
v[index] = f32::from_bits(MASK[value as usize]);
*self = Self(v.0);
self.0 = v.0;
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion src/bool/sse2/bvec4a.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ impl BVec4A {
use crate::Vec4;
let mut v = Vec4(self.0);
v[index] = f32::from_bits(MASK[value as usize]);
*self = Self(v.0);
self.0 = v.0;
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion src/bool/wasm32/bvec3a.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl BVec3A {
use crate::Vec3A;
let mut v = Vec3A(self.0);
v[index] = f32::from_bits(MASK[value as usize]);
*self = Self(v.0);
self.0 = v.0;
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion src/bool/wasm32/bvec4a.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl BVec4A {
use crate::Vec4;
let mut v = Vec4(self.0);
v[index] = f32::from_bits(MASK[value as usize]);
*self = Self(v.0);
self.0 = v.0;
}

#[inline]
Expand Down
12 changes: 9 additions & 3 deletions src/f32/neon/vec3a.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,18 @@ impl Vec3A {
/// In other words, this computes `self.x + self.y + ..`.
#[inline]
#[must_use]
pub fn element_sum(self) -> f32 {}
pub fn element_sum(self) -> f32 {
self.x + self.y + self.z
}

/// Returns the product of all elements of `self`.
///
/// In other words, this computes `self.x * self.y * ..`.
#[inline]
#[must_use]
pub fn element_product(self) -> f32 {}
pub fn element_product(self) -> f32 {
self.x * self.y * self.z
}

/// Returns a vector mask containing the result of a `==` comparison for each element of
/// `self` and `rhs`.
Expand Down Expand Up @@ -664,7 +668,9 @@ impl Vec3A {
/// always truncated towards zero.
#[inline]
#[must_use]
pub fn trunc(self) -> Self {}
pub fn trunc(self) -> Self {
Self(unsafe { vrndq_f32(self.0) })
}

/// Returns a vector containing the fractional part of the vector as `self - self.trunc()`.
///
Expand Down
12 changes: 9 additions & 3 deletions src/f32/neon/vec4.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,14 +259,18 @@ impl Vec4 {
/// In other words, this computes `self.x + self.y + ..`.
#[inline]
#[must_use]
pub fn element_sum(self) -> f32 {}
pub fn element_sum(self) -> f32 {
self.x + self.y + self.z + self.w
}

/// Returns the product of all elements of `self`.
///
/// In other words, this computes `self.x * self.y * ..`.
#[inline]
#[must_use]
pub fn element_product(self) -> f32 {}
pub fn element_product(self) -> f32 {
self.x * self.y * self.z * self.w
}

/// Returns a vector mask containing the result of a `==` comparison for each element of
/// `self` and `rhs`.
Expand Down Expand Up @@ -654,7 +658,9 @@ impl Vec4 {
/// always truncated towards zero.
#[inline]
#[must_use]
pub fn trunc(self) -> Self {}
pub fn trunc(self) -> Self {
Self(unsafe { vrndq_f32(self.0) })
}

/// Returns a vector containing the fractional part of the vector as `self - self.trunc()`.
///
Expand Down

0 comments on commit ecd93dc

Please sign in to comment.