Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CIR] Vector types, comparison operators #432

Merged
merged 1 commit into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,31 @@ def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// VecCmp
//===----------------------------------------------------------------------===//

def VecCmpOp : CIR_Op<"vec.cmp", [Pure, SameTypeOperands]> {

let summary = "Compare two vectors";
let description = [{
The `cir.vec.cmp` operation does an element-wise comparison of two vectors
of the same type. The result is a vector of the same size as the operands
whose element type is the signed integral type that is the same size as the
element type of the operands. The values in the result are 0 or -1.
}];

let arguments = (ins Arg<CmpOpKind, "cmp kind">:$kind, CIR_VectorType:$lhs,
CIR_VectorType:$rhs);
let results = (outs CIR_VectorType:$result);

let assemblyFormat = [{
`(` $kind `,` $lhs `,` $rhs `)` `:` type($lhs) `,` type($result) attr-dict
}];

let hasVerifier = 0;
}

//===----------------------------------------------------------------------===//
// BaseClassAddr
//===----------------------------------------------------------------------===//
Expand Down
63 changes: 33 additions & 30 deletions clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,26 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
QualType LHSTy = E->getLHS()->getType();
QualType RHSTy = E->getRHS()->getType();

auto ClangCmpToCIRCmp = [](auto ClangCmp) -> mlir::cir::CmpOpKind {
switch (ClangCmp) {
case BO_LT:
return mlir::cir::CmpOpKind::lt;
case BO_GT:
return mlir::cir::CmpOpKind::gt;
case BO_LE:
return mlir::cir::CmpOpKind::le;
case BO_GE:
return mlir::cir::CmpOpKind::ge;
case BO_EQ:
return mlir::cir::CmpOpKind::eq;
case BO_NE:
return mlir::cir::CmpOpKind::ne;
default:
llvm_unreachable("unsupported comparison kind");
return mlir::cir::CmpOpKind(-1);
}
};

if (const MemberPointerType *MPT = LHSTy->getAs<MemberPointerType>()) {
assert(0 && "not implemented");
} else if (!LHSTy->isAnyComplexType() && !RHSTy->isAnyComplexType()) {
Expand All @@ -773,12 +793,18 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
mlir::Value RHS = BOInfo.RHS;

if (LHSTy->isVectorType()) {
// Cannot handle any vector just yet.
assert(0 && "not implemented");
// If AltiVec, the comparison results in a numeric type, so we use
// intrinsics comparing vectors and giving 0 or 1 as a result
if (!E->getType()->isVectorType())
assert(0 && "not implemented");
if (!E->getType()->isVectorType()) {
// If AltiVec, the comparison results in a numeric type, so we use
// intrinsics comparing vectors and giving 0 or 1 as a result
llvm_unreachable("NYI: AltiVec comparison");
} else {
// Other kinds of vectors. Element-wise comparison returning
// a vector.
mlir::cir::CmpOpKind Kind = ClangCmpToCIRCmp(E->getOpcode());
return Builder.create<mlir::cir::VecCmpOp>(
CGF.getLoc(BOInfo.Loc), CGF.getCIRType(BOInfo.Ty), Kind,
BOInfo.LHS, BOInfo.RHS);
}
}
if (BOInfo.isFixedPointOp()) {
assert(0 && "not implemented");
Expand All @@ -793,30 +819,7 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
llvm_unreachable("NYI");
}

mlir::cir::CmpOpKind Kind;
switch (E->getOpcode()) {
case BO_LT:
Kind = mlir::cir::CmpOpKind::lt;
break;
case BO_GT:
Kind = mlir::cir::CmpOpKind::gt;
break;
case BO_LE:
Kind = mlir::cir::CmpOpKind::le;
break;
case BO_GE:
Kind = mlir::cir::CmpOpKind::ge;
break;
case BO_EQ:
Kind = mlir::cir::CmpOpKind::eq;
break;
case BO_NE:
Kind = mlir::cir::CmpOpKind::ne;
break;
default:
llvm_unreachable("unsupported");
}

mlir::cir::CmpOpKind Kind = ClangCmpToCIRCmp(E->getOpcode());
return Builder.create<mlir::cir::CmpOp>(CGF.getLoc(BOInfo.Loc),
CGF.getCIRType(BOInfo.Ty), Kind,
BOInfo.LHS, BOInfo.RHS);
Expand Down
137 changes: 88 additions & 49 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,51 @@ void walkRegionSkipping(mlir::Region &region,
});
}

