458 changes: 458 additions & 0 deletions mlir/include/mlir/Dialect/Index/IR/IndexOps.td

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllDialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
Expand Down Expand Up @@ -87,6 +88,7 @@ inline void registerAllDialects(DialectRegistry &registry) {
emitc::EmitCDialect,
func::FuncDialect,
gpu::GPUDialect,
index::IndexDialect,
LLVM::LLVMDialect,
linalg::LinalgDialect,
math::MathDialect,
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Support/LLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ namespace mlir {
using llvm::cast;
using llvm::cast_or_null;
using llvm::dyn_cast;
using llvm::dyn_cast_if_present;
using llvm::dyn_cast_or_null;
using llvm::isa;
using llvm::isa_and_nonnull;
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_subdirectory(GPUToNVVM)
add_subdirectory(GPUToROCDL)
add_subdirectory(GPUToSPIRV)
add_subdirectory(GPUToVulkan)
add_subdirectory(IndexToLLVM)
add_subdirectory(LinalgToLLVM)
add_subdirectory(LinalgToSPIRV)
add_subdirectory(LinalgToStandard)
Expand Down
17 changes: 17 additions & 0 deletions mlir/lib/Conversion/IndexToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
add_mlir_conversion_library(MLIRIndexToLLVM
IndexToLLVM.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/IndexToLLVM

DEPENDS
MLIRConversionPassIncGen

LINK_COMPONENTS
Core

LINK_LIBS PUBLIC
MLIRIndexDialect
MLIRLLVMCommonConversion
MLIRLLVMDialect
)
347 changes: 347 additions & 0 deletions mlir/lib/Conversion/IndexToLLVM/IndexToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,347 @@
//===- IndexToLLVM.cpp - Index to LLVM dialect conversion -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Index/IR/IndexAttrs.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Pass/Pass.h"

using namespace mlir;
using namespace index;

namespace {

//===----------------------------------------------------------------------===//
// ConvertIndexCeilDivS
//===----------------------------------------------------------------------===//

/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
struct ConvertIndexCeilDivS : mlir::ConvertOpToLLVMPattern<CeilDivSOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
Value m = adaptor.getRhs();
Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);

// Compute `x`.
Value mPos =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, m, zero);
Value x = rewriter.create<LLVM::SelectOp>(loc, mPos, negOne, posOne);

// Compute the positive result.
Value nPlusX = rewriter.create<LLVM::AddOp>(loc, n, x);
Value nPlusXDivM = rewriter.create<LLVM::SDivOp>(loc, nPlusX, m);
Value posRes = rewriter.create<LLVM::AddOp>(loc, nPlusXDivM, posOne);

// Compute the negative result.
Value negN = rewriter.create<LLVM::SubOp>(loc, zero, n);
Value negNDivM = rewriter.create<LLVM::SDivOp>(loc, negN, m);
Value negRes = rewriter.create<LLVM::SubOp>(loc, zero, negNDivM);

// Pick the positive result if `n` and `m` have the same sign and `n` is
// non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
Value nPos =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::sgt, n, zero);
Value sameSign =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, nPos, mPos);
Value nNonZero =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
Value cmp = rewriter.create<LLVM::AndOp>(loc, sameSign, nNonZero);
rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, posRes, negRes);
return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertIndexCeilDivU
//===----------------------------------------------------------------------===//

