diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp index abe5be7638255..00f94e48a3f9a 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp @@ -2825,6 +2825,8 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) { report_fatal_error("Do not know how to soft promote this operator's " "result!"); + case ISD::ARITH_FENCE: + R = SoftPromoteHalfRes_ARITH_FENCE(N); break; case ISD::BITCAST: R = SoftPromoteHalfRes_BITCAST(N); break; case ISD::ConstantFP: R = SoftPromoteHalfRes_ConstantFP(N); break; case ISD::EXTRACT_VECTOR_ELT: @@ -2904,6 +2906,11 @@ void DAGTypeLegalizer::SoftPromoteHalfResult(SDNode *N, unsigned ResNo) { SetSoftPromotedHalf(SDValue(N, ResNo), R); } +SDValue DAGTypeLegalizer::SoftPromoteHalfRes_ARITH_FENCE(SDNode *N) { + return DAG.getNode(ISD::ARITH_FENCE, SDLoc(N), MVT::i16, + BitConvertToInteger(N->getOperand(0))); +} + SDValue DAGTypeLegalizer::SoftPromoteHalfRes_BITCAST(SDNode *N) { return BitConvertToInteger(N->getOperand(0)); } diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h index 49be824deb513..e9714f6f72b6b 100644 --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -726,6 +726,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer { void SetSoftPromotedHalf(SDValue Op, SDValue Result); void SoftPromoteHalfResult(SDNode *N, unsigned ResNo); + SDValue SoftPromoteHalfRes_ARITH_FENCE(SDNode *N); SDValue SoftPromoteHalfRes_BinOp(SDNode *N); SDValue SoftPromoteHalfRes_BITCAST(SDNode *N); SDValue SoftPromoteHalfRes_ConstantFP(SDNode *N); diff --git a/llvm/test/CodeGen/X86/arithmetic_fence2.ll b/llvm/test/CodeGen/X86/arithmetic_fence2.ll index 6a854b58fc02d..3c2ef21527f50 100644 --- a/llvm/test/CodeGen/X86/arithmetic_fence2.ll +++ b/llvm/test/CodeGen/X86/arithmetic_fence2.ll @@ -157,6 +157,160 @@ define <8 x float> @f6(<8 x float> %a) { ret <8 x float> %3 } +define half @f7(half %a) nounwind { +; X86-LABEL: f7: +; X86: # %bb.0: +; X86-NEXT: pinsrw $0, {{[0-9]+}}(%esp), %xmm0 +; X86-NEXT: #ARITH_FENCE +; X86-NEXT: retl +; +; X64-LABEL: f7: +; X64: # %bb.0: +; X64-NEXT: #ARITH_FENCE +; X64-NEXT: retq + %b = call half @llvm.arithmetic.fence.f16(half %a) + ret half %b +} + +define bfloat @f8(bfloat %a) nounwind { +; X86-LABEL: f8: +; X86: # %bb.0: +; X86-NEXT: movzwl {{[0-9]+}}(%esp), %eax +; X86-NEXT: #ARITH_FENCE +; X86-NEXT: pinsrw $0, %eax, %xmm0 +; X86-NEXT: retl +; +; X64-LABEL: f8: +; X64: # %bb.0: +; X64-NEXT: pextrw $0, %xmm0, %eax +; X64-NEXT: #ARITH_FENCE +; X64-NEXT: pinsrw $0, %eax, %xmm0 +; X64-NEXT: retq + %b = call bfloat @llvm.arithmetic.fence.bf16(bfloat %a) + ret bfloat %b +} + +define <2 x half> @f9(<2 x half> %a) nounwind { +; X86-LABEL: f9: +; X86: # %bb.0: +; X86-NEXT: movdqa %xmm0, %xmm1 +; X86-NEXT: psrld $16, %xmm1 +; X86-NEXT: #ARITH_FENCE +; X86-NEXT: #ARITH_FENCE +; X86-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] +; X86-NEXT: retl +; +; X64-LABEL: f9: +; X64: # %bb.0: +; X64-NEXT: movdqa %xmm0, %xmm1 +; X64-NEXT: psrld $16, %xmm1 +; X64-NEXT: #ARITH_FENCE +; X64-NEXT: #ARITH_FENCE +; X64-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] +; X64-NEXT: retq + %b = call <2 x half> @llvm.arithmetic.fence.v2f16(<2 x half> %a) + ret <2 x half> %b +} + +define <3 x bfloat> @f10(<3 x bfloat> %a) nounwind { +; X86-LABEL: f10: +; X86: # %bb.0: +; X86-NEXT: pextrw $0, %xmm0, %eax +; X86-NEXT: movdqa %xmm0, %xmm1 +; X86-NEXT: psrld $16, %xmm1 +; X86-NEXT: pextrw $0, %xmm1, %ecx +; X86-NEXT: shufps {{.*#+}} xmm0 = xmm0[1,1,1,1] +; X86-NEXT: pextrw $0, %xmm0, %edx +; X86-NEXT: #ARITH_FENCE +; X86-NEXT: #ARITH_FENCE +; X86-NEXT: #ARITH_FENCE +; X86-NEXT: pinsrw $0, %eax, %xmm0 +; X86-NEXT: pinsrw $0, %ecx, %xmm1 +; X86-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] +; X86-NEXT: pinsrw $0, %edx, %xmm1 +; X86-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1] +; X86-NEXT: retl +; +; X64-LABEL: f10: +; X64: # %bb.0: +; X64-NEXT: pextrw $0, %xmm0, %eax +; X64-NEXT: movdqa %xmm0, %xmm1 +; X64-NEXT: psrld $16, %xmm1 +; X64-NEXT: pextrw $0, %xmm1, %ecx +; X64-NEXT: shufps {{.*#+}} xmm0 = xmm0[1,1,1,1] +; X64-NEXT: pextrw $0, %xmm0, %edx +; X64-NEXT: #ARITH_FENCE +; X64-NEXT: #ARITH_FENCE +; X64-NEXT: #ARITH_FENCE +; X64-NEXT: pinsrw $0, %eax, %xmm0 +; X64-NEXT: pinsrw $0, %ecx, %xmm1 +; X64-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1],xmm0[2],xmm1[2],xmm0[3],xmm1[3] +; X64-NEXT: pinsrw $0, %edx, %xmm1 +; X64-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1] +; X64-NEXT: retq + %b = call <3 x bfloat> @llvm.arithmetic.fence.v3bf16(<3 x bfloat> %a) + ret <3 x bfloat> %b +} + +define <4 x bfloat> @f11(<4 x bfloat> %a) nounwind { +; X86-LABEL: f11: +; X86: # %bb.0: +; X86-NEXT: pushl %esi +; X86-NEXT: movdqa %xmm0, %xmm1 +; X86-NEXT: psrlq $48, %xmm1 +; X86-NEXT: pextrw $0, %xmm1, %eax +; X86-NEXT: movdqa %xmm0, %xmm1 +; X86-NEXT: shufps {{.*#+}} xmm1 = xmm1[1,1],xmm0[1,1] +; X86-NEXT: pextrw $0, %xmm1, %edx +; X86-NEXT: pextrw $0, %xmm0, %ecx +; X86-NEXT: psrld $16, %xmm0 +; X86-NEXT: pextrw $0, %xmm0, %esi +; X86-NEXT: #ARITH_FENCE +; X86-NEXT: #ARITH_FENCE +; X86-NEXT: #ARITH_FENCE +; X86-NEXT: #ARITH_FENCE +; X86-NEXT: pinsrw $0, %eax, %xmm0 +; X86-NEXT: pinsrw $0, %edx, %xmm1 +; X86-NEXT: punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3] +; X86-NEXT: pinsrw $0, %ecx, %xmm0 +; X86-NEXT: pinsrw $0, %esi, %xmm2 +; X86-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3] +; X86-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1] +; X86-NEXT: popl %esi +; X86-NEXT: retl +; +; X64-LABEL: f11: +; X64: # %bb.0: +; X64-NEXT: movdqa %xmm0, %xmm1 +; X64-NEXT: psrlq $48, %xmm1 +; X64-NEXT: pextrw $0, %xmm1, %eax +; X64-NEXT: movdqa %xmm0, %xmm1 +; X64-NEXT: shufps {{.*#+}} xmm1 = xmm1[1,1],xmm0[1,1] +; X64-NEXT: pextrw $0, %xmm1, %ecx +; X64-NEXT: pextrw $0, %xmm0, %edx +; X64-NEXT: psrld $16, %xmm0 +; X64-NEXT: pextrw $0, %xmm0, %esi +; X64-NEXT: #ARITH_FENCE +; X64-NEXT: #ARITH_FENCE +; X64-NEXT: #ARITH_FENCE +; X64-NEXT: #ARITH_FENCE +; X64-NEXT: pinsrw $0, %eax, %xmm0 +; X64-NEXT: pinsrw $0, %ecx, %xmm1 +; X64-NEXT: punpcklwd {{.*#+}} xmm1 = xmm1[0],xmm0[0],xmm1[1],xmm0[1],xmm1[2],xmm0[2],xmm1[3],xmm0[3] +; X64-NEXT: pinsrw $0, %edx, %xmm0 +; X64-NEXT: pinsrw $0, %esi, %xmm2 +; X64-NEXT: punpcklwd {{.*#+}} xmm0 = xmm0[0],xmm2[0],xmm0[1],xmm2[1],xmm0[2],xmm2[2],xmm0[3],xmm2[3] +; X64-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1] +; X64-NEXT: retq + %b = call <4 x bfloat> @llvm.arithmetic.fence.v4bf16(<4 x bfloat> %a) + ret <4 x bfloat> %b +} + +declare half @llvm.arithmetic.fence.f16(half) +declare bfloat @llvm.arithmetic.fence.bf16(bfloat) +declare <2 x half> @llvm.arithmetic.fence.v2f16(<2 x half>) +declare <3 x bfloat> @llvm.arithmetic.fence.v3bf16(<3 x bfloat>) +declare <4 x bfloat> @llvm.arithmetic.fence.v4bf16(<4 x bfloat>) declare float @llvm.arithmetic.fence.f32(float) declare double @llvm.arithmetic.fence.f64(double) declare <2 x float> @llvm.arithmetic.fence.v2f32(<2 x float>)