Skip to content

Commit

Permalink
Merge pull request #18289 from fp64/sse2-vfpu-dot
Browse files Browse the repository at this point in the history
Add SSE2 version of vfpu_dot
  • Loading branch information
hrydgard committed Oct 3, 2023
2 parents cd0b4fc + 49ac4c6 commit 7c184a7
Showing 1 changed file with 174 additions and 1 deletion.
175 changes: 174 additions & 1 deletion Core/MIPS/MIPSVFPUUtils.cpp
Expand Up @@ -598,7 +598,8 @@ float Float16ToFloat32(unsigned short l)
return f;
}

float vfpu_dot(const float a[4], const float b[4]) {
// Reference C++ version.
static float vfpu_dot_cpp(const float a[4], const float b[4]) {
static const int EXTRA_BITS = 2;
float2int result;
float2int src[2];
Expand Down Expand Up @@ -711,6 +712,178 @@ float vfpu_dot(const float a[4], const float b[4]) {
return result.f;
}

#if defined(__SSE2__)

#include <emmintrin.h>

static inline __m128i mulhi32x4(__m128i a, __m128i b) {
__m128i m02 = _mm_mul_epu32(a, b);
__m128i m13 = _mm_mul_epu32(
_mm_shuffle_epi32(a, _MM_SHUFFLE(3, 3, 1, 1)),
_mm_shuffle_epi32(b, _MM_SHUFFLE(3, 3, 1, 1)));
__m128i m=_mm_unpacklo_epi32(
_mm_shuffle_epi32(m02, _MM_SHUFFLE(3, 2, 3, 1)),
_mm_shuffle_epi32(m13, _MM_SHUFFLE(3, 2, 3, 1)));
return m;
}

// Values of rounding_mode:
// -1 - detect at runtime
// 0 - assume round-to-nearest-ties-to-even
// 1 - round yourself in integer math
template<int rounding_mode=-1>
static float vfpu_dot_sse2(const float a[4], const float b[4])
{
static const int EXTRA_BITS = 2;

bool is_default_rounding_mode = (rounding_mode == 0);
if(rounding_mode == -1)
{
volatile float test05 = 5.9604644775390625e-08f; // 0.5*2^-23
volatile float test15 = 1.78813934326171875e-07f; // 1.5*2^-23
const float res15 = 1.0000002384185791015625f; // 1+2^-22
test05 += 1.0f;
test15 += 1.0f;
is_default_rounding_mode = (test05 == 1.0f && test15 == res15);
}
__m128 A = _mm_loadu_ps(a);
__m128 B = _mm_loadu_ps(b);
// Extract exponents.
__m128 exp_mask = _mm_castsi128_ps(_mm_set1_epi32(0x7F800000));
__m128 eA = _mm_and_ps(A, exp_mask);
__m128 eB = _mm_and_ps(B, exp_mask);
__m128i exps = _mm_srli_epi32(_mm_add_epi32(
_mm_castps_si128(eA),
_mm_castps_si128(eB)),23);
// Find maximum exponent, stored as float32 in [1;2),
// so we can use _mm_max_ps() with normal arguments.
__m128 t = _mm_or_ps(_mm_castsi128_ps(exps), _mm_set1_ps(1.0f));
t = _mm_max_ps(t, _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(t), _MM_SHUFFLE(2, 3, 0, 1))));
t = _mm_max_ps(t, _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(t), _MM_SHUFFLE(1, 0, 3, 2))));
t = _mm_max_ps(t, _mm_castsi128_ps(_mm_set1_epi32(0x3F80007F)));
int32_t mexp = _mm_cvtsi128_si32(_mm_castps_si128(t)) & 511;
// NOTE: mexp is doubly-biased, same for exps.
int32_t max_exp = mexp - 127;
// Fall back on anything weird.
__m128 finiteA = _mm_sub_ps(A, A);
__m128 finiteB = _mm_sub_ps(B, B);
finiteA = _mm_cmpeq_ps(finiteA, finiteA);
finiteB = _mm_cmpeq_ps(finiteB, finiteB);
if(max_exp >= 255 || _mm_movemask_ps(_mm_and_ps(finiteA, finiteB)) != 15) return vfpu_dot_cpp(a, b);
// Extract significands.
__m128i mA = _mm_or_si128(_mm_and_si128(_mm_castps_si128(A),_mm_set1_epi32(0x007FFFFF)),_mm_set1_epi32(0x00800000));
__m128i mB = _mm_or_si128(_mm_and_si128(_mm_castps_si128(B),_mm_set1_epi32(0x007FFFFF)),_mm_set1_epi32(0x00800000));
// Multiply.
// NOTE: vfpu_dot does multiplication as
// ((x<<EXTRA_BITS)*(y<<EXTRA_BITS))>>(23+EXTRA_BITS),
// here we do (x*y)>>(23-EXTRA_BITS-1),
// which produces twice the result (neither expression
// overflows in our case). We need that because our
// variable-shift scheme (below) must shift by at least 1 bit.
static const int s = 32-(23 - EXTRA_BITS - 1), s0 = s / 2,s1 = s - s0;
// We compute ((x*y)>>shift) as
// (((x*y)<<(32-shift))>>32), which we express as
// (((x<<s0)*(y<<s1))>>32) (neither shift overflows).
__m128i m = mulhi32x4(_mm_slli_epi32(mA, s0), _mm_slli_epi32(mB, s1));
// Shift according to max_exp. Since SSE2 doesn't have
// variable per-lane shifts, we multiply *again*,
// specifically, x>>y turns into (x<<(1<<(32-y)))>>32.
// We compute 1<<(32-y) using floating-point casts.
// NOTE: the cast for 1<<31 produces the correct value,
// since the _mm_cvttps_epi32 error code just happens
// to be 0x80000000.
// So (since we pre-multiplied m by 2), we need
// (m>>1)>>(mexp-exps),
// i.e. m>>(mexp+1-exps),
// i.e. (m<<(32-(mexp+1-exps)))>>32,
// i.e. (m<<(exps-(mexp-31)))>>32.
__m128i amounts = _mm_sub_epi32(exps, _mm_set1_epi32(mexp - 31));
// Clamp by 0. Both zero and negative amounts produce zero,
// since they correspond to right-shifting by 32 or more bits.
amounts = _mm_and_si128(amounts, _mm_cmpgt_epi32(amounts, _mm_set1_epi32(0)));
// Set up multipliers.
__m128i bits = _mm_add_epi32(_mm_set1_epi32(0x3F800000), _mm_slli_epi32(amounts, 23));
__m128i muls = _mm_cvttps_epi32(_mm_castsi128_ps(bits));
m = mulhi32x4(m, muls);
// Extract signs.
__m128i signs = _mm_cmpgt_epi32(
_mm_set1_epi32(0),
_mm_xor_si128(_mm_castps_si128(A), _mm_castps_si128(B)));
// Apply signs to m.
m = _mm_sub_epi32(_mm_xor_si128(m, signs), signs);
// Horizontal sum.
// See https://stackoverflow.com/questions/6996764/fastest-way-to-do-horizontal-sse-vector-sum-or-other-reduction
__m128i h64 = _mm_shuffle_epi32(m, _MM_SHUFFLE(1, 0, 3, 2));
__m128i s64 = _mm_add_epi32(h64, m);
__m128i h32 = _mm_shufflelo_epi16(s64, _MM_SHUFFLE(1, 0, 3, 2));
__m128i s32 = _mm_add_epi32(s64, h32);
int32_t mant_sum = _mm_cvtsi128_si32(s32);