/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`.
struct ConvertIndexCeilDivU : mlir::ConvertOpToLLVMPattern<CeilDivUOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
Value m = adaptor.getRhs();
Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
Value one = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);

// Compute the non-zero result.
Value minusOne = rewriter.create<LLVM::SubOp>(loc, n, one);
Value quotient = rewriter.create<LLVM::UDivOp>(loc, minusOne, m);
Value plusOne = rewriter.create<LLVM::AddOp>(loc, quotient, one);

// Pick the result.
Value cmp =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, n, zero);
rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, zero, plusOne);
return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertIndexFloorDivS
//===----------------------------------------------------------------------===//

/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
/// `n*m < 0 ? -1 - (x-n)/m : n/m`.
struct ConvertIndexFloorDivS : mlir::ConvertOpToLLVMPattern<FloorDivSOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value n = adaptor.getLhs();
Value m = adaptor.getRhs();
Value zero = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 0);
Value posOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), 1);
Value negOne = rewriter.create<LLVM::ConstantOp>(loc, n.getType(), -1);

// Compute `x`.
Value mNeg =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, m, zero);
Value x = rewriter.create<LLVM::SelectOp>(loc, mNeg, posOne, negOne);

// Compute the negative result.
Value xMinusN = rewriter.create<LLVM::SubOp>(loc, x, n);
Value xMinusNDivM = rewriter.create<LLVM::SDivOp>(loc, xMinusN, m);
Value negRes = rewriter.create<LLVM::SubOp>(loc, negOne, xMinusNDivM);

// Compute the positive result.
Value posRes = rewriter.create<LLVM::SDivOp>(loc, n, m);

// Pick the negative result if `n` and `m` have different signs and `n` is
// non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
Value nNeg =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::slt, n, zero);
Value diffSign =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, nNeg, mNeg);
Value nNonZero =
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::ne, n, zero);
Value cmp = rewriter.create<LLVM::AndOp>(loc, diffSign, nNonZero);
rewriter.replaceOpWithNewOp<LLVM::SelectOp>(op, cmp, negRes, posRes);
return success();
}
};

//===----------------------------------------------------------------------===//
// CovnertIndexCast
//===----------------------------------------------------------------------===//

/// Convert a cast op. If the materialized index type is the same as the other
/// type, fold away the op. Otherwise, truncate or extend the op as appropriate.
/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
/// zero extend when the result bitwidth is larger.
template <typename CastOp, typename ExtOp>
struct ConvertIndexCast : public mlir::ConvertOpToLLVMPattern<CastOp> {
using mlir::ConvertOpToLLVMPattern<CastOp>::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type in = adaptor.getInput().getType();
Type out = this->getTypeConverter()->convertType(op.getType());
if (in == out)
rewriter.replaceOp(op, adaptor.getInput());
else if (in.getIntOrFloatBitWidth() > out.getIntOrFloatBitWidth())
rewriter.replaceOpWithNewOp<LLVM::TruncOp>(op, out, adaptor.getInput());
else
rewriter.replaceOpWithNewOp<ExtOp>(op, out, adaptor.getInput());
return success();
}
};

using ConvertIndexCastS = ConvertIndexCast<CastSOp, LLVM::SExtOp>;
using ConvertIndexCastU = ConvertIndexCast<CastUOp, LLVM::ZExtOp>;

//===----------------------------------------------------------------------===//
// ConvertIndexCmp
//===----------------------------------------------------------------------===//

/// Assert that the LLVM comparison enum lines up with index's enum.
static constexpr bool checkPredicates(LLVM::ICmpPredicate lhs,
IndexCmpPredicate rhs) {
return static_cast<int>(lhs) == static_cast<int>(rhs);
}

static_assert(
LLVM::getMaxEnumValForICmpPredicate() ==
getMaxEnumValForIndexCmpPredicate() &&
checkPredicates(LLVM::ICmpPredicate::eq, IndexCmpPredicate::EQ) &&
checkPredicates(LLVM::ICmpPredicate::ne, IndexCmpPredicate::NE) &&
checkPredicates(LLVM::ICmpPredicate::sge, IndexCmpPredicate::SGE) &&
checkPredicates(LLVM::ICmpPredicate::sgt, IndexCmpPredicate::SGT) &&
checkPredicates(LLVM::ICmpPredicate::sle, IndexCmpPredicate::SLE) &&
checkPredicates(LLVM::ICmpPredicate::slt, IndexCmpPredicate::SLT) &&
checkPredicates(LLVM::ICmpPredicate::uge, IndexCmpPredicate::UGE) &&
checkPredicates(LLVM::ICmpPredicate::ugt, IndexCmpPredicate::UGT) &&
checkPredicates(LLVM::ICmpPredicate::ule, IndexCmpPredicate::ULE) &&
checkPredicates(LLVM::ICmpPredicate::ult, IndexCmpPredicate::ULT),
"LLVM ICmpPredicate mismatches IndexCmpPredicate");

struct ConvertIndexCmp : public mlir::ConvertOpToLLVMPattern<CmpOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// The LLVM enum has the same values as the index predicate enums.
rewriter.replaceOpWithNewOp<LLVM::ICmpOp>(
op, *LLVM::symbolizeICmpPredicate(static_cast<uint32_t>(op.getPred())),
adaptor.getLhs(), adaptor.getRhs());
return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertIndexSizeOf
//===----------------------------------------------------------------------===//

/// Lower `index.sizeof` to a constant with the value of the index bitwidth.
struct ConvertIndexSizeOf : public mlir::ConvertOpToLLVMPattern<SizeOfOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(SizeOfOp op, SizeOfOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
op, getTypeConverter()->getIndexType(),
getTypeConverter()->getIndexTypeBitwidth());
return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertIndexConstant
//===----------------------------------------------------------------------===//

/// Convert an index constant. Truncate the value as appropriate.
struct ConvertIndexConstant : public mlir::ConvertOpToLLVMPattern<ConstantOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;

LogicalResult
matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type type = getTypeConverter()->getIndexType();
APInt value = op.getValue().trunc(type.getIntOrFloatBitWidth());
rewriter.replaceOpWithNewOp<LLVM::ConstantOp>(
op, type, IntegerAttr::get(type, value));
return success();
}
};

//===----------------------------------------------------------------------===//
// Trivial Conversions
//===----------------------------------------------------------------------===//

using ConvertIndexAdd = mlir::OneToOneConvertToLLVMPattern<AddOp, LLVM::AddOp>;
using ConvertIndexSub = mlir::OneToOneConvertToLLVMPattern<SubOp, LLVM::SubOp>;
using ConvertIndexMul = mlir::OneToOneConvertToLLVMPattern<MulOp, LLVM::MulOp>;
using ConvertIndexDivS =
mlir::OneToOneConvertToLLVMPattern<DivSOp, LLVM::SDivOp>;
using ConvertIndexDivU =
mlir::OneToOneConvertToLLVMPattern<DivUOp, LLVM::UDivOp>;
using ConvertIndexRemS =
mlir::OneToOneConvertToLLVMPattern<RemSOp, LLVM::SRemOp>;
using ConvertIndexRemU =
mlir::OneToOneConvertToLLVMPattern<RemUOp, LLVM::URemOp>;
using ConvertIndexMaxS =
mlir::OneToOneConvertToLLVMPattern<MaxSOp, LLVM::SMaxOp>;
using ConvertIndexMaxU =
mlir::OneToOneConvertToLLVMPattern<MaxUOp, LLVM::UMaxOp>;
using ConvertIndexBoolConstant =
mlir::OneToOneConvertToLLVMPattern<BoolConstantOp, LLVM::ConstantOp>;

} // namespace

//===----------------------------------------------------------------------===//
// Pattern Population
//===----------------------------------------------------------------------===//

void index::populateIndexToLLVMConversionPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.insert<
// clang-format off
ConvertIndexAdd,
ConvertIndexSub,
ConvertIndexMul,
ConvertIndexDivS,
ConvertIndexDivU,
ConvertIndexRemS,
ConvertIndexRemU,
ConvertIndexMaxS,
ConvertIndexMaxU,
ConvertIndexCeilDivS,
ConvertIndexCeilDivU,
ConvertIndexFloorDivS,
ConvertIndexCastS,
ConvertIndexCastU,
ConvertIndexCmp,
ConvertIndexSizeOf,
ConvertIndexConstant,
ConvertIndexBoolConstant
// clang-format on
>(typeConverter);
}

//===----------------------------------------------------------------------===//
// ODS-Generated Definitions
//===----------------------------------------------------------------------===//

namespace mlir {
#define GEN_PASS_DEF_CONVERTINDEXTOLLVMPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//

namespace {
struct ConvertIndexToLLVMPass
: public impl::ConvertIndexToLLVMPassBase<ConvertIndexToLLVMPass> {
using Base::Base;

void runOnOperation() override;
};
} // namespace

void ConvertIndexToLLVMPass::runOnOperation() {
// Configure dialect conversion.
ConversionTarget target(getContext());
target.addIllegalDialect<IndexDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();

// Set LLVM lowering options.
LowerToLLVMOptions options(&getContext());
if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
options.overrideIndexBitwidth(indexBitwidth);
LLVMTypeConverter typeConverter(&getContext(), options);

// Populate patterns and run the conversion.
RewritePatternSet patterns(&getContext());
populateIndexToLLVMConversionPatterns(typeConverter, patterns);

if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
return signalPassFailure();
}
1 change: 1 addition & 0 deletions mlir/lib/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ add_subdirectory(DLTI)
add_subdirectory(EmitC)
add_subdirectory(Func)
add_subdirectory(GPU)
add_subdirectory(Index)
add_subdirectory(Linalg)
add_subdirectory(LLVMIR)
add_subdirectory(Math)
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Index/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_subdirectory(IR)
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/Index/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
add_mlir_dialect_library(MLIRIndexDialect
IndexAttrs.cpp
IndexDialect.cpp
IndexOps.cpp

DEPENDS
MLIRIndexOpsIncGen

LINK_LIBS PUBLIC
MLIRDialect
MLIRIR
)
36 changes: 36 additions & 0 deletions mlir/lib/Dialect/Index/IR/IndexAttrs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//===- IndexAttrs.cpp - Index attribute definitions ------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Index/IR/IndexAttrs.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::index;

//===----------------------------------------------------------------------===//
// IndexDialect
//===----------------------------------------------------------------------===//

void IndexDialect::registerAttributes() {
addAttributes<
#define GET_ATTRDEF_LIST
#include "mlir/Dialect/Index/IR/IndexAttrs.cpp.inc"
>();
}

//===----------------------------------------------------------------------===//
// ODS-Generated Declarations
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Index/IR/IndexEnums.cpp.inc"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Index/IR/IndexAttrs.cpp.inc"
27 changes: 27 additions & 0 deletions mlir/lib/Dialect/Index/IR/IndexDialect.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//===- IndexDialect.cpp - Index dialect definition -------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Index/IR/IndexDialect.h"

using namespace mlir;
using namespace mlir::index;

//===----------------------------------------------------------------------===//
// IndexDialect
//===----------------------------------------------------------------------===//

void IndexDialect::initialize() {
registerAttributes();
registerOperations();
}

//===----------------------------------------------------------------------===//
// ODS-Generated Definitions
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Index/IR/IndexOpsDialect.cpp.inc"
376 changes: 376 additions & 0 deletions mlir/lib/Dialect/Index/IR/IndexOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,376 @@
//===- IndexOps.cpp - Index operation definitions --------------------------==//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Index/IR/IndexAttrs.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"

using namespace mlir;
using namespace mlir::index;

//===----------------------------------------------------------------------===//
// IndexDialect
//===----------------------------------------------------------------------===//

void IndexDialect::registerOperations() {
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
>();
}

Operation *IndexDialect::materializeConstant(OpBuilder &b, Attribute value,
Type type, Location loc) {
// Materialize bool constants as `i1`.
if (auto boolValue = dyn_cast<BoolAttr>(value)) {
if (!type.isSignlessInteger(1))
return nullptr;
return b.create<BoolConstantOp>(loc, type, boolValue);
}

// Materialize integer attributes as `index`.
if (auto indexValue = dyn_cast<IntegerAttr>(value)) {
if (!indexValue.getType().isa<IndexType>() || !type.isa<IndexType>())
return nullptr;
assert(indexValue.getValue().getBitWidth() ==
IndexType::kInternalStorageBitWidth);
return b.create<ConstantOp>(loc, indexValue);
}

return nullptr;
}

//===----------------------------------------------------------------------===//
// Fold Utilities
//===----------------------------------------------------------------------===//

/// Fold an index operation irrespective of the target bitwidth. The
/// operation must satisfy the property:
///
/// ```
/// trunc(f(a, b)) = f(trunc(a), trunc(b))
/// ```
///
/// For all values of `a` and `b`. The function accepts a lambda that computes
/// the integer result, which in turn must satisfy the above property.
static OpFoldResult foldBinaryOpUnchecked(
ArrayRef<Attribute> operands,
function_ref<APInt(const APInt &, const APInt &)> calculate) {
assert(operands.size() == 2 && "binary operation expected 2 operands");
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
if (!lhs || !rhs)
return {};

APInt result = calculate(lhs.getValue(), rhs.getValue());
assert(result.trunc(32) ==
calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32)));
return IntegerAttr::get(IndexType::get(lhs.getContext()), std::move(result));
}

