Skip to content
Permalink
Browse files
Merge pull request #11321 from JosJuice/jitarm64-accurate-nans
JitArm64: Implement accurate NaNs
  • Loading branch information
lioncash committed Dec 7, 2022
2 parents 94faad0 + 06e60ac commit 000c6c4
Show file tree
Hide file tree
Showing 5 changed files with 403 additions and 127 deletions.
@@ -2173,6 +2173,12 @@ void ARM64FloatEmitter::EmitScalar2RegMisc(bool U, u32 size, u32 opcode, ARM64Re
(DecodeReg(Rn) << 5) | DecodeReg(Rd));
}

void ARM64FloatEmitter::EmitScalarPairwise(bool U, u32 size, u32 opcode, ARM64Reg Rd, ARM64Reg Rn)
{
Write32((1 << 30) | (U << 29) | (0b111100011 << 20) | (size << 22) | (opcode << 12) | (1 << 11) |
(DecodeReg(Rn) << 5) | DecodeReg(Rd));
}

void ARM64FloatEmitter::Emit2RegMisc(bool Q, bool U, u32 size, u32 opcode, ARM64Reg Rd, ARM64Reg Rn)
{
ASSERT_MSG(DYNA_REC, !IsSingle(Rd), "Singles are not supported!");
@@ -2985,6 +2991,28 @@ void ARM64FloatEmitter::FRSQRTE(ARM64Reg Rd, ARM64Reg Rn)
EmitScalar2RegMisc(1, IsDouble(Rd) ? 3 : 2, 0x1D, Rd, Rn);
}

// Scalar - pairwise
void ARM64FloatEmitter::FADDP(ARM64Reg Rd, ARM64Reg Rn)
{
EmitScalarPairwise(1, IsDouble(Rd), 0b01101, Rd, Rn);
}
void ARM64FloatEmitter::FMAXP(ARM64Reg Rd, ARM64Reg Rn)
{
EmitScalarPairwise(1, IsDouble(Rd), 0b01111, Rd, Rn);
}
void ARM64FloatEmitter::FMINP(ARM64Reg Rd, ARM64Reg Rn)
{
EmitScalarPairwise(1, IsDouble(Rd) ? 3 : 2, 0b01111, Rd, Rn);
}
void ARM64FloatEmitter::FMAXNMP(ARM64Reg Rd, ARM64Reg Rn)
{
EmitScalarPairwise(1, IsDouble(Rd), 0b01100, Rd, Rn);
}
void ARM64FloatEmitter::FMINNMP(ARM64Reg Rd, ARM64Reg Rn)
{
EmitScalarPairwise(1, IsDouble(Rd) ? 3 : 2, 0b01100, Rd, Rn);
}

