diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 42a231b82d47ba..3598b0b6f73a2d 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -3746,6 +3746,10 @@ SDValue AArch64TargetLowering::LowerBITCAST(SDValue Op, if (OpVT != MVT::f16 && OpVT != MVT::bf16) return SDValue(); + // Bitcasts between f16 and bf16 are legal. + if (ArgVT == MVT::f16 || ArgVT == MVT::bf16) + return Op; + assert(ArgVT == MVT::i16); SDLoc DL(Op); diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td index 4ea70da56c237e..740ebd52b85fa7 100644 --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -7546,6 +7546,9 @@ def : Pat<(i64 (bitconvert (f64 FPR64:$Xn))), def : Pat<(i64 (bitconvert (v1f64 V64:$Vn))), (COPY_TO_REGCLASS V64:$Vn, GPR64)>; +def : Pat<(f16 (bitconvert (bf16 FPR16:$src))), (f16 FPR16:$src)>; +def : Pat<(bf16 (bitconvert (f16 FPR16:$src))), (bf16 FPR16:$src)>; + let Predicates = [IsLE] in { def : Pat<(v1i64 (bitconvert (v2i32 FPR64:$src))), (v1i64 FPR64:$src)>; def : Pat<(v1i64 (bitconvert (v4i16 FPR64:$src))), (v1i64 FPR64:$src)>; diff --git a/llvm/test/CodeGen/AArch64/bf16.ll b/llvm/test/CodeGen/AArch64/bf16.ll index f6226b86777a83..49545cb30c09d0 100644 --- a/llvm/test/CodeGen/AArch64/bf16.ll +++ b/llvm/test/CodeGen/AArch64/bf16.ll @@ -82,3 +82,17 @@ define { <8 x bfloat>, <8 x bfloat>* } @test_store_post_v8bf16(<8 x bfloat> %val ret { <8 x bfloat>, <8 x bfloat>* } %res } + +define bfloat @test_bitcast_halftobfloat(half %a) nounwind { +; CHECK-LABEL: test_bitcast_halftobfloat: +; CHECK-NEXT: ret + %r = bitcast half %a to bfloat + ret bfloat %r +} + +define half @test_bitcast_bfloattohalf(bfloat %a) nounwind { +; CHECK-LABEL: test_bitcast_bfloattohalf: +; CHECK-NEXT: ret + %r = bitcast bfloat %a to half + ret half %r +}