Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 105 additions & 4 deletions libc/src/__support/math/rsqrtf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,11 @@ LIBC_INLINE static constexpr float16 rsqrtf16(float16 x) {
return FPBits::quiet_nan().get_val();
}

// x = +inf => rsqrt(x) = 0
return FPBits::zero().get_val();
// x = +inf => rsqrt(x) = +0
return FPBits::zero(xbits.sign()).get_val();
}

// TODO: add integer based implementation when LIBC_TARGET_CPU_HAS_FPU_FLOAT
// is not defined
#ifdef LIBC_TARGET_CPU_HAS_FPU_FLOAT
float result = 1.0f / fputil::sqrt<float>(fputil::cast<float>(x));

// Targeted post-corrections to ensure correct rounding in half for specific
Expand All @@ -76,6 +75,108 @@ LIBC_INLINE static constexpr float16 rsqrtf16(float16 x) {
}

return fputil::cast<float16>(result);

#else
// Range reduction:
// x can be expressed as m*2^e, where e - int exponent and m - mantissa
// rsqrtf16(x) = rsqrtf16(m*2^e)
// rsqrtf16(m*2^e) = 1/sqrt(m) * 1/sqrt(2^e) = 1/sqrt(m) * 1/2^(e/2)
// 1/sqrt(m) * 1/2^(e/2) = 1/sqrt(m) * 2^(-e/2)

// Compute reduction directly from half bits to avoid frexp/ldexp overhead.
int exponent = 0;
int signifcand = 0; // same as mantissa, but int
uint16_t eh = static_cast<uint16_t>((x_abs >> 10) & 0x1F);
uint16_t frac = static_cast<uint16_t>(x_abs & 0x3FF);

int result;
if (eh != 0) {
// ((2^-1 + frac/2^11) * 2) * 2^(eh-15)

// Normal: x = (1 + frac/2^10) * 2^(eh-15) = ((0.5 + frac/2^11) * 2) *
// 2^(eh-15)
// => mantissa in [0.5,1): m = 0.5 + frac/2^11, exponent = (eh - 15) + 1 =
// eh - 14
exponent = static_cast<int>(eh) - 14;
mantissa = 0.5f + static_cast<float>(frac) * 0x1.0p-11f;
} else {
// Subnormal: x = (frac/2^10) * 2^(1-15) = frac * 2^-24.
// Normalize frac so that bit 9 becomes 1; then mantissa m = (frac <<
// t)/2^10 ∈ [0.5,1) and exponent E = -14 - t so that x = m * 2^E.
if (LIBC_UNLIKELY(frac == 0)) {
// Should have been handled by zero check above, but keep safe.
return FPBits::inf(Sign::POS).get_val();
}
int shifts = 0;
while ((frac & 0x200u) == 0u) { // bring into [0x200, 0x3FF]
frac <<= 1;
++shifts;
}
exponent = -14 - shifts;
mantissa = static_cast<float>(frac) * 0x1.0p-10f;
}

float result = 0.0f;
int exp_floored = -(exponent >> 1);

if (mantissa == 0.5f) {
// When mantissa is 0.5f, x was a power of 2 (or subnormal that normalizes
// this way). 1/sqrt(0.5f) = sqrt(2.0f).
// If exponent is odd (exponent = 2k + 1):
// rsqrt(x) = (1/sqrt(0.5)) * 2^(-(2k+1)/2) = sqrt(2) * 2^(-k-0.5)
// = sqrt(2) * 2^(-k) * (1/sqrt(2)) = 2^(-k)
// exp_floored = -((2k+1)>>1) = -(k) = -k
// So result = ldexp(1.0f, exp_floored)
// If exponent is even (exponent = 2k):
// rsqrt(x) = (1/sqrt(0.5)) * 2^(-2k/2) = sqrt(2) * 2^(-k)
// exp_floored = -((2k)>>1) = -(k) = -k
// So result = ldexp(sqrt(2.0f), exp_floored)
if (exponent & 1) {
result = fputil::ldexp(1.0f, exp_floored);
} else {
constexpr float SQRT_2_F = 0x1.6a09e6p0f; // sqrt(2.0f)
result = fputil::ldexp(SQRT_2_F, exp_floored);
}
} else {
// 4 Degree minimax polynomial (single-precision coefficients) generated
// with Sollya: P = fpminimax(1/sqrt(x), 4,
// [|single,single,single,single,single|], [0.5;1])
float y = fputil::polyeval(mantissa,
0x1.771256p1f, // c0
-0x1.5e7c4ap2f, // c1
0x1.b3851cp2f, // c2
-0x1.1a27ep2f, // c3
0x1.265c66p0f); // c4

// Newton-Raphson iteration in float (use multiply_add to leverage FMA when
// available):
float y2 = y * y;
float factor = fputil::multiply_add(-0.5f * mantissa, y2, 1.5f);
y = y * factor;

result = fputil::ldexp(y, exp_floored);
if (exponent & 1) {
constexpr float ONE_OVER_SQRT2 = 0x1.6a09e6p-1f; // 1/sqrt(2)
result *= ONE_OVER_SQRT2;
}

// Targeted post-correction: for the specific half-precision mantissa
// pattern M == 0x011F we observe a consistent -1 ULP bias across exponents.
// Apply a tiny upward nudge to cross the rounding boundary in all modes.
const uint16_t half_mantissa = static_cast<uint16_t>(x_abs & 0x3ff);
if (half_mantissa == 0x011F) {
// Nudge up to fix consistent -1 ULP at that mantissa boundary
result = fputil::multiply_add(result, 0x1.0p-21f,
result); // result *= (1 + 2^-21)
} else if (half_mantissa == 0x0313) {
// Nudge down to fix +1 ULP under upward rounding at this mantissa
// boundary
result = fputil::multiply_add(result, -0x1.0p-21f,
result); // result *= (1 - 2^-21)
}
}
return fputil::cast<float16>(result);
#endif
}

} // namespace math
Expand Down
Loading