// Scalar - 2 Source
void ARM64FloatEmitter::ADD(ARM64Reg Rd, ARM64Reg Rn, ARM64Reg Rm)
{
@@ -1130,6 +1130,13 @@ class ARM64FloatEmitter
void FRECPE(ARM64Reg Rd, ARM64Reg Rn);
void FRSQRTE(ARM64Reg Rd, ARM64Reg Rn);

// Scalar - pairwise
void FADDP(ARM64Reg Rd, ARM64Reg Rn);
void FMAXP(ARM64Reg Rd, ARM64Reg Rn);
void FMINP(ARM64Reg Rd, ARM64Reg Rn);
void FMAXNMP(ARM64Reg Rd, ARM64Reg Rn);
void FMINNMP(ARM64Reg Rd, ARM64Reg Rn);

// Scalar - 2 Source
void ADD(ARM64Reg Rd, ARM64Reg Rn, ARM64Reg Rm);
void FADD(ARM64Reg Rd, ARM64Reg Rn, ARM64Reg Rm);
@@ -1296,6 +1303,7 @@ class ARM64FloatEmitter
void EmitThreeSame(bool U, u32 size, u32 opcode, ARM64Reg Rd, ARM64Reg Rn, ARM64Reg Rm);
void EmitCopy(bool Q, u32 op, u32 imm5, u32 imm4, ARM64Reg Rd, ARM64Reg Rn);
void EmitScalar2RegMisc(bool U, u32 size, u32 opcode, ARM64Reg Rd, ARM64Reg Rn);
void EmitScalarPairwise(bool U, u32 size, u32 opcode, ARM64Reg Rd, ARM64Reg Rn);
void Emit2RegMisc(bool Q, bool U, u32 size, u32 opcode, ARM64Reg Rd, ARM64Reg Rn);
void EmitLoadStoreSingleStructure(bool L, bool R, u32 opcode, bool S, u32 size, ARM64Reg Rt,
ARM64Reg Rn);
@@ -177,6 +177,10 @@ class JitArm64 : public JitBase, public Arm64Gen::ARM64CodeBlock, public CommonA

void FloatCompare(UGeckoInstruction inst, bool upper = false);

// temp_gpr can be INVALID_REG if single is true
void EmitQuietNaNBitConstant(Arm64Gen::ARM64Reg dest_reg, bool single,
Arm64Gen::ARM64Reg temp_gpr);

bool IsFPRStoreSafe(size_t guest_reg) const;

protected:
@@ -3,6 +3,8 @@

#include "Core/PowerPC/JitArm64/Jit.h"

#include <optional>

#include "Common/Arm64Emitter.h"
#include "Common/CPUDetect.h"
#include "Common/CommonTypes.h"
@@ -66,11 +68,19 @@ void JitArm64::fp_arith(UGeckoInstruction inst)
FALLBACK_IF(inst.Rc);
FALLBACK_IF(jo.fp_exceptions || (jo.div_by_zero_exceptions && inst.SUBOP5 == 18));

u32 a = inst.FA, b = inst.FB, c = inst.FC, d = inst.FD;
u32 op5 = inst.SUBOP5;
const u32 a = inst.FA;
const u32 b = inst.FB;
const u32 c = inst.FC;
const u32 d = inst.FD;
const u32 op5 = inst.SUBOP5;

const bool use_c = op5 >= 25; // fmul and all kind of fmaddXX
const bool use_b = op5 != 25; // fmul uses no B
const bool fma = use_b && use_c;
const bool negate_result = (op5 & ~0x1) == 30;

// Addition and subtraction can't generate new NaNs, they can only take NaNs from inputs
const bool can_generate_nan = (op5 & ~0x1) != 20;

const bool output_is_single = inst.OPCD == 59;
const bool inaccurate_fma = op5 > 25 && !Config::Get(Config::SESSION_USE_FMA);
@@ -82,53 +92,69 @@ void JitArm64::fp_arith(UGeckoInstruction inst)
};
const bool inputs_are_singles = inputs_are_singles_func();

const RegType type =
(inputs_are_singles && output_is_single) ? RegType::LowerPairSingle : RegType::LowerPair;
const bool single = inputs_are_singles && output_is_single;
const RegType type = single ? RegType::LowerPairSingle : RegType::LowerPair;
const RegType type_out =
output_is_single ? (inputs_are_singles ? RegType::DuplicatedSingle : RegType::Duplicated) :
RegType::LowerPair;
const auto reg_encoder =
(inputs_are_singles && output_is_single) ? EncodeRegToSingle : EncodeRegToDouble;
const auto reg_encoder = single ? EncodeRegToSingle : EncodeRegToDouble;

const ARM64Reg VA = reg_encoder(fpr.R(a, type));
const ARM64Reg VB = use_b ? reg_encoder(fpr.R(b, type)) : ARM64Reg::INVALID_REG;
ARM64Reg VC = use_c ? reg_encoder(fpr.R(c, type)) : ARM64Reg::INVALID_REG;
const ARM64Reg VC = use_c ? reg_encoder(fpr.R(c, type)) : ARM64Reg::INVALID_REG;
const ARM64Reg VD = reg_encoder(fpr.RW(d, type_out));

ARM64Reg V0Q = ARM64Reg::INVALID_REG;
ARM64Reg V1Q = ARM64Reg::INVALID_REG;

