Skip to content

Commit

Permalink
reflect changes in std::simd
Browse files Browse the repository at this point in the history
  • Loading branch information
kade committed Sep 17, 2022
1 parent 7957e1d commit 0a630e4
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ use std::simd::Simd;
#[cfg(target_feature="avx")]
use std::arch::x86_64::*;

#[cfg(not(target_feature="avx"))]
use std::simd::{SimdFloat, SimdPartialOrd};

// Number of single-precision floating-point numbers per vector
#[cfg(target_feature="avx")]
pub const LANES : usize = 8;
Expand All @@ -24,6 +27,7 @@ pub(crate) use simd_load;
// https://stackoverflow.com/questions/13879609
// https://stackoverflow.com/questions/41303780
#[cfg(target_feature="avx")]
#[inline]
pub fn horizontal_sum(x : Simd<f32, 8>) -> f32
{
unsafe {
Expand All @@ -40,6 +44,7 @@ pub fn horizontal_sum(x : Simd<f32, 8>) -> f32
}

#[cfg(not(target_feature="avx"))]
#[inline]
pub fn horizontal_sum(x : Simd<f32, LANES>) -> f32
{
return x.reduce_sum();
Expand Down Expand Up @@ -80,6 +85,7 @@ pub fn horizontal_sum(x : Simd<f32, LANES>) -> f32
// I'm not certain.
//
#[cfg(target_feature="avx")]
#[inline]
pub fn relu_ps(x : Simd<f32, 8>) -> Simd<f32, 8>
{
let x = __m256::from(x);
Expand All @@ -88,13 +94,15 @@ pub fn relu_ps(x : Simd<f32, 8>) -> Simd<f32, 8>
}

#[cfg(not(target_feature="avx"))]
#[inline]
pub fn relu_ps(x : Simd<f32, LANES>) -> Simd<f32, LANES>
{
return x.max(x * Simd::splat(0.03125));
return x.simd_max(x * Simd::splat(0.03125));
}

// The same caveat in re broadcast/set1 as above applies here.
#[cfg(target_feature="avx")]
#[inline]
pub fn d_relu_ps(x : Simd<f32, 8>) -> Simd<f32, 8>
{
unsafe {
Expand All @@ -109,7 +117,8 @@ pub fn d_relu_ps(x : Simd<f32, 8>) -> Simd<f32, 8>
}

#[cfg(not(target_feature="avx"))]
#[inline]
pub fn d_relu_ps(x : Simd<f32, LANES>) -> Simd<f32, LANES>
{
return x.lanes_lt(Simd::splat(0.0)).select(Simd::splat(0.03125), Simd::splat(1.0));
return x.simd_lt(Simd::splat(0.0)).select(Simd::splat(0.03125), Simd::splat(1.0));
}

0 comments on commit 0a630e4

Please sign in to comment.