/// Fold an index operation only if the truncated 64-bit result matches the
/// 32-bit result for operations that don't satisfy the above property. These
/// are operations where the upper bits of the operands can affect the lower
/// bits of the results.
///
/// The function accepts a lambda that computes the integer result in both
/// 64-bit and 32-bit. If either call returns `None`, the operation is not
/// folded.
static OpFoldResult foldBinaryOpChecked(
ArrayRef<Attribute> operands,
function_ref<Optional<APInt>(const APInt &, const APInt &lhs)> calculate) {
assert(operands.size() == 2 && "binary operation expected 2 operands");
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
// Only fold index operands.
if (!lhs || !rhs)
return {};

// Compute the 64-bit result and the 32-bit result.
Optional<APInt> result64 = calculate(lhs.getValue(), rhs.getValue());
if (!result64)
return {};
Optional<APInt> result32 =
calculate(lhs.getValue().trunc(32), rhs.getValue().trunc(32));
if (!result32)
return {};
// Compare the truncated 64-bit result to the 32-bit result.
if (result64->trunc(32) != *result32)
return {};
// The operation can be folded for these particular operands.
return IntegerAttr::get(IndexType::get(lhs.getContext()),
std::move(*result64));
}

//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//

OpFoldResult AddOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs + rhs; });
}

//===----------------------------------------------------------------------===//
// SubOp
//===----------------------------------------------------------------------===//

OpFoldResult SubOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs - rhs; });
}

//===----------------------------------------------------------------------===//
// MulOp
//===----------------------------------------------------------------------===//

OpFoldResult MulOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpUnchecked(
operands, [](const APInt &lhs, const APInt &rhs) { return lhs * rhs; });
}

//===----------------------------------------------------------------------===//
// DivSOp
//===----------------------------------------------------------------------===//

OpFoldResult DivSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
// Don't fold division by zero.
if (rhs.isZero())
return None;
return lhs.sdiv(rhs);
});
}

//===----------------------------------------------------------------------===//
// DivUOp
//===----------------------------------------------------------------------===//

OpFoldResult DivUOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(
operands, [](const APInt &lhs, const APInt &rhs) -> Optional<APInt> {
// Don't fold division by zero.
if (rhs.isZero())
return None;
return lhs.udiv(rhs);
});
}

//===----------------------------------------------------------------------===//
// CeilDivSOp
//===----------------------------------------------------------------------===//

