diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp index 72e6bea244802..21b61b7482894 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp @@ -20,6 +20,7 @@ #include "clang/Basic/TargetBuiltins.h" #include "clang/CIR/Dialect/IR/CIRTypes.h" #include "clang/CIR/MissingFeatures.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" using namespace clang; @@ -362,6 +363,28 @@ static mlir::Value emitX86Muldq(CIRGenBuilderTy &builder, mlir::Location loc, return builder.createMul(loc, lhs, rhs); } +static mlir::Value emitX86CvtF16ToFloatExpr(CIRGenBuilderTy &builder, + mlir::Location loc, + mlir::Type dstTy, + SmallVectorImpl &ops) { + + mlir::Value src = ops[0]; + mlir::Value passthru = ops[1]; + mlir::Value mask = ops[2]; + + auto vecTy = mlir::cast(src.getType()); + uint64_t numElems = vecTy.getSize(); + + mask = getMaskVecValue(builder, loc, mask, numElems); + + auto halfTy = cir::VectorType::get(builder.getF16Type(), numElems); + mlir::Value srcF16 = builder.createBitcast(loc, src, halfTy); + + mlir::Value res = builder.createFloatingCast(srcF16, dstTy); + + return emitX86Select(builder, loc, mask, res, passthru); +} + static mlir::Value emitX86vpcom(CIRGenBuilderTy &builder, mlir::Location loc, llvm::SmallVector ops, bool isSigned) { @@ -1688,12 +1711,43 @@ CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID, const CallExpr *expr) { case X86::BI__builtin_ia32_cmpnltsd: case X86::BI__builtin_ia32_cmpnlesd: case X86::BI__builtin_ia32_cmpordsd: + cgm.errorNYI(expr->getSourceRange(), + std::string("unimplemented X86 builtin call: ") + + getContext().BuiltinInfo.getName(builtinID)); + return mlir::Value{}; case X86::BI__builtin_ia32_vcvtph2ps_mask: case X86::BI__builtin_ia32_vcvtph2ps256_mask: - case X86::BI__builtin_ia32_vcvtph2ps512_mask: - case X86::BI__builtin_ia32_cvtneps2bf16_128_mask: + case X86::BI__builtin_ia32_vcvtph2ps512_mask: { + mlir::Location loc = getLoc(expr->getExprLoc()); + return emitX86CvtF16ToFloatExpr(builder, loc, convertType(expr->getType()), + ops); + } + case X86::BI__builtin_ia32_cvtneps2bf16_128_mask: { + mlir::Location loc = getLoc(expr->getExprLoc()); + unsigned numElts = cast(ops[0].getType()).getSize(); + ops[2] = getMaskVecValue(builder, loc, ops[2], numElts); + return emitIntrinsicCallOp(builder, loc, + "x86.avx512bf16.mask.cvtneps2bf16.128", + convertType(expr->getType()), ops); + } case X86::BI__builtin_ia32_cvtneps2bf16_256_mask: - case X86::BI__builtin_ia32_cvtneps2bf16_512_mask: + case X86::BI__builtin_ia32_cvtneps2bf16_512_mask: { + mlir::Location loc = getLoc(expr->getExprLoc()); + unsigned numElts = cast(ops[0].getType()).getSize(); + ops[2] = getMaskVecValue(builder, loc, ops[2], numElts); + + StringRef intrinsicName; + if (builtinID == X86::BI__builtin_ia32_cvtneps2bf16_256_mask) { + intrinsicName = "x86.avx512bf16.cvtneps2bf16.256"; + } else { + intrinsicName = "x86.avx512bf16.cvtneps2bf16.512"; + } + + mlir::Value intrinsicResult = emitIntrinsicCallOp( + builder, loc, intrinsicName, convertType(expr->getType()), ops); + + return emitX86Select(builder, loc, ops[2], intrinsicResult, ops[1]); + } case X86::BI__cpuid: case X86::BI__cpuidex: case X86::BI__emul: diff --git a/clang/test/CIR/CodeGenBuiltins/X86/avx512vlbf16-builtins.c b/clang/test/CIR/CodeGenBuiltins/X86/avx512vlbf16-builtins.c new file mode 100644 index 0000000000000..0948ec85a6766 --- /dev/null +++ b/clang/test/CIR/CodeGenBuiltins/X86/avx512vlbf16-builtins.c @@ -0,0 +1,98 @@ +// RUN: %clang_cc1 -fclangir -triple x86_64-unknown-linux-gnu -target-feature +avx512fp16 -target-feature +avx512bf16 -emit-llvm %s -o - | FileCheck %s --check-prefix=CIR +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -target-feature +avx512fp16 -target-feature +avx512bf16 -emit-llvm %s -o - | FileCheck %s --check-prefix=OGCG +// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -target-feature +avx512fp16 -target-feature +avx512bf16 -emit-llvm %s -o - | FileCheck %s --check-prefix=LLVM + +// REQUIRES: avx512fp16 +// REQUIRES: avx512bf16 + +#include + +// CIR-LABEL: test_mm512_mask_cvtneps_pbh +// CIR: cir.call @_mm512_mask_cvtneps_pbh({{.*}}, {{.*}}, {{.*}}) +// LLVM-LABEL: test_mm512_mask_cvtneps_pbh +// LLVM: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512 +// OGCG-LABEL: test_mm512_mask_cvtneps_pbh +// OGCG: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512 +__m256bh test_mm512_mask_cvtneps_pbh(__m256bh src, __mmask16 k, __m512 a) { + return _mm512_mask_cvtneps_pbh(src, k, a); +} + +// CIR-LABEL: test_mm512_maskz_cvtneps_pbh +// CIR: cir.call @_mm512_maskz_cvtneps_pbh({{.*}}, {{.*}}) +// LLVM-LABEL: test_mm512_maskz_cvtneps_pbh +// LLVM: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512 +// OGCG-LABEL: test_mm512_maskz_cvtneps_pbh +// OGCG: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512 +__m256bh test_mm512_maskz_cvtneps_pbh(__mmask16 k, __m512 a) { + return _mm512_maskz_cvtneps_pbh(k, a); +} + +// CIR-LABEL: test_mm256_mask_cvtneps_pbh +// CIR: cir.call @_mm256_mask_cvtneps_pbh({{.*}}, {{.*}}, {{.*}}) +// LLVM-LABEL: test_mm256_mask_cvtneps_pbh +// LLVM: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256 +// OGCG-LABEL: test_mm256_mask_cvtneps_pbh +// OGCG: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256 +__m128bh test_mm256_mask_cvtneps_pbh(__m128bh src, __mmask8 k, __m256 a) { + return _mm256_mask_cvtneps_pbh(src, k, a); +} + +// CIR-LABEL: test_mm256_maskz_cvtneps_pbh +// CIR: cir.call @_mm256_maskz_cvtneps_pbh({{.*}}, {{.*}}) +// LLVM-LABEL: test_mm256_maskz_cvtneps_pbh +// LLVM: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256 +// OGCG-LABEL: test_mm256_maskz_cvtneps_pbh +// OGCG: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256 +__m128bh test_mm256_maskz_cvtneps_pbh(__mmask8 k, __m256 a) { + return _mm256_maskz_cvtneps_pbh(k, a); +} + +// CIR-LABEL: test_mm_mask_cvtneps_pbh +// CIR: cir.call @_mm_mask_cvtneps_pbh({{.*}}, {{.*}}, {{.*}}) +// LLVM-LABEL: test_mm_mask_cvtneps_pbh +// LLVM: call <4 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.128 +// OGCG-LABEL: test_mm_mask_cvtneps_pbh +// OGCG: call <4 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.128 +__m64bh test_mm_mask_cvtneps_pbh(__m64bh src, __mmask8 k, __m128 a) { + return _mm_mask_cvtneps_pbh(src, k, a); +} + +// CIR-LABEL: test_mm_maskz_cvtneps_pbh +// CIR: cir.call @_mm_maskz_cvtneps_pbh({{.*}}, {{.*}}) +// LLVM-LABEL: test_mm_maskz_cvtneps_pbh +// LLVM: call <4 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.128 +// OGCG-LABEL: test_mm_maskz_cvtneps_pbh +// OGCG: call <4 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.128 +__m64bh test_mm_maskz_cvtneps_pbh(__mmask8 k, __m128 a) { + return _mm_maskz_cvtneps_pbh(k, a); +} + +// CIR-LABEL: test_mm512_cvtneps_pbh +// CIR: cir.call @_mm512_cvtneps_pbh({{.*}}) +// LLVM-LABEL: test_mm512_cvtneps_pbh +// LLVM: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512 +// OGCG-LABEL: test_mm512_cvtneps_pbh +// OGCG: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512 +__m256bh test_mm512_cvtneps_pbh(__m512 a) { + return _mm512_cvtneps_pbh(a); +} + +// CIR-LABEL: test_mm256_cvtneps_pbh +// CIR: cir.call @_mm256_cvtneps_pbh({{.*}}) +// LLVM-LABEL: test_mm256_cvtneps_pbh +// LLVM: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256 +// OGCG-LABEL: test_mm256_cvtneps_pbh +// OGCG: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256 +__m128bh test_mm256_cvtneps_pbh(__m256 a) { + return _mm256_cvtneps_pbh(a); +} + +// CIR-LABEL: test_mm_cvtneps_pbh +// CIR: cir.call @_mm_cvtneps_pbh({{.*}}) +// LLVM-LABEL: test_mm_cvtneps_pbh +// LLVM: call <4 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.128 +// OGCG-LABEL: test_mm_cvtneps_pbh +// OGCG: call <4 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.128 +__m64bh test_mm_cvtneps_pbh(__m128 a) { + return _mm_cvtneps_pbh(a); +}