/// Convert from a CIR comparison kind to an LLVM IR integral comparison kind.
mlir::LLVM::ICmpPredicate
convertCmpKindToICmpPredicate(mlir::cir::CmpOpKind kind, bool isSigned) {
using CIR = mlir::cir::CmpOpKind;
using LLVMICmp = mlir::LLVM::ICmpPredicate;
switch (kind) {
case CIR::eq:
return LLVMICmp::eq;
case CIR::ne:
return LLVMICmp::ne;
case CIR::lt:
return (isSigned ? LLVMICmp::slt : LLVMICmp::ult);
case CIR::le:
return (isSigned ? LLVMICmp::sle : LLVMICmp::ule);
case CIR::gt:
return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt);
case CIR::ge:
return (isSigned ? LLVMICmp::sge : LLVMICmp::uge);
}
llvm_unreachable("Unknown CmpOpKind");
}

/// Convert from a CIR comparison kind to an LLVM IR floating-point comparison
/// kind.
mlir::LLVM::FCmpPredicate
convertCmpKindToFCmpPredicate(mlir::cir::CmpOpKind kind) {
using CIR = mlir::cir::CmpOpKind;
using LLVMFCmp = mlir::LLVM::FCmpPredicate;
switch (kind) {
case CIR::eq:
return LLVMFCmp::oeq;
case CIR::ne:
return LLVMFCmp::une;
case CIR::lt:
return LLVMFCmp::olt;
case CIR::le:
return LLVMFCmp::ole;
case CIR::gt:
return LLVMFCmp::ogt;
case CIR::ge:
return LLVMFCmp::oge;
}
llvm_unreachable("Unknown CmpOpKind");
}

} // namespace

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1131,6 +1176,41 @@ class CIRVectorExtractLowering
}
};

class CIRVectorCmpOpLowering
: public mlir::OpConversionPattern<mlir::cir::VecCmpOp> {
public:
using OpConversionPattern<mlir::cir::VecCmpOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(mlir::cir::VecCmpOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
assert(op.getType().isa<mlir::cir::VectorType>() &&
op.getLhs().getType().isa<mlir::cir::VectorType>() &&
op.getRhs().getType().isa<mlir::cir::VectorType>() &&
"Vector compare with non-vector type");
// LLVM IR vector comparison returns a vector of i1. This one-bit vector
// must be sign-extended to the correct result type.
auto elementType =
op.getLhs().getType().dyn_cast<mlir::cir::VectorType>().getEltType();
mlir::Value bitResult;
if (auto intType = elementType.dyn_cast<mlir::cir::IntType>()) {
bitResult = rewriter.create<mlir::LLVM::ICmpOp>(
op.getLoc(),
convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()),
adaptor.getLhs(), adaptor.getRhs());
} else if (elementType.isa<mlir::FloatType>()) {
bitResult = rewriter.create<mlir::LLVM::FCmpOp>(
op.getLoc(), convertCmpKindToFCmpPredicate(op.getKind()),
adaptor.getLhs(), adaptor.getRhs());
} else {
return op.emitError() << "unsupported type for VecCmpOp: " << elementType;
}
rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(
op, typeConverter->convertType(op.getType()), bitResult);
return mlir::success();
}
};

