Skip to content

Commit

Permalink
[ARM] Fix Crashes in fp16/bf16 Inline Asm
Browse files Browse the repository at this point in the history
We were still seeing occasional crashes with inline assembly blocks
using fp16/bf16 after my previous patches:
- https://reviews.llvm.org/rGff4027d152d0
- https://reviews.llvm.org/rG7d15212b8c0c
- https://reviews.llvm.org/rG20b2d11896d9

It turns out:
- The original two commits were wrong, and we should have always been
  choosing the SPR register class, not the HPR register class, so that
  LLVM's SelectionDAGBuilder correctly did the right splits/joins.
- The `splitValueIntoRegisterParts`/`joinRegisterPartsIntoValue` changes
  from rG20b2d11896d9 are still correct, even though they sometimes
  result in inefficient codegen of casts between fp16/bf16 and i32/f32
  (which is visible in these tests).

This patch fixes crashes in `getCopyToParts` and when trying to select
`(bf16 (bitconvert (fp16 ...)))` dags when Neon is enabled.

This patch also adds support for passing fp16/bf16 values using the 'x'
constraint that is LLVM-specific. This should broadly match how we pass
with 't' and 'w', but with a different set of valid S registers.

Differential Revision: https://reviews.llvm.org/D147715
  • Loading branch information
lenary committed Apr 13, 2023
1 parent 4a6a4f8 commit 9ee4fe6
Show file tree
Hide file tree
Showing 2 changed files with 342 additions and 30 deletions.
18 changes: 3 additions & 15 deletions llvm/lib/Target/ARM/ARMISelLowering.cpp
Expand Up @@ -20347,13 +20347,7 @@ RCPair ARMTargetLowering::getRegForInlineAsmConstraint(
case 'w':
if (VT == MVT::Other)
break;
if (VT == MVT::f16)
return RCPair(0U, Subtarget->hasFullFP16() ? &ARM::HPRRegClass
: &ARM::SPRRegClass);
if (VT == MVT::bf16)
return RCPair(0U, Subtarget->hasBF16() ? &ARM::HPRRegClass
: &ARM::SPRRegClass);
if (VT == MVT::f32)
if (VT == MVT::f32 || VT == MVT::f16 || VT == MVT::bf16)
return RCPair(0U, &ARM::SPRRegClass);
if (VT.getSizeInBits() == 64)
return RCPair(0U, &ARM::DPRRegClass);
Expand All @@ -20363,7 +20357,7 @@ RCPair ARMTargetLowering::getRegForInlineAsmConstraint(
case 'x':
if (VT == MVT::Other)
break;
if (VT == MVT::f32)
if (VT == MVT::f32 || VT == MVT::f16 || VT == MVT::bf16)
return RCPair(0U, &ARM::SPR_8RegClass);
if (VT.getSizeInBits() == 64)
return RCPair(0U, &ARM::DPR_8RegClass);
Expand All @@ -20373,13 +20367,7 @@ RCPair ARMTargetLowering::getRegForInlineAsmConstraint(
case 't':
if (VT == MVT::Other)
break;
if (VT == MVT::f16)
return RCPair(0U, Subtarget->hasFullFP16() ? &ARM::HPRRegClass
: &ARM::SPRRegClass);
if (VT == MVT::bf16)
return RCPair(0U, Subtarget->hasBF16() ? &ARM::HPRRegClass
: &ARM::SPRRegClass);
if (VT == MVT::f32 || VT == MVT::i32)
if (VT == MVT::f32 || VT == MVT::i32 || VT == MVT::f16 || VT == MVT::bf16)
return RCPair(0U, &ARM::SPRRegClass);
if (VT.getSizeInBits() == 64)
return RCPair(0U, &ARM::DPR_VFP2RegClass);
Expand Down

0 comments on commit 9ee4fe6

Please sign in to comment.