// The rest is scalar.
uint32_t sign_sum = 0;
if (mant_sum < 0) {
sign_sum = 0x80000000;
mant_sum = -mant_sum;
}

// Truncate off the extra bits now. We want to zero them for rounding purposes.
mant_sum >>= EXTRA_BITS;

if (mant_sum == 0 || max_exp <= 0) {
return 0.0f;
}

if(is_default_rounding_mode)
{
float2int r;
r.f = (float)mant_sum;
mant_sum = (r.i & 0x007FFFFF) | 0x00800000;
max_exp += (r.i >> 23) - 0x96;
}
else
{
int8_t shift = (int8_t)clz32_nonzero(mant_sum) - 8;
if (shift < 0) {
// Round to even if we'd shift away a 0.5.
const uint32_t round_bit = 1 << (-shift - 1);
if ((mant_sum & round_bit) && (mant_sum & (round_bit << 1))) {
mant_sum += round_bit;
shift = (int8_t)clz32_nonzero(mant_sum) - 8;
} else if ((mant_sum & round_bit) && (mant_sum & (round_bit - 1))) {
mant_sum += round_bit;
shift = (int8_t)clz32_nonzero(mant_sum) - 8;
}
mant_sum >>= -shift;
max_exp += -shift;
} else {
mant_sum <<= shift;
max_exp -= shift;
}
_dbg_assert_msg_((mant_sum & 0x00800000) != 0, "Mantissa wrong: %08x", mant_sum);
}

if (max_exp >= 255) {
max_exp = 255;
mant_sum = 0;
} else if (max_exp <= 0) {
return 0.0f;
}

float2int result;
result.i = sign_sum | (max_exp << 23) | (mant_sum & 0x007FFFFF);
return result.f;
}

#endif // defined(__SSE2__)

float vfpu_dot(const float a[4], const float b[4]) {
#if defined(__SSE2__)
return vfpu_dot_sse2(a, b);
#else
return vfpu_dot_cpp(a, b);
#endif
}

//==============================================================================
// The code below attempts to exactly match behaviour of
// PSP's vrnd instructions. See investigation starting around
Expand Down

0 comments on commit 7c184a7

Please sign in to comment.