ARM64Reg rounded_c_reg = VC;
if (round_c)
{
ASSERT_MSG(DYNA_REC, !inputs_are_singles, "Tried to apply 25-bit precision to single");

V1Q = fpr.GetReg();
V0Q = fpr.GetReg();
rounded_c_reg = reg_encoder(V0Q);
Force25BitPrecision(rounded_c_reg, VC);
}

Force25BitPrecision(reg_encoder(V1Q), VC);
VC = reg_encoder(V1Q);
ARM64Reg inaccurate_fma_reg = VD;
if (fma && inaccurate_fma && VD == VB)
{
if (V0Q == ARM64Reg::INVALID_REG)
V0Q = fpr.GetReg();
inaccurate_fma_reg = reg_encoder(V0Q);
}

ARM64Reg inaccurate_fma_temp_reg = VD;
if (inaccurate_fma && d == b)
ARM64Reg result_reg = VD;
const bool preserve_d =
m_accurate_nans && (VD == VA || (use_b && VD == VB) || (use_c && VD == VC));
if (preserve_d)
{
V0Q = fpr.GetReg();
V1Q = fpr.GetReg();
result_reg = reg_encoder(V1Q);
}

const ARM64Reg temp_gpr = m_accurate_nans && !single ? gpr.GetReg() : ARM64Reg::INVALID_REG;

inaccurate_fma_temp_reg = reg_encoder(V0Q);
if (m_accurate_nans)
{
if (V0Q == ARM64Reg::INVALID_REG)
V0Q = fpr.GetReg();
}

switch (op5)
{
case 18:
m_float_emit.FDIV(VD, VA, VB);
m_float_emit.FDIV(result_reg, VA, VB);
break;
case 20:
m_float_emit.FSUB(VD, VA, VB);
m_float_emit.FSUB(result_reg, VA, VB);
break;
case 21:
m_float_emit.FADD(VD, VA, VB);
m_float_emit.FADD(result_reg, VA, VB);
break;
case 25:
m_float_emit.FMUL(VD, VA, VC);
m_float_emit.FMUL(result_reg, VA, rounded_c_reg);
break;
// While it may seem like PowerPC's nmadd/nmsub map to AArch64's nmadd/msub [sic],
// the subtly different definitions affect how signed zeroes are handled.
@@ -138,39 +164,116 @@ void JitArm64::fp_arith(UGeckoInstruction inst)
case 30: // fnmsub: "D = -(A*C - B)" vs "Vd = -((-Va) + Vn*Vm)"
if (inaccurate_fma)
{
m_float_emit.FMUL(inaccurate_fma_temp_reg, VA, VC);
m_float_emit.FSUB(VD, inaccurate_fma_temp_reg, VB);
m_float_emit.FMUL(inaccurate_fma_reg, VA, rounded_c_reg);
m_float_emit.FSUB(result_reg, inaccurate_fma_reg, VB);
}
else
{
m_float_emit.FNMSUB(VD, VA, VC, VB);
m_float_emit.FNMSUB(result_reg, VA, rounded_c_reg, VB);
}
if (op5 == 30)
m_float_emit.FNEG(VD, VD);
break;
case 29: // fmadd: "D = A*C + B" vs "Vd = Va + Vn*Vm"
case 31: // fnmadd: "D = -(A*C + B)" vs "Vd = -(Va + Vn*Vm)"
if (inaccurate_fma)
{
m_float_emit.FMUL(inaccurate_fma_temp_reg, VA, VC);
m_float_emit.FADD(VD, inaccurate_fma_temp_reg, VB);
m_float_emit.FMUL(inaccurate_fma_reg, VA, rounded_c_reg);
m_float_emit.FADD(result_reg, inaccurate_fma_reg, VB);
}
else
{
m_float_emit.FMADD(VD, VA, VC, VB);
m_float_emit.FMADD(result_reg, VA, rounded_c_reg, VB);
}
if (op5 == 31)
m_float_emit.FNEG(VD, VD);
break;
default:
ASSERT_MSG(DYNA_REC, 0, "fp_arith");
break;
}

std::vector<FixupBranch> nan_fixups;
if (m_accurate_nans)
{
// Check if we need to handle NaNs
m_float_emit.FCMP(result_reg);
FixupBranch no_nan = B(CCFlags::CC_VC);
FixupBranch nan = B();
SetJumpTarget(no_nan);

SwitchToFarCode();
SetJumpTarget(nan);

const ARM64Reg quiet_bit_reg = reg_encoder(V0Q);

EmitQuietNaNBitConstant(quiet_bit_reg, inputs_are_singles && output_is_single, temp_gpr);

std::vector<ARM64Reg> inputs;
inputs.push_back(VA);
if (use_b && VA != VB)
inputs.push_back(VB);
if (use_c && VA != VC && (!use_b || VB != VC))
inputs.push_back(VC);

// If any inputs are NaNs, pick the first NaN of them and OR it with the quiet bit
for (size_t i = 0; i < inputs.size(); ++i)
{
// Skip checking if the input is a NaN if it's the last input and we're guaranteed to have at
// least one NaN input
const bool check_input = can_generate_nan || i != inputs.size() - 1;

const ARM64Reg input = inputs[i];
FixupBranch skip;
if (check_input)
{
m_float_emit.FCMP(input);
skip = B(CCFlags::CC_VC);
}

m_float_emit.ORR(EncodeRegToDouble(VD), EncodeRegToDouble(input),
EncodeRegToDouble(quiet_bit_reg));
nan_fixups.push_back(B());

if (check_input)
SetJumpTarget(skip);
}

std::optional<FixupBranch> nan_early_fixup;
if (can_generate_nan)
{
// There was no NaN in any of the inputs, so the NaN must have been generated by the
// arithmetic instruction. In this case, the result is already correct.
if (negate_result)
{
if (result_reg != VD)
m_float_emit.MOV(EncodeRegToDouble(VD), EncodeRegToDouble(result_reg));
nan_fixups.push_back(B());
}
else
{
nan_early_fixup = B();
}
}

SwitchToNearCode();

if (nan_early_fixup)
SetJumpTarget(*nan_early_fixup);
}

// PowerPC's nmadd/nmsub perform rounding before the final negation, which is not the case
// for any of AArch64's FMA instructions, so we negate using a separate instruction.
if (negate_result)
m_float_emit.FNEG(VD, result_reg);
else if (result_reg != VD)
m_float_emit.MOV(EncodeRegToDouble(VD), EncodeRegToDouble(result_reg));

for (FixupBranch fixup : nan_fixups)
SetJumpTarget(fixup);

if (V0Q != ARM64Reg::INVALID_REG)
fpr.Unlock(V0Q);
if (V1Q != ARM64Reg::INVALID_REG)
fpr.Unlock(V1Q);
if (temp_gpr != ARM64Reg::INVALID_REG)
gpr.Unlock(temp_gpr);

if (output_is_single)
{
@@ -782,6 +885,29 @@ void JitArm64::ConvertSingleToDoublePair(size_t guest_reg, ARM64Reg dest_reg, AR
}
}

void JitArm64::EmitQuietNaNBitConstant(ARM64Reg dest_reg, bool single, ARM64Reg temp_gpr)
{
// dest_reg = QNaN & ~SNaN
//
// (Alternatively, dest_reg = QNaN would also work, but that would take
// two instructions to emit even for singles)

if (single)
{
m_float_emit.MOVI(32, dest_reg, 0x40, 16);
}
else
{
ASSERT(temp_gpr != ARM64Reg::INVALID_REG);

MOVI2R(EncodeRegTo64(temp_gpr), 0x0008'0000'0000'0000);
if (IsQuad(dest_reg))
m_float_emit.DUP(64, dest_reg, EncodeRegTo64(temp_gpr));
else
m_float_emit.FMOV(dest_reg, EncodeRegTo64(temp_gpr));
}
}

bool JitArm64::IsFPRStoreSafe(size_t guest_reg) const
{
return js.fpr_is_store_safe[guest_reg];

0 comments on commit 000c6c4

Please sign in to comment.