/// Compute `ceildivs(n, m)` as `x = m > 0 ? -1 : 1` and then
/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`.
static Optional<APInt> calculateCeilDivS(const APInt &n, const APInt &m) {
// Don't fold division by zero.
if (m.isZero())
return None;
// Short-circuit the zero case.
if (n.isZero())
return n;

bool mGtZ = m.sgt(0);
if (n.sgt(0) != mGtZ) {
// If the operands have different signs, compute the negative result. Signed
// division overflow is not possible, since if `m == -1`, `n` can be at most
// `INT_MAX`, and `-INT_MAX != INT_MIN` in two's complement.
return -(-n).sdiv(m);
}
// Otherwise, compute the positive result. Signed division overflow is not
// possible since if `m == -1`, `x` will be `1`.
int64_t x = mGtZ ? -1 : 1;
return (n + x).sdiv(m) + 1;
}

OpFoldResult CeilDivSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, calculateCeilDivS);
}

//===----------------------------------------------------------------------===//
// CeilDivUOp
//===----------------------------------------------------------------------===//

OpFoldResult CeilDivUOp::fold(ArrayRef<Attribute> operands) {
// Compute `ceildivu(n, m)` as `n == 0 ? 0 : (n-1)/m + 1`.
return foldBinaryOpChecked(
operands, [](const APInt &n, const APInt &m) -> Optional<APInt> {
// Don't fold division by zero.
if (m.isZero())
return None;
// Short-circuit the zero case.
if (n.isZero())
return n;

return (n - 1).udiv(m) + 1;
});
}

//===----------------------------------------------------------------------===//
// FloorDivSOp
//===----------------------------------------------------------------------===//

/// Compute `floordivs(n, m)` as `x = m < 0 ? 1 : -1` and then
/// `n*m < 0 ? -1 - (x-n)/m : n/m`.
static Optional<APInt> calculateFloorDivS(const APInt &n, const APInt &m) {
// Don't fold division by zero.
if (m.isZero())
return None;
// Short-circuit the zero case.
if (n.isZero())
return n;

bool mLtZ = m.slt(0);
if (n.slt(0) == mLtZ) {
// If the operands have the same sign, compute the positive result.
return n.sdiv(m);
}
// If the operands have different signs, compute the negative result. Signed
// division overflow is not possible since if `m == -1`, `x` will be 1 and
// `n` can be at most `INT_MAX`.
int64_t x = mLtZ ? 1 : -1;
return -1 - (x - n).sdiv(m);
}

OpFoldResult FloorDivSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, calculateFloorDivS);
}

//===----------------------------------------------------------------------===//
// RemSOp
//===----------------------------------------------------------------------===//

OpFoldResult RemSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
return lhs.srem(rhs);
});
}

//===----------------------------------------------------------------------===//
// RemUOp
//===----------------------------------------------------------------------===//

OpFoldResult RemUOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
return lhs.urem(rhs);
});
}

//===----------------------------------------------------------------------===//
// MaxSOp
//===----------------------------------------------------------------------===//

OpFoldResult MaxSOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
return lhs.sgt(rhs) ? lhs : rhs;
});
}

//===----------------------------------------------------------------------===//
// MaxUOp
//===----------------------------------------------------------------------===//

OpFoldResult MaxUOp::fold(ArrayRef<Attribute> operands) {
return foldBinaryOpChecked(operands, [](const APInt &lhs, const APInt &rhs) {
return lhs.ugt(rhs) ? lhs : rhs;
});
}

//===----------------------------------------------------------------------===//
// CastSOp
//===----------------------------------------------------------------------===//

bool CastSOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
return lhsTypes.front().isa<IndexType>() != rhsTypes.front().isa<IndexType>();
}

//===----------------------------------------------------------------------===//
// CastUOp
//===----------------------------------------------------------------------===//

bool CastUOp::areCastCompatible(TypeRange lhsTypes, TypeRange rhsTypes) {
return lhsTypes.front().isa<IndexType>() != rhsTypes.front().isa<IndexType>();
}

//===----------------------------------------------------------------------===//
// CmpOp
//===----------------------------------------------------------------------===//

/// Compare two integers according to the comparison predicate.
bool compareIndices(const APInt &lhs, const APInt &rhs,
IndexCmpPredicate pred) {
switch (pred) {
case IndexCmpPredicate::EQ:
return lhs.eq(rhs);
case IndexCmpPredicate::NE:
return lhs.ne(rhs);
case IndexCmpPredicate::SGE:
return lhs.sge(rhs);
case IndexCmpPredicate::SGT:
return lhs.sgt(rhs);
case IndexCmpPredicate::SLE:
return lhs.sle(rhs);
case IndexCmpPredicate::SLT:
return lhs.slt(rhs);
case IndexCmpPredicate::UGE:
return lhs.uge(rhs);
case IndexCmpPredicate::UGT:
return lhs.ugt(rhs);
case IndexCmpPredicate::ULE:
return lhs.ule(rhs);
case IndexCmpPredicate::ULT:
return lhs.ult(rhs);
}
llvm_unreachable("unhandled IndexCmpPredicate predicate");
}

OpFoldResult CmpOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "compare expected 2 operands");
auto lhs = dyn_cast_if_present<IntegerAttr>(operands[0]);
auto rhs = dyn_cast_if_present<IntegerAttr>(operands[1]);
if (!lhs || !rhs)
return {};

// Perform the comparison in 64-bit and 32-bit.
bool result64 = compareIndices(lhs.getValue(), rhs.getValue(), getPred());
bool result32 = compareIndices(lhs.getValue().trunc(32),
rhs.getValue().trunc(32), getPred());
if (result64 != result32)
return {};
return BoolAttr::get(getContext(), result64);
}

//===----------------------------------------------------------------------===//
// ConstantOp
//===----------------------------------------------------------------------===//

OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
return getValueAttr();
}

void ConstantOp::build(OpBuilder &b, OperationState &state, int64_t value) {
build(b, state, b.getIndexType(), b.getIndexAttr(value));
}

//===----------------------------------------------------------------------===//
// BoolConstantOp
//===----------------------------------------------------------------------===//

OpFoldResult BoolConstantOp::fold(ArrayRef<Attribute> operands) {
return getValueAttr();
}

//===----------------------------------------------------------------------===//
// ODS-Generated Definitions
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/Index/IR/IndexOps.cpp.inc"
176 changes: 176 additions & 0 deletions mlir/test/Conversion/IndexToLLVM/index-to-llvm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// RUN: mlir-opt %s -convert-index-to-llvm | FileCheck %s
// RUN: mlir-opt %s -convert-index-to-llvm=index-bitwidth=32 | FileCheck %s --check-prefix=INDEX32
// RUN: mlir-opt %s -convert-index-to-llvm=index-bitwidth=64 | FileCheck %s --check-prefix=INDEX64

// CHECK-LABEL: @trivial_ops
func.func @trivial_ops(%a: index, %b: index) {
// CHECK: llvm.add
%0 = index.add %a, %b
// CHECK: llvm.sub
%1 = index.sub %a, %b
// CHECK: llvm.mul
%2 = index.mul %a, %b
// CHECK: llvm.sdiv
%3 = index.divs %a, %b
// CHECK: llvm.udiv
%4 = index.divu %a, %b
// CHECK: llvm.srem
%5 = index.rems %a, %b
// CHECK: llvm.urem
%6 = index.remu %a, %b
// CHECK: llvm.intr.smax
%7 = index.maxs %a, %b
// CHECK: llvm.intr.umax
%8 = index.maxu %a, %b
// CHECK: llvm.mlir.constant(true
%9 = index.bool.constant true
return
}

// CHECK-LABEL: @ceildivs
// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
func.func @ceildivs(%n: index, %m: index) -> index {
// CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
// CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 :
// CHECK: %[[POS_ONE:.*]] = llvm.mlir.constant(1 :
// CHECK: %[[NEG_ONE:.*]] = llvm.mlir.constant(-1 :

// CHECK: %[[M_POS:.*]] = llvm.icmp "sgt" %[[M]], %[[ZERO]]
// CHECK: %[[X:.*]] = llvm.select %[[M_POS]], %[[NEG_ONE]], %[[POS_ONE]]

// CHECK: %[[N_PLUS_X:.*]] = llvm.add %[[N]], %[[X]]
// CHECK: %[[N_PLUS_X_DIV_M:.*]] = llvm.sdiv %[[N_PLUS_X]], %[[M]]
// CHECK: %[[POS_RES:.*]] = llvm.add %[[N_PLUS_X_DIV_M]], %[[POS_ONE]]

// CHECK: %[[NEG_N:.*]] = llvm.sub %[[ZERO]], %[[N]]
// CHECK: %[[NEG_N_DIV_M:.*]] = llvm.sdiv %[[NEG_N]], %[[M]]
// CHECK: %[[NEG_RES:.*]] = llvm.sub %[[ZERO]], %[[NEG_N_DIV_M]]

// CHECK: %[[N_POS:.*]] = llvm.icmp "sgt" %[[N]], %[[ZERO]]
// CHECK: %[[SAME_SIGN:.*]] = llvm.icmp "eq" %[[N_POS]], %[[M_POS]]
// CHECK: %[[N_NON_ZERO:.*]] = llvm.icmp "ne" %[[N]], %[[ZERO]]
// CHECK: %[[CMP:.*]] = llvm.and %[[SAME_SIGN]], %[[N_NON_ZERO]]
// CHECK: %[[RESULT:.*]] = llvm.select %[[CMP]], %[[POS_RES]], %[[NEG_RES]]
%result = index.ceildivs %n, %m

// CHECK: %[[RESULTI:.*]] = builtin.unrealized_conversion_cast %[[RESULT]]
// CHECK: return %[[RESULTI]]
return %result : index
}

// CHECK-LABEL: @ceildivu
// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
func.func @ceildivu(%n: index, %m: index) -> index {
// CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
// CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 :
// CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 :

// CHECK: %[[MINUS_ONE:.*]] = llvm.sub %[[N]], %[[ONE]]
// CHECK: %[[QUOTIENT:.*]] = llvm.udiv %[[MINUS_ONE]], %[[M]]
// CHECK: %[[PLUS_ONE:.*]] = llvm.add %[[QUOTIENT]], %[[ONE]]

// CHECK: %[[CMP:.*]] = llvm.icmp "eq" %[[N]], %[[ZERO]]
// CHECK: %[[RESULT:.*]] = llvm.select %[[CMP]], %[[ZERO]], %[[PLUS_ONE]]
%result = index.ceildivu %n, %m

// CHECK: %[[RESULTI:.*]] = builtin.unrealized_conversion_cast %[[RESULT]]
// CHECK: return %[[RESULTI]]
return %result : index
}

// CHECK-LABEL: @floordivs
// CHECK-SAME: %[[NI:.*]]: index, %[[MI:.*]]: index
func.func @floordivs(%n: index, %m: index) -> index {
// CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[NI]]
// CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[MI]]
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 :
// CHECK: %[[POS_ONE:.*]] = llvm.mlir.constant(1 :
// CHECK: %[[NEG_ONE:.*]] = llvm.mlir.constant(-1 :

// CHECK: %[[M_NEG:.*]] = llvm.icmp "slt" %[[M]], %[[ZERO]]
// CHECK: %[[X:.*]] = llvm.select %[[M_NEG]], %[[POS_ONE]], %[[NEG_ONE]]

// CHECK: %[[X_MINUS_N:.*]] = llvm.sub %[[X]], %[[N]]
// CHECK: %[[X_MINUS_N_DIV_M:.*]] = llvm.sdiv %[[X_MINUS_N]], %[[M]]
// CHECK: %[[NEG_RES:.*]] = llvm.sub %[[NEG_ONE]], %[[X_MINUS_N_DIV_M]]

// CHECK: %[[POS_RES:.*]] = llvm.sdiv %[[N]], %[[M]]

// CHECK: %[[N_NEG:.*]] = llvm.icmp "slt" %[[N]], %[[ZERO]]
// CHECK: %[[DIFF_SIGN:.*]] = llvm.icmp "ne" %[[N_NEG]], %[[M_NEG]]
// CHECK: %[[N_NON_ZERO:.*]] = llvm.icmp "ne" %[[N]], %[[ZERO]]
// CHECK: %[[CMP:.*]] = llvm.and %[[DIFF_SIGN]], %[[N_NON_ZERO]]
// CHECK: %[[RESULT:.*]] = llvm.select %[[CMP]], %[[NEG_RES]], %[[POS_RES]]
%result = index.floordivs %n, %m

// CHECK: %[[RESULTI:.*]] = builtin.unrealized_conversion_cast %[[RESULT]]
// CHECK: return %[[RESULTI]]
return %result : index
}

// INDEX32-LABEL: @index_cast_from
// INDEX64-LABEL: @index_cast_from
// INDEX32-SAME: %[[AI:.*]]: index
// INDEX64-SAME: %[[AI:.*]]: index
func.func @index_cast_from(%a: index) -> (i64, i32, i64, i32) {
// INDEX32: %[[A:.*]] = builtin.unrealized_conversion_cast %[[AI]] : index to i32
// INDEX64: %[[A:.*]] = builtin.unrealized_conversion_cast %[[AI]] : index to i64

// INDEX32: %[[V0:.*]] = llvm.sext %[[A]] : i32 to i64
%0 = index.casts %a : index to i64
// INDEX64: %[[V1:.*]] = llvm.trunc %[[A]] : i64 to i32
%1 = index.casts %a : index to i32
// INDEX32: %[[V2:.*]] = llvm.zext %[[A]] : i32 to i64
%2 = index.castu %a : index to i64
// INDEX64: %[[V3:.*]] = llvm.trunc %[[A]] : i64 to i32
%3 = index.castu %a : index to i32

// INDEX32: return %[[V0]], %[[A]], %[[V2]], %[[A]]
// INDEX64: return %[[A]], %[[V1]], %[[A]], %[[V3]]
return %0, %1, %2, %3 : i64, i32, i64, i32
}

// INDEX32-LABEL: @index_cast_to
// INDEX64-LABEL: @index_cast_to
// INDEX32-SAME: %[[A:.*]]: i32, %[[B:.*]]: i64
// INDEX64-SAME: %[[A:.*]]: i32, %[[B:.*]]: i64
func.func @index_cast_to(%a: i32, %b: i64) -> (index, index, index, index) {
// INDEX64: %[[V0:.*]] = llvm.sext %[[A]] : i32 to i64
%0 = index.casts %a : i32 to index
// INDEX32: %[[V1:.*]] = llvm.trunc %[[B]] : i64 to i32
%1 = index.casts %b : i64 to index
// INDEX64: %[[V2:.*]] = llvm.zext %[[A]] : i32 to i64
%2 = index.castu %a : i32 to index
// INDEX32: %[[V3:.*]] = llvm.trunc %[[B]] : i64 to i32
%3 = index.castu %b : i64 to index
return %0, %1, %2, %3 : index, index, index, index
}

// INDEX32-LABEL: @index_sizeof
// INDEX64-LABEL: @index_sizeof
func.func @index_sizeof() {
// INDEX32-NEXT: llvm.mlir.constant(32 : i32)
// INDEX64-NEXT: llvm.mlir.constant(64 : i64)
%0 = index.sizeof
return
}

// INDEX32-LABEL: @index_constant
// INDEX64-LABEL: @index_constant
func.func @index_constant() {
// INDEX32: llvm.mlir.constant(-2100000000 : i32) : i32
// INDEX64: llvm.mlir.constant(-2100000000 : i64) : i64
%0 = index.constant -2100000000
// INDEX32: llvm.mlir.constant(2100000000 : i32) : i32
// INDEX64: llvm.mlir.constant(2100000000 : i64) : i64
%1 = index.constant 2100000000
// INDEX32: llvm.mlir.constant(1294967296 : i32) : i32
// INDEX64: llvm.mlir.constant(-3000000000 : i64) : i64
%2 = index.constant -3000000000
// INDEX32: llvm.mlir.constant(-1294967296 : i32) : i32
// INDEX64: llvm.mlir.constant(3000000000 : i64) : i64
%3 = index.constant 3000000000
return
}
319 changes: 319 additions & 0 deletions mlir/test/Dialect/Index/index-canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
// RUN: mlir-opt %s -canonicalize | FileCheck %s

// CHECK-LABEL: @add
func.func @add() -> (index, index) {
%0 = index.constant 1
%1 = index.constant 2100
%2 = index.constant 3000000001
%3 = index.constant 4000002100
// Folds normally.
%4 = index.add %0, %1
// Folds even though values exceed INT32_MAX.
%5 = index.add %2, %3

// CHECK-DAG: %[[A:.*]] = index.constant 2101
// CHECK-DAG: %[[B:.*]] = index.constant 7000002101
// CHECK: return %[[A]], %[[B]]
return %4, %5 : index, index
}

// CHECK-LABEL: @add_overflow
func.func @add_overflow() -> (index, index) {
%0 = index.constant 2000000000
%1 = index.constant 8000000000000000000
// Folds normally.
%2 = index.add %0, %0
// Folds and overflows.
%3 = index.add %1, %1

// CHECK-DAG: %[[A:.*]] = index.constant 4{{0+}}
// CHECK-DAG: %[[B:.*]] = index.constant -2446{{[0-9]+}}
// CHECK: return %[[A]], %[[B]]
return %2, %3 : index, index
}

// CHECK-LABEL: @sub
func.func @sub() -> index {
%0 = index.constant -2000000000
%1 = index.constant 3000000000
%2 = index.sub %0, %1
// CHECK: %[[A:.*]] = index.constant -5{{0+}}
// CHECK: return %[[A]]
return %2 : index
}

// CHECK-LABEL: @mul
func.func @mul() -> index {
%0 = index.constant 8000000002000000000
%1 = index.constant 2
%2 = index.mul %0, %1
// CHECK: %[[A:.*]] = index.constant -2446{{[0-9]+}}
// CHECK: return %[[A]]
return %2 : index
}

// CHECK-LABEL: @divs
func.func @divs() -> index {
%0 = index.constant -2
%1 = index.constant 0x200000000
%2 = index.divs %1, %0
// CHECK: %[[A:.*]] = index.constant -429{{[0-9]+}}
// CHECK: return %[[A]]
return %2 : index
}

// CHECK-LABEL: @divs_nofold
func.func @divs_nofold() -> (index, index) {
%0 = index.constant 0
%1 = index.constant 0x100000000
%2 = index.constant 2

// Divide by zero.
// CHECK: index.divs
%3 = index.divs %2, %0
// 32-bit result differs from 64-bit.
// CHECK: index.divs
%4 = index.divs %1, %2

return %3, %4 : index, index
}

// CHECK-LABEL: @divu
func.func @divu() -> index {
%0 = index.constant -2
%1 = index.constant 0x200000000
%2 = index.divu %1, %0
// CHECK: %[[A:.*]] = index.constant 0
// CHECK: return %[[A]]
return %2 : index
}

// CHECK-LABEL: @divu_nofold
func.func @divu_nofold() -> (index, index) {
%0 = index.constant 0
%1 = index.constant 0x100000000
%2 = index.constant 2

// Divide by zero.
// CHECK: index.divu
%3 = index.divu %2, %0
// 32-bit result differs from 64-bit.
// CHECK: index.divu
%4 = index.divu %1, %2

return %3, %4 : index, index
}

// CHECK-LABEL: @ceildivs
func.func @ceildivs() -> (index, index, index) {
%c0 = index.constant 0
%c2 = index.constant 2
%c5 = index.constant 5

// CHECK-DAG: %[[A:.*]] = index.constant 0
%0 = index.ceildivs %c0, %c5

// CHECK-DAG: %[[B:.*]] = index.constant 1
%1 = index.ceildivs %c2, %c5

// CHECK-DAG: %[[C:.*]] = index.constant 3
%2 = index.ceildivs %c5, %c2

// CHECK: return %[[A]], %[[B]], %[[C]]
return %0, %1, %2 : index, index, index
}

// CHECK-LABEL: @ceildivs_neg
func.func @ceildivs_neg() -> index {
%c5 = index.constant -5
%c2 = index.constant 2
// CHECK: %[[A:.*]] = index.constant -2
%0 = index.ceildivs %c5, %c2
// CHECK: return %[[A]]
return %0 : index
}

// CHECK-LABEL: @ceildivs_edge
func.func @ceildivs_edge() -> (index, index) {
%cn1 = index.constant -1
%cIntMin = index.constant -2147483648
%cIntMax = index.constant 2147483647

// The result is 0 on 32-bit.
// CHECK-DAG: %[[A:.*]] = index.constant 2147483648
%0 = index.ceildivs %cIntMin, %cn1

// CHECK-DAG: %[[B:.*]] = index.constant -2147483647
%1 = index.ceildivs %cIntMax, %cn1

// CHECK: return %[[A]], %[[B]]
return %0, %1 : index, index
}

// CHECK-LABEL: @ceildivu
func.func @ceildivu() -> index {
%0 = index.constant 0x200000001
%1 = index.constant 2
// CHECK: %[[A:.*]] = index.constant 429{{[0-9]+}}7
%2 = index.ceildivu %0, %1
// CHECK: return %[[A]]
return %2 : index
}

// CHECK-LABEL: @floordivs
func.func @floordivs() -> index {
%0 = index.constant -5
%1 = index.constant 2
// CHECK: %[[A:.*]] = index.constant -3
%2 = index.floordivs %0, %1
// CHECK: return %[[A]]
return %2 : index
}

// CHECK-LABEL: @floordivs_edge
func.func @floordivs_edge() -> (index, index) {
%cIntMin = index.constant -2147483648
%cIntMax = index.constant 2147483647
%n1 = index.constant -1
%p1 = index.constant 1

// CHECK-DAG: %[[A:.*]] = index.constant -2147483648
// CHECK-DAG: %[[B:.*]] = index.constant -2147483647
%0 = index.floordivs %cIntMin, %p1
%1 = index.floordivs %cIntMax, %n1

// CHECK: return %[[A]], %[[B]]
return %0, %1 : index, index
}

// CHECK-LABEL: @floordivs_nofold
func.func @floordivs_nofold() -> index {
%lhs = index.constant 0x100000000
%c2 = index.constant 2

// 32-bit result differs from 64-bit.
// CHECK: index.floordivs
%0 = index.floordivs %lhs, %c2

return %0 : index
}

// CHECK-LABEL: @rems
func.func @rems() -> index {
%lhs = index.constant -5
%rhs = index.constant 2
// CHECK: %[[A:.*]] = index.constant -1
%0 = index.rems %lhs, %rhs
// CHECK: return %[[A]]
return %0 : index
}

// CHECK-LABEL: @rems_nofold
func.func @rems_nofold() -> index {
%lhs = index.constant 2
%rhs = index.constant 0x100000001
// 32-bit result differs from 64-bit.
// CHECK: index.rems
%0 = index.rems %lhs, %rhs
return %0 : index
}

// CHECK-LABEL: @remu
func.func @remu() -> index {
%lhs = index.constant 2
%rhs = index.constant -1
// CHECK: %[[A:.*]] = index.constant 2
%0 = index.remu %lhs, %rhs
// CHECK: return %[[A]]
return %0 : index
}

// CHECK-LABEL: @remu_nofold
func.func @remu_nofold() -> index {
%lhs = index.constant 2
%rhs = index.constant 0x100000001
// 32-bit result differs from 64-bit.
// CHECK: index.remu
%0 = index.remu %lhs, %rhs
return %0 : index
}

// CHECK-LABEL: @maxs
func.func @maxs() -> index {
%lhs = index.constant -4
%rhs = index.constant 2
// CHECK: %[[A:.*]] = index.constant 2
%0 = index.maxs %lhs, %rhs
// CHECK: return %[[A]]
return %0 : index
}

// CHECK-LABEL: @maxs_nofold
func.func @maxs_nofold() -> index {
%lhs = index.constant 1
%rhs = index.constant 0x100000000
// 32-bit result differs from 64-bit.
// CHECK: index.maxs
%0 = index.maxs %lhs, %rhs
return %0 : index
}

// CHECK-LABEL: @maxs_edge
func.func @maxs_edge() -> index {
%lhs = index.constant 1
%rhs = index.constant 0x100000001
// Truncated 64-bit result is the same as 32-bit.
// CHECK: %[[A:.*]] = index.constant 429{{[0-9]+}}
%0 = index.maxs %lhs, %rhs
// CHECK: return %[[A]]
return %0 : index
}

// CHECK-LABEL: @maxu
func.func @maxu() -> index {
%lhs = index.constant -1
%rhs = index.constant 1
// CHECK: %[[A:.*]] = index.constant -1
%0 = index.maxu %lhs, %rhs
// CHECK: return %[[A]]
return %0 : index
}

// CHECK-LABEL: @cmp
func.func @cmp() -> (i1, i1, i1, i1) {
%a = index.constant 0
%b = index.constant -1
%c = index.constant -2
%d = index.constant 4

%0 = index.cmp slt(%a, %b)
%1 = index.cmp ugt(%b, %a)
%2 = index.cmp ne(%d, %a)
%3 = index.cmp sgt(%b, %a)

// CHECK-DAG: %[[TRUE:.*]] = index.bool.constant true
// CHECK-DAG: %[[FALSE:.*]] = index.bool.constant false
// CHECK: return %[[FALSE]], %[[TRUE]], %[[TRUE]], %[[FALSE]]
return %0, %1, %2, %3 : i1, i1, i1, i1
}

// CHECK-LABEL: @cmp_nofold
func.func @cmp_nofold() -> i1 {
%lhs = index.constant 1
%rhs = index.constant 0x100000000
// 32-bit result differs from 64-bit.
// CHECK: index.cmp slt
%0 = index.cmp slt(%lhs, %rhs)
return %0 : i1
}

// CHECK-LABEL: @cmp_edge
func.func @cmp_edge() -> i1 {
%lhs = index.constant 1
%rhs = index.constant 0x100000002
// 64-bit result is the same as 32-bit.
// CHECK: %[[TRUE:.*]] = index.bool.constant true
%0 = index.cmp slt(%lhs, %rhs)
// CHECK: return %[[TRUE]]
return %0 : i1
}
31 changes: 31 additions & 0 deletions mlir/test/Dialect/Index/index-errors.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: mlir-opt -split-input-file -verify-diagnostics %s

func.func @invalid_cast(%a: index) {
// expected-error @below {{cast incompatible}}
%0 = index.casts %a : index to index
return
}

// -----

func.func @invalid_cast(%a: i64) {
// expected-error @below {{cast incompatible}}
%0 = index.casts %a : i64 to i64
return
}

// -----

func.func @invalid_cast(%a: index) {
// expected-error @below {{cast incompatible}}
%0 = index.castu %a : index to index
return
}

// -----

func.func @invalid_cast(%a: i64) {
// expected-error @below {{cast incompatible}}
%0 = index.castu %a : i64 to i64
return
}
102 changes: 102 additions & 0 deletions mlir/test/Dialect/Index/index-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// RUN: mlir-opt %s | mlir-opt | FileCheck %s

// CHECK-LABEL: @binary_ops
// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index
func.func @binary_ops(%a: index, %b: index) {
// CHECK-NEXT: index.add %[[A]], %[[B]]
%0 = index.add %a, %b
// CHECK-NEXT: index.sub %[[A]], %[[B]]
%1 = index.sub %a, %b
// CHECK-NEXT: index.mul %[[A]], %[[B]]
%2 = index.mul %a, %b
// CHECK-NEXT: index.divs %[[A]], %[[B]]
%3 = index.divs %a, %b
// CHECK-NEXT: index.divu %[[A]], %[[B]]
%4 = index.divu %a, %b
// CHECK-NEXT: index.ceildivs %[[A]], %[[B]]
%5 = index.ceildivs %a, %b
// CHECK-NEXT: index.ceildivu %[[A]], %[[B]]
%6 = index.ceildivu %a, %b
// CHECK-NEXT: index.floordivs %[[A]], %[[B]]
%7 = index.floordivs %a, %b
// CHECK-NEXT: index.rems %[[A]], %[[B]]
%8 = index.rems %a, %b
// CHECK-NEXT: index.remu %[[A]], %[[B]]
%9 = index.remu %a, %b
// CHECK-NEXT: index.maxs %[[A]], %[[B]]
%10 = index.maxs %a, %b
// CHECK-NEXT: index.maxu %[[A]], %[[B]]
%11 = index.maxu %a, %b
return
}

// CHECK-LABEL: @cmp_op
// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: index
func.func @cmp_op(%a: index, %b: index) {
// CHECK-NEXT: index.cmp eq(%[[A]], %[[B]])
%0 = index.cmp eq(%a, %b)
// CHECK-NEXT: index.cmp ne(%[[A]], %[[B]])
%1 = index.cmp ne(%a, %b)
// CHECK-NEXT: index.cmp slt(%[[A]], %[[B]])
%2 = index.cmp slt(%a, %b)
// CHECK-NEXT: index.cmp sle(%[[A]], %[[B]])
%3 = index.cmp sle(%a, %b)
// CHECK-NEXT: index.cmp sgt(%[[A]], %[[B]])
%4 = index.cmp sgt(%a, %b)
// CHECK-NEXT: index.cmp sge(%[[A]], %[[B]])
%5 = index.cmp sge(%a, %b)
// CHECK-NEXT: index.cmp ult(%[[A]], %[[B]])
%6 = index.cmp ult(%a, %b)
// CHECK-NEXT: index.cmp ule(%[[A]], %[[B]])
%7 = index.cmp ule(%a, %b)
// CHECK-NEXT: index.cmp ugt(%[[A]], %[[B]])
%8 = index.cmp ugt(%a, %b)
// CHECK-NEXT: index.cmp uge(%[[A]], %[[B]])
%9 = index.cmp uge(%a, %b)
return
}

// CHECK-LABEL: @sizeof_op
func.func @sizeof_op() {
// CHECK: index.sizeof
%0 = index.sizeof
return
}

// CHECK-LABEL: @constant_op
func.func @constant_op() {
// CHECK-NEXT: index.constant 0
%0 = index.constant 0
// CHECK-NEXT: index.constant 1
%1 = index.constant 1
// CHECK-NEXT: index.constant 42
%2 = index.constant 42
return
}

// CHECK-LABEL: @bool_constant_op
func.func @bool_constant_op() {
// CHECK-NEXT: index.bool.constant true
%0 = index.bool.constant true
// CHECK-NEXT: index.bool.constant false
%1 = index.bool.constant false
return
}

// CHECK-LABEL: @cast_op
// CHECK-SAME: %[[A:.*]]: index, %[[B:.*]]: i32, %[[C:.*]]: i64
func.func @cast_op(%a: index, %b: i32, %c: i64) {
// CHECK-NEXT: index.casts %[[A]] : index to i64
%0 = index.casts %a : index to i64
// CHECK-NEXT: index.casts %[[B]] : i32 to index
%1 = index.casts %b : i32 to index
// CHECK-NEXT: index.casts %[[C]] : i64 to index
%2 = index.casts %c : i64 to index
// CHECK-NEXT: index.castu %[[A]] : index to i64
%3 = index.castu %a : index to i64
// CHECK-NEXT: index.castu %[[B]] : i32 to index
%4 = index.castu %b : i32 to index
// CHECK-NEXT: index.castu %[[C]] : i64 to index
%5 = index.castu %c : i64 to index
return
}
1 change: 1 addition & 0 deletions mlir/test/mlir-opt/commandline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// CHECK-NEXT: emitc
// CHECK-NEXT: func
// CHECK-NEXT: gpu
// CHECK-NEXT: index
// CHECK-NEXT: linalg
// CHECK-NEXT: llvm
// CHECK-NEXT: math
Expand Down