Skip to content

Commit c3a6337

Browse files
committed
[CIR][X86] Implement lowering for AVX512 mask builtins (kadd, kand, kandn, kor, kxor, knot, kmov)
This patch adds CIR codegen support for AVX512 mask operations on X86, including kadd, kand, kandn, kor, kxor, knot, and kmov in all supported mask widths. Each builtin now lowers to the expected vector<i1> form and bitcast representations in CIR, matching the semantics of the corresponding LLVM intrinsics. The patch also adds comprehensive CIR/LLVM/OGCG tests for AVX512F, AVX512DQ, and AVX512BW to validate the lowering behavior.
1 parent 6abbbca commit c3a6337

File tree

4 files changed

+780
-2
lines changed

4 files changed

+780
-2
lines changed

clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,36 @@ static mlir::Value getMaskVecValue(CIRGenBuilderTy &builder, mlir::Location loc,
8585
return maskVec;
8686
}
8787

88+
static mlir::Value emitX86MaskAddLogic(CIRGenBuilderTy &builder,
89+
mlir::Location loc,
90+
const std::string &intrinsicName,
91+
SmallVectorImpl<mlir::Value> &ops) {
92+
93+
auto intTy = cast<cir::IntType>(ops[0].getType());
94+
unsigned numElts = intTy.getWidth();
95+
mlir::Value lhsVec = getMaskVecValue(builder, loc, ops[0], numElts);
96+
mlir::Value rhsVec = getMaskVecValue(builder, loc, ops[1], numElts);
97+
mlir::Type vecTy = lhsVec.getType();
98+
mlir::Value resVec = emitIntrinsicCallOp(builder, loc, intrinsicName, vecTy,
99+
mlir::ValueRange{lhsVec, rhsVec});
100+
return builder.createBitcast(resVec, ops[0].getType());
101+
}
102+
103+
static mlir::Value emitX86MaskLogic(CIRGenBuilderTy &builder,
104+
mlir::Location loc,
105+
cir::BinOpKind binOpKind,
106+
SmallVectorImpl<mlir::Value> &ops,
107+
bool invertLHS = false) {
108+
unsigned numElts = cast<cir::IntType>(ops[0].getType()).getWidth();
109+
mlir::Value lhs = getMaskVecValue(builder, loc, ops[0], numElts);
110+
mlir::Value rhs = getMaskVecValue(builder, loc, ops[1], numElts);
111+
112+
if (invertLHS)
113+
lhs = builder.createNot(lhs);
114+
return builder.createBitcast(builder.createBinop(loc, lhs, binOpKind, rhs),
115+
ops[0].getType());
116+
}
117+
88118
mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
89119
const CallExpr *expr) {
90120
if (builtinID == Builtin::BI__builtin_cpu_is) {
@@ -743,38 +773,75 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
743773
case X86::BI__builtin_ia32_ktestzsi:
744774
case X86::BI__builtin_ia32_ktestcdi:
745775
case X86::BI__builtin_ia32_ktestzdi:
776+
cgm.errorNYI(expr->getSourceRange(),
777+
std::string("unimplemented X86 builtin call: ") +
778+
getContext().BuiltinInfo.getName(builtinID));
779+
return {};
746780
case X86::BI__builtin_ia32_kaddqi:
781+
return emitX86MaskAddLogic(builder, getLoc(expr->getExprLoc()),
782+
"x86.avx512.kadd.b", ops);
747783
case X86::BI__builtin_ia32_kaddhi:
784+
return emitX86MaskAddLogic(builder, getLoc(expr->getExprLoc()),
785+
"x86.avx512.kadd.w", ops);
748786
case X86::BI__builtin_ia32_kaddsi:
787+
return emitX86MaskAddLogic(builder, getLoc(expr->getExprLoc()),
788+
"x86.avx512.kadd.d", ops);
749789
case X86::BI__builtin_ia32_kadddi:
790+
return emitX86MaskAddLogic(builder, getLoc(expr->getExprLoc()),
791+
"x86.avx512.kadd.q", ops);
750792
case X86::BI__builtin_ia32_kandqi:
751793
case X86::BI__builtin_ia32_kandhi:
752794
case X86::BI__builtin_ia32_kandsi:
753795
case X86::BI__builtin_ia32_kanddi:
796+
return emitX86MaskLogic(builder, getLoc(expr->getExprLoc()),
797+
cir::BinOpKind::And, ops);
754798
case X86::BI__builtin_ia32_kandnqi:
755799
case X86::BI__builtin_ia32_kandnhi:
756800
case X86::BI__builtin_ia32_kandnsi:
757801
case X86::BI__builtin_ia32_kandndi:
802+
return emitX86MaskLogic(builder, getLoc(expr->getExprLoc()),
803+
cir::BinOpKind::And, ops, true);
758804
case X86::BI__builtin_ia32_korqi:
759805
case X86::BI__builtin_ia32_korhi:
760806
case X86::BI__builtin_ia32_korsi:
761807
case X86::BI__builtin_ia32_kordi:
808+
return emitX86MaskLogic(builder, getLoc(expr->getExprLoc()),
809+
cir::BinOpKind::Or, ops);
762810
case X86::BI__builtin_ia32_kxnorqi:
763811
case X86::BI__builtin_ia32_kxnorhi:
764812
case X86::BI__builtin_ia32_kxnorsi:
765813
case X86::BI__builtin_ia32_kxnordi:
814+
return emitX86MaskLogic(builder, getLoc(expr->getExprLoc()),
815+
cir::BinOpKind::Xor, ops, true);
766816
case X86::BI__builtin_ia32_kxorqi:
767817
case X86::BI__builtin_ia32_kxorhi:
768818
case X86::BI__builtin_ia32_kxorsi:
769819
case X86::BI__builtin_ia32_kxordi:
820+
return emitX86MaskLogic(builder, getLoc(expr->getExprLoc()),
821+
cir::BinOpKind::Xor, ops);
770822
case X86::BI__builtin_ia32_knotqi:
771823
case X86::BI__builtin_ia32_knothi:
772824
case X86::BI__builtin_ia32_knotsi:
773-
case X86::BI__builtin_ia32_knotdi:
825+
case X86::BI__builtin_ia32_knotdi: {
826+
cir::IntType intTy = cast<cir::IntType>(ops[0].getType());
827+
unsigned numElts = intTy.getWidth();
828+
mlir::Value resVec =
829+
getMaskVecValue(builder, getLoc(expr->getExprLoc()), ops[0], numElts);
830+
return builder.createBitcast(builder.createNot(resVec), ops[0].getType());
831+
}
774832
case X86::BI__builtin_ia32_kmovb:
775833
case X86::BI__builtin_ia32_kmovw:
776834
case X86::BI__builtin_ia32_kmovd:
777-
case X86::BI__builtin_ia32_kmovq:
835+
case X86::BI__builtin_ia32_kmovq: {
836+
// Bitcast to vXi1 type and then back to integer. This gets the mask
837+
// register type into the IR, but might be optimized out depending on
838+
// what's around it.
839+
cir::IntType intTy = cast<cir::IntType>(ops[0].getType());
840+
unsigned numElts = intTy.getWidth();
841+
mlir::Value resVec =
842+
getMaskVecValue(builder, getLoc(expr->getExprLoc()), ops[0], numElts);
843+
return builder.createBitcast(resVec, ops[0].getType());
844+
}
778845
case X86::BI__builtin_ia32_kunpckdi:
779846
case X86::BI__builtin_ia32_kunpcksi:
780847
case X86::BI__builtin_ia32_kunpckhi:

0 commit comments

Comments
 (0)