class CIRVAStartLowering
: public mlir::OpConversionPattern<mlir::cir::VAStartOp> {
public:
Expand Down Expand Up @@ -1833,50 +1913,6 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> {
public:
using OpConversionPattern<mlir::cir::CmpOp>::OpConversionPattern;

mlir::LLVM::ICmpPredicate convertToICmpPredicate(mlir::cir::CmpOpKind kind,
bool isSigned) const {
using CIR = mlir::cir::CmpOpKind;
using LLVMICmp = mlir::LLVM::ICmpPredicate;

switch (kind) {
case CIR::eq:
return LLVMICmp::eq;
case CIR::ne:
return LLVMICmp::ne;
case CIR::lt:
return (isSigned ? LLVMICmp::slt : LLVMICmp::ult);
case CIR::le:
return (isSigned ? LLVMICmp::sle : LLVMICmp::ule);
case CIR::gt:
return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt);
case CIR::ge:
return (isSigned ? LLVMICmp::sge : LLVMICmp::uge);
}
llvm_unreachable("Unknown CmpOpKind");
}

mlir::LLVM::FCmpPredicate
convertToFCmpPredicate(mlir::cir::CmpOpKind kind) const {
using CIR = mlir::cir::CmpOpKind;
using LLVMFCmp = mlir::LLVM::FCmpPredicate;

switch (kind) {
case CIR::eq:
return LLVMFCmp::ueq;
case CIR::ne:
return LLVMFCmp::une;
case CIR::lt:
return LLVMFCmp::ult;
case CIR::le:
return LLVMFCmp::ule;
case CIR::gt:
return LLVMFCmp::ugt;
case CIR::ge:
return LLVMFCmp::uge;
}
llvm_unreachable("Unknown CmpOpKind");
}

mlir::LogicalResult
matchAndRewrite(mlir::cir::CmpOp cmpOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
Expand All @@ -1885,15 +1921,17 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<mlir::cir::CmpOp> {

// Lower to LLVM comparison op.
if (auto intTy = type.dyn_cast<mlir::cir::IntType>()) {
auto kind = convertToICmpPredicate(cmpOp.getKind(), intTy.isSigned());
auto kind =
convertCmpKindToICmpPredicate(cmpOp.getKind(), intTy.isSigned());
llResult = rewriter.create<mlir::LLVM::ICmpOp>(
cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
} else if (auto ptrTy = type.dyn_cast<mlir::cir::PointerType>()) {
auto kind = convertToICmpPredicate(cmpOp.getKind(), /* isSigned=*/false);
auto kind = convertCmpKindToICmpPredicate(cmpOp.getKind(),
/* isSigned=*/false);
llResult = rewriter.create<mlir::LLVM::ICmpOp>(
cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
} else if (type.isa<mlir::FloatType>()) {
auto kind = convertToFCmpPredicate(cmpOp.getKind());
auto kind = convertCmpKindToFCmpPredicate(cmpOp.getKind());
llResult = rewriter.create<mlir::LLVM::FCmpOp>(
cmpOp.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
} else {
Expand Down Expand Up @@ -2088,8 +2126,9 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns,
CIRTernaryOpLowering, CIRGetMemberOpLowering, CIRSwitchOpLowering,
CIRPtrDiffOpLowering, CIRCopyOpLowering, CIRMemCpyOpLowering,
CIRFAbsOpLowering, CIRVTableAddrPointOpLowering, CIRVectorCreateLowering,
CIRVectorInsertLowering, CIRVectorExtractLowering, CIRStackSaveLowering,
CIRStackRestoreLowering>(converter, patterns.getContext());
CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering,
CIRStackSaveLowering, CIRStackRestoreLowering>(converter,
patterns.getContext());
}

namespace {
Expand Down
29 changes: 29 additions & 0 deletions clang/test/CIR/CodeGen/vectype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

typedef int vi4 __attribute__((vector_size(16)));
typedef double vd2 __attribute__((vector_size(16)));
typedef long long vll2 __attribute__((vector_size(16)));

void vector_int_test(int x) {

Expand Down Expand Up @@ -49,6 +50,20 @@ void vector_int_test(int x) {
// CHECK: %{{[0-9]+}} = cir.unary(minus, %{{[0-9]+}}) : !cir.vector<!s32i x 4>, !cir.vector<!s32i x 4>
vi4 n = ~a;
// CHECK: %{{[0-9]+}} = cir.unary(not, %{{[0-9]+}}) : !cir.vector<!s32i x 4>, !cir.vector<!s32i x 4>

// Comparisons
vi4 o = a == b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(eq, %{{[0-9]+}}, %{{[0-9]+}}) : <!s32i x 4>, <!s32i x 4>
vi4 p = a != b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(ne, %{{[0-9]+}}, %{{[0-9]+}}) : <!s32i x 4>, <!s32i x 4>
vi4 q = a < b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(lt, %{{[0-9]+}}, %{{[0-9]+}}) : <!s32i x 4>, <!s32i x 4>
vi4 r = a > b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(gt, %{{[0-9]+}}, %{{[0-9]+}}) : <!s32i x 4>, <!s32i x 4>
vi4 s = a <= b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(le, %{{[0-9]+}}, %{{[0-9]+}}) : <!s32i x 4>, <!s32i x 4>
vi4 t = a >= b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(ge, %{{[0-9]+}}, %{{[0-9]+}}) : <!s32i x 4>, <!s32i x 4>
}

void vector_double_test(int x, double y) {
Expand Down Expand Up @@ -86,4 +101,18 @@ void vector_double_test(int x, double y) {
// CHECK: %{{[0-9]+}} = cir.unary(plus, %{{[0-9]+}}) : !cir.vector<f64 x 2>, !cir.vector<f64 x 2>
vd2 m = -a;
// CHECK: %{{[0-9]+}} = cir.unary(minus, %{{[0-9]+}}) : !cir.vector<f64 x 2>, !cir.vector<f64 x 2>

// Comparisons
vll2 o = a == b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(eq, %{{[0-9]+}}, %{{[0-9]+}}) : <f64 x 2>, <!s64i x 2>
vll2 p = a != b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(ne, %{{[0-9]+}}, %{{[0-9]+}}) : <f64 x 2>, <!s64i x 2>
vll2 q = a < b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(lt, %{{[0-9]+}}, %{{[0-9]+}}) : <f64 x 2>, <!s64i x 2>
vll2 r = a > b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(gt, %{{[0-9]+}}, %{{[0-9]+}}) : <f64 x 2>, <!s64i x 2>
vll2 s = a <= b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(le, %{{[0-9]+}}, %{{[0-9]+}}) : <f64 x 2>, <!s64i x 2>
vll2 t = a >= b;
// CHECK: %{{[0-9]+}} = cir.vec.cmp(ge, %{{[0-9]+}}, %{{[0-9]+}}) : <f64 x 2>, <!s64i x 2>
}
10 changes: 5 additions & 5 deletions clang/test/CIR/Lowering/cmp.cir
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,27 @@ module {
%23 = cir.load %2 : cir.ptr <f32>, f32
%24 = cir.load %3 : cir.ptr <f32>, f32
%25 = cir.cmp(gt, %23, %24) : f32, !cir.bool
// CHECK: llvm.fcmp "ugt"
// CHECK: llvm.fcmp "ogt"
%26 = cir.load %2 : cir.ptr <f32>, f32
%27 = cir.load %3 : cir.ptr <f32>, f32
%28 = cir.cmp(eq, %26, %27) : f32, !cir.bool
// CHECK: llvm.fcmp "ueq"
// CHECK: llvm.fcmp "oeq"
%29 = cir.load %2 : cir.ptr <f32>, f32
%30 = cir.load %3 : cir.ptr <f32>, f32
%31 = cir.cmp(lt, %29, %30) : f32, !cir.bool
// CHECK: llvm.fcmp "ult"
// CHECK: llvm.fcmp "olt"
%32 = cir.load %2 : cir.ptr <f32>, f32
%33 = cir.load %3 : cir.ptr <f32>, f32
%34 = cir.cmp(ge, %32, %33) : f32, !cir.bool
// CHECK: llvm.fcmp "uge"
// CHECK: llvm.fcmp "oge"
%35 = cir.load %2 : cir.ptr <f32>, f32
%36 = cir.load %3 : cir.ptr <f32>, f32
%37 = cir.cmp(ne, %35, %36) : f32, !cir.bool
// CHECK: llvm.fcmp "une"
%38 = cir.load %2 : cir.ptr <f32>, f32
%39 = cir.load %3 : cir.ptr <f32>, f32
%40 = cir.cmp(le, %38, %39) : f32, !cir.bool
// CHECK: llvm.fcmp "ule"
// CHECK: llvm.fcmp "ole"

// Pointer comparisons.
%41 = cir.cmp(ne, %0, %1) : !cir.ptr<!s32i>, !cir.bool
Expand Down
Loading
Loading