Skip to content

Commit

Permalink
fma sqrt
Browse files Browse the repository at this point in the history
  • Loading branch information
burrbull committed Aug 9, 2022
1 parent cb1af0f commit c6fdb3d
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 223 deletions.
70 changes: 34 additions & 36 deletions src/f32x.rs
Original file line number Diff line number Diff line change
Expand Up @@ -964,43 +964,41 @@ macro_rules! impl_math_f32 {
/// This function may return infinity with a correct sign if the absolute value of the correct return value is greater than `1e+33`.
/// The error bounds of the returned value is `max(0.500_01 ULP, f32::MIN_POSITIVE)`.
pub fn fmaf(mut x: F32x, mut y: F32x, mut z: F32x) -> F32x {
/*
#ifdef ENABLE_FMA_SP
return vfma_vf_vf_vf_vf(x, y, z);
#else
*/
let h2 = x * y + z;
let mut q = ONE;
let o = h2.abs().simd_lt(F32x::splat(1e-38));
const C0: F32x = F1_25X;
let c1: F32x = C0 * C0;
let c2: F32x = c1 * c1;
{
x = o.select(x * c1, x);
y = o.select(y * c1, y);
z = o.select(z * c2, z);
q = o.select(ONE / c2, q);
if cfg!(target_feature = "fma") {
x.mla(y, z)
} else {
let h2 = x * y + z;
let mut q = ONE;
let o = h2.abs().simd_lt(F32x::splat(1e-38));
const C0: F32x = F1_25X;
let c1: F32x = C0 * C0;
let c2: F32x = c1 * c1;
{
x = o.select(x * c1, x);
y = o.select(y * c1, y);
z = o.select(z * c2, z);
q = o.select(ONE / c2, q);
}
let o = h2.abs().simd_gt(F32x::splat(1e+38));
{
x = o.select(x * (ONE / c1), x);
y = o.select(y * (ONE / c1), y);
z = o.select(z * (ONE / c2), z);
q = o.select(c2, q);
}
let d = x.mul_as_doubled(y) + z;
let ret = (x.simd_eq(ZERO) | y.simd_eq(ZERO)).select(z, F32x::from(d));
let mut o = z.is_infinite();
o = !x.is_infinite() & o;
o = !x.is_nan() & o;
o = !y.is_infinite() & o;
o = !y.is_nan() & o;
let h2 = o.select(z, h2);

o = h2.is_infinite() | h2.is_nan();

o.select(h2, ret * q)
}
let o = h2.abs().simd_gt(F32x::splat(1e+38));
{
x = o.select(x * (ONE / c1), x);
y = o.select(y * (ONE / c1), y);
z = o.select(z * (ONE / c2), z);
q = o.select(c2, q);
}
let d = x.mul_as_doubled(y) + z;
let ret = (x.simd_eq(ZERO) | y.simd_eq(ZERO)).select(z, F32x::from(d));
let mut o = z.is_infinite();
o = !x.is_infinite() & o;
o = !x.is_nan() & o;
o = !y.is_infinite() & o;
o = !y.is_nan() & o;
let h2 = o.select(z, h2);

o = h2.is_infinite() | h2.is_nan();

o.select(h2, ret * q)
// #endif
}

/// Square root function
Expand Down
119 changes: 60 additions & 59 deletions src/f32x/u05_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,72 +95,73 @@ macro_rules! impl_math_f32_u05 {
///
/// The error bound of the returned value is `0.5001 ULP`.
pub fn sqrtf(d: F32x) -> F32x {
/*
#if defined(ENABLE_FMA_SP)
vfloat q, w, x, y, z;
if cfg!(target_feature = "fma") {
let d = d.simd_lt(ZERO).select(NAN, d);

d = vsel_vf_vo_vf_vf(vlt_vo_vf_vf(d, vcast_vf_f(0)), vcast_vf_f(SLEEF_NANf), d);
let o = d.simd_lt(F32x::splat(5.293_955_920_339_377_e-23));
let d = o.select(d * F32x::splat(1.888_946_593_147_858_e+22), d);
let q = o.select(F32x::splat(7.275_957_614_183_426_e-12), F32x::splat(1.));

vopmask o = vlt_vo_vf_vf(d, vcast_vf_f(5.2939559203393770e-23f));
d = vsel_vf_vo_vf_vf(o, vmul_vf_vf_vf(d, vcast_vf_f(1.8889465931478580e+22f)), d);
q = vsel_vf_vo_vf_vf(o, vcast_vf_f(7.2759576141834260e-12f), vcast_vf_f(1.0f));
y = vreinterpret_vf_vi2(vsub_vi2_vi2_vi2(vcast_vi2_i(0x5f3759df), vsrl_vi2_vi2_i(vreinterpret_vi2_vf(d), 1)));
x = vmul_vf_vf_vf(d, y); w = vmul_vf_vf_vf(vcast_vf_f(0.5), y);
y = vfmanp_vf_vf_vf_vf(x, w, vcast_vf_f(0.5));
x = vfma_vf_vf_vf_vf(x, y, x); w = vfma_vf_vf_vf_vf(w, y, w);
y = vfmanp_vf_vf_vf_vf(x, w, vcast_vf_f(0.5));
x = vfma_vf_vf_vf_vf(x, y, x); w = vfma_vf_vf_vf_vf(w, y, w);
y = vfmanp_vf_vf_vf_vf(x, w, vcast_vf_f(1.5)); w = vadd_vf_vf_vf(w, w);
w = vmul_vf_vf_vf(w, y);
x = vmul_vf_vf_vf(w, d);
y = vfmapn_vf_vf_vf_vf(w, d, x); z = vfmanp_vf_vf_vf_vf(w, x, vcast_vf_f(1));
z = vfmanp_vf_vf_vf_vf(w, y, z); w = vmul_vf_vf_vf(vcast_vf_f(0.5), x);
w = vfma_vf_vf_vf_vf(w, z, y);
w = vadd_vf_vf_vf(w, x);
w = vmul_vf_vf_vf(w, q);
w = vsel_vf_vo_vf_vf(vor_vo_vo_vo(veq_vo_vf_vf(d, vcast_vf_f(0)),
veq_vo_vf_vf(d, vcast_vf_f(SLEEF_INFINITYf))), d, w);
w = vsel_vf_vo_vf_vf(vlt_vo_vf_vf(d, vcast_vf_f(0)), vcast_vf_f(SLEEF_NANf), w);
return w;
#else
*/

let d = d.simd_lt(ZERO).select(NAN, d);

let o = d.simd_lt(F32x::splat(5.293_955_920_339_377_e-23));
let d = o.select(d * F32x::splat(1.888_946_593_147_858_e+22), d);
let q = o.select(F32x::splat(7.275_957_614_183_426_e-12 * 0.5), HALF);

let o = d.simd_gt(F32x::splat(1.844_674_407_370_955_2_e+19));
let d = o.select(d * F32x::splat(5.421_010_862_427_522_e-20), d);
let q = o.select(F32x::splat(4_294_967_296.0 * 0.5), q);
let mut y = F32x::from_bits(
(I32x::splat(0x5f3759df) - (d.to_bits().cast::<i32>() >> I32x::splat(1)))
.cast(),
);

let mut x = F32x::from_bits(
(I32x::splat(0x_5f37_5a86)
- ((d + F32x::splat(1e-45)).to_bits() >> U32x::splat(1)).cast())
.cast(),
);
let mut x = d * y;
let mut w = HALF * y;
y = x.neg_mul_add(w, HALF);
x = x.mla(y, x);
w = w.mla(y, w);
y = x.neg_mul_add(w, HALF);
x = x.mla(y, x);
w = w.mla(y, w);

y = x.neg_mul_add(w, F32x::splat(1.5));
w += w;
w *= y;
x = w * d;
y = w.mul_sub(d, x);
let mut z = w.neg_mul_add(x, ONE);

z = w.neg_mul_add(y, z);
w = HALF * x;
w = w.mla(z, y);
w += x;

w *= q;

w = (d.simd_eq(ZERO) | d.simd_eq(INFINITY)).select(d, w);

d.simd_lt(ZERO).select(NAN, w)
} else {
let d = d.simd_lt(ZERO).select(NAN, d);

let o = d.simd_lt(F32x::splat(5.293_955_920_339_377_e-23));
let d = o.select(d * F32x::splat(1.888_946_593_147_858_e+22), d);
let q = o.select(F32x::splat(7.275_957_614_183_426_e-12 * 0.5), HALF);

let o = d.simd_gt(F32x::splat(1.844_674_407_370_955_2_e+19));
let d = o.select(d * F32x::splat(5.421_010_862_427_522_e-20), d);
let q = o.select(F32x::splat(4_294_967_296.0 * 0.5), q);

let mut x = F32x::from_bits(
(I32x::splat(0x_5f37_5a86)
- ((d + F32x::splat(1e-45)).to_bits() >> U32x::splat(1)).cast())
.cast(),
);

x *= F32x::splat(1.5) - HALF * d * x * x;
x *= F32x::splat(1.5) - HALF * d * x * x;
x *= F32x::splat(1.5) - HALF * d * x * x;
x *= d;
x *= F32x::splat(1.5) - HALF * d * x * x;
x *= F32x::splat(1.5) - HALF * d * x * x;
x *= F32x::splat(1.5) - HALF * d * x * x;
x *= d;

let d2 = (d + x.mul_as_doubled(x)) * x.recip_as_doubled();
let d2 = (d + x.mul_as_doubled(x)) * x.recip_as_doubled();

x = F32x::from(d2) * q;
x = F32x::from(d2) * q;

x = d.simd_eq(INFINITY).select(INFINITY, x);
d.simd_eq(ZERO).select(d, x)
// #endif
x = d.simd_eq(INFINITY).select(INFINITY, x);
d.simd_eq(ZERO).select(d, x)
}
}

#[test]
Expand Down
117 changes: 50 additions & 67 deletions src/f64x.rs
Original file line number Diff line number Diff line change
Expand Up @@ -874,21 +874,6 @@ macro_rules! impl_math_f64 {
t
}

#[inline]
fn splat2i(i0: i32, i1: i32) -> I64x {
I64x::splat(((i0 as i64) << 32) + (i1 as i64))
}

#[inline]
fn splat2u(u0: u32, u1: u32) -> I64x {
I64x::splat((((u0 as u64) << 32) + (u1 as u64)) as i64)
}

#[inline]
fn splat2uu(u0: u32, u1: u32) -> U64x {
U64x::splat(((u0 as u64) << 32) + (u1 as u64))
}

/// Absolute value
#[inline]
pub fn fabs(x: F64x) -> F64x {
Expand Down Expand Up @@ -991,21 +976,21 @@ macro_rules! impl_math_f64 {
let mut xi2: I64x = x.to_bits().cast();
let c = x.is_sign_negative() ^ y.simd_ge(x);

let mut t = (xi2 ^ splat2u(0x_7fff_ffff, 0x_ffff_ffff)) + splat2i(0, 1);
t += swap_upper_lower(splat2i(0, 1) & t.simd_eq(splat2i(-1, 0)).to_int());
let mut t = (xi2 ^ I64x::splat(0x_7fff_ffff_ffff_ffff_u64 as _)) + I64x::splat(1);
t += swap_upper_lower(I64x::splat(1) & t.simd_eq(I64x::splat(0x_ffff_ffff_0000_0000_u64 as _)).to_int());
xi2 = c.select(F64x::from_bits(t.cast()), F64x::from_bits(xi2.cast())).to_bits().cast();

xi2 -= (x.simd_ne(y).to_int().cast() & splat2uu(0, 1)).cast();
xi2 -= (x.simd_ne(y).to_int().cast() & U64x::splat(1)).cast();

xi2 = x.simd_ne(y).select(
F64x::from_bits((
xi2 + swap_upper_lower(splat2i(0, -1) & xi2.simd_eq(splat2i(0, -1)).to_int())
xi2 + swap_upper_lower(I64x::splat(0x_ffff_ffff_u64 as _) & xi2.simd_eq(I64x::splat(0x_ffff_ffff_u64 as _)).to_int())
).cast()),
F64x::from_bits(xi2.cast()),
).to_bits().cast();

let mut t = (xi2 ^ splat2u(0x_7fff_ffff, 0x_ffff_ffff)) + splat2i(0, 1);
t += swap_upper_lower(splat2i(0, 1) & t.simd_eq(splat2i(-1, 0)).to_int());
let mut t = (xi2 ^ I64x::splat(0x_7fff_ffff_ffff_ffff_u64 as _)) + I64x::splat(1);
t += swap_upper_lower(I64x::splat(1) & t.simd_eq(I64x::splat(0x_ffff_ffff_0000_0000_u64 as _)).to_int());
xi2 = c.select(F64x::from_bits(t.cast()), F64x::from_bits(xi2.cast())).to_bits().cast();

let mut ret = F64x::from_bits(xi2.cast());
Expand Down Expand Up @@ -1042,8 +1027,8 @@ macro_rules! impl_math_f64 {
.select(x * D1_63X, x);

let mut xm = x.to_bits();
xm &= splat2uu(!0x_7ff0_0000, !0);
xm |= splat2uu(0x_3fe0_0000, 0);
xm &= U64x::splat(0x_800f_ffff_ffff_ffff);
xm |= U64x::splat(0x_3fe0_0000 << 32);

let ret = F64x::from_bits(xm);

Expand All @@ -1070,43 +1055,41 @@ macro_rules! impl_math_f64 {
/// This function may return infinity with a correct sign if the absolute value of the correct return value is greater than `1e+300`.
/// The error bounds of the returned value is `max(0.500_01 ULP, f64::MIN_POSITIVE)`.
pub fn fma(mut x: F64x, mut y: F64x, mut z: F64x) -> F64x {
/*
#ifdef ENABLE_FMA_DP
return vfma_vd_vd_vd_vd(x, y, z);
#else
*/
let mut h2 = x * y + z;
let mut q = ONE;
const C0: F64x = D1_54X;
let c1: F64x = C0 * C0;
let c2: F64x = c1 * c1;
let o = h2.abs().simd_lt(F64x::splat(1e-300));
{
x = o.select(x * c1, x);
y = o.select(y * c1, y);
z = o.select(z * c2, z);
q = o.select(ONE / c2, q);
}
let o = h2.abs().simd_gt(F64x::splat(1e+300));
{
x = o.select(x * (ONE / c1), x);
y = o.select(y * (ONE / c1), y);
z = o.select(z * (ONE / c2), z);
q = o.select(c2, q);
}
let d = x.mul_as_doubled(y) + z;
let ret = (x.simd_eq(ZERO) | y.simd_eq(ZERO)).select(z, d.0 + d.1);
let mut o = z.is_infinite();
o = !x.is_infinite() & o;
o = !x.is_nan() & o;
o = !y.is_infinite() & o;
o = !y.is_nan() & o;
h2 = o.select(z, h2);

let o = h2.is_infinite() | h2.is_nan();

o.select(h2, ret * q)
// #endif
if cfg!(target_feature = "fma") {
x.mla(y, z)
} else {
let mut h2 = x * y + z;
let mut q = ONE;
const C0: F64x = D1_54X;
let c1: F64x = C0 * C0;
let c2: F64x = c1 * c1;
let o = h2.abs().simd_lt(F64x::splat(1e-300));
{
x = o.select(x * c1, x);
y = o.select(y * c1, y);
z = o.select(z * c2, z);
q = o.select(ONE / c2, q);
}
let o = h2.abs().simd_gt(F64x::splat(1e+300));
{
x = o.select(x * (ONE / c1), x);
y = o.select(y * (ONE / c1), y);
z = o.select(z * (ONE / c2), z);
q = o.select(c2, q);
}
let d = x.mul_as_doubled(y) + z;
let ret = (x.simd_eq(ZERO) | y.simd_eq(ZERO)).select(z, d.0 + d.1);
let mut o = z.is_infinite();
o = !x.is_infinite() & o;
o = !x.is_nan() & o;
o = !y.is_infinite() & o;
o = !y.is_nan() & o;
h2 = o.select(z, h2);

let o = h2.is_infinite() | h2.is_nan();

o.select(h2, ret * q)
}
}

/// Square root function
Expand All @@ -1128,7 +1111,7 @@ macro_rules! impl_math_f64 {
#[inline]
fn toward0(x: F64x) -> F64x {
// returns nextafter(x, 0)
let t = F64x::from_bits(x.to_bits() + splat2i(-1, -1).cast());
let t = F64x::from_bits(x.to_bits() + I64x::splat(-1).cast());
x.simd_eq(ZERO).select(ZERO, t)
}

Expand Down Expand Up @@ -1159,9 +1142,9 @@ macro_rules! impl_math_f64 {
for _ in 0..21 {
// ceil(log2(DBL_MAX) / 52)
let mut q = trunc_positive(toward0(r.0) * rd);
/* #ifndef ENABLE_FMA_DP
q = vreinterpret_vd_vm(vand_vm_vm_vm(vreinterpret_vm_vd(q), vcast_vm_i_i(0xffffffff, 0xfffffffe)));
#endif */
if cfg!(target_feature = "fma") {
q = F64x::from_bits(q.to_bits() & U64x::splat(0xffff_ffff_ffff_fffe));
}
q = ((F64x::splat(3.) * d).simd_gt(r.0) & r.0.simd_ge(d)).select(F64x::splat(2.), q);
q = ((d + d).simd_gt(r.0) & r.0.simd_ge(d)).select( ONE, q);
r = (r + q.mul_as_doubled(-d)).normalize();
Expand Down Expand Up @@ -1206,9 +1189,9 @@ macro_rules! impl_math_f64 {

for _ in 0..21 { // ceil(log2(DBL_MAX) / 52)
let mut q = rintk2(r.0 * rd);
/*#ifndef ENABLE_FMA_DP
q = vreinterpret_vd_vm(vand_vm_vm_vm(vreinterpret_vm_vd(q), vcast_vm_u64(UINT64_C(0xfffffffffffffffe))));
#endif*/
if cfg!(target_feature = "fma") {
q = F64x::from_bits(q.to_bits() & U64x::splat(0xffff_ffff_ffff_fffe));
}
q = r.0.abs().simd_lt(d * F64x::splat(1.5)).select(ONE.mul_sign(r.0), q);
q = (r.0.abs().simd_lt(d * HALF) | (!qisodd & r.0.abs().simd_eq(d * HALF)))
.select(ZERO, q);
Expand Down
Loading

0 comments on commit c6fdb3d

Please sign in to comment.