Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reland: "[mlir][index][spirv] Add conversion for index to spirv" #69790

Merged
merged 2 commits into from Oct 22, 2023

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Oct 20, 2023

Due to an issue when lowering from scf to spirv as there was no conversion pass for index to spirv, we are motivated to add a conversion pass from the Index dialect to the SPIR-V dialect. Furthermore, we add the new conversion patterns to the scf-to-spirv conversion.

Fixes #63713

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 20, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Finn Plummer (inbelic)

Changes

Due to an issue when lowering from scf to spirv as there was no conversion pass for index to spirv, we are motivated to add a conversion pass from the Index dialect to the SPIR-V dialect. Furthermore, we add the new conversion patterns to the scf-to-spirv conversion.

Fixes #63713


Patch is 33.89 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69790.diff

10 Files Affected:

  • (added) mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h (+30)
  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+22)
  • (modified) mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h (+8-3)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (added) mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt (+16)
  • (added) mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp (+418)
  • (modified) mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp (+2)
  • (added) mlir/test/Conversion/IndexToSPRIV/index-to-spirv.mlir (+222)
  • (added) mlir/test/Conversion/SCFToSPIRV/use-indices.mlir (+28)
diff --git a/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
new file mode 100644
index 000000000000000..58a1c5246eef999
--- /dev/null
+++ b/mlir/include/mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h
@@ -0,0 +1,30 @@
+//===- IndexToSPIRV.h - Index to SPIRV dialect conversion -------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
+#define MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
+
+#include "mlir/Pass/Pass.h"
+#include <memory>
+
+namespace mlir {
+class RewritePatternSet;
+class SPIRVTypeConverter;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTINDEXTOSPIRVPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+namespace index {
+void populateIndexToSPIRVPatterns(SPIRVTypeConverter &converter,
+                                  RewritePatternSet &patterns);
+std::unique_ptr<OperationPass<>> createConvertIndexToSPIRVPass();
+} // namespace index
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_INDEXTOSPIRV_INDEXTOSPIRV_H
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index e714f5070f23db8..c13c457fd97492a 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -35,6 +35,7 @@
 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
 #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
 #include "mlir/Conversion/LinalgToStandard/LinalgToStandard.h"
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 38b05c792d405ad..9979faed4251787 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -644,6 +644,28 @@ def ConvertIndexToLLVMPass : Pass<"convert-index-to-llvm"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// ConvertIndexToSPIRVPass
+//===----------------------------------------------------------------------===//
+
+def ConvertIndexToSPIRVPass : Pass<"convert-index-to-spirv"> {
+  let summary = "Lower the `index` dialect to the `spirv` dialect.";
+  let description = [{
+    This pass lowers Index dialect operations to SPIR-V dialect operations.
+    Operation conversions are 1-to-1 except for the exotic divides: `ceildivs`,
+    `ceildivu`, and `floordivs`. The index bitwidth will be 32 or 64 as
+    specified by use-64bit-index.
+  }];
+
+  let dependentDialects = ["::mlir::spirv::SPIRVDialect"];
+
+  let options = [
+    Option<"use64bitIndex", "use-64bit-index",
+           "bool", /*default=*/"false",
+           "Use 64-bit integers to convert index types">
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // LinalgToStandard
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 89ded981d38f9f4..933d62e35fce8cd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -55,13 +55,13 @@ struct SPIRVConversionOptions {
   /// values will be packed into one 32-bit value to be memory efficient.
   bool emulateLT32BitScalarTypes{true};
 
-  /// Use 64-bit integers to convert index types.
-  bool use64bitIndex{false};
-
   /// Whether to enable fast math mode during conversion. If true, various
   /// patterns would assume no NaN/infinity numbers as inputs, and thus there
   /// will be no special guards emitted to check and handle such cases.
   bool enableFastMathMode{false};
+
+  /// Use 64-bit integers when converting index types.
+  bool use64bitIndex{false};
 };
 
 /// Type conversion from builtin types to SPIR-V types for shader interface.
@@ -77,6 +77,11 @@ class SPIRVTypeConverter : public TypeConverter {
   /// Gets the SPIR-V correspondence for the standard index type.
   Type getIndexType() const;
 
+  /// Gets the bitwidth of the index type when converted to SPIR-V.
+  unsigned getIndexTypeBitwidth() const {
+    return options.use64bitIndex ? 64 : 32;
+  }
+
   const spirv::TargetEnv &getTargetEnv() const { return targetEnv; }
 
   /// Returns the options controlling the SPIR-V type converter.
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 35790254be137be..7e1c7bcf9a8678a 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -24,6 +24,7 @@ add_subdirectory(GPUToROCDL)
 add_subdirectory(GPUToSPIRV)
 add_subdirectory(GPUToVulkan)
 add_subdirectory(IndexToLLVM)
+add_subdirectory(IndexToSPIRV)
 add_subdirectory(LinalgToStandard)
 add_subdirectory(LLVMCommon)
 add_subdirectory(MathToFuncs)
diff --git a/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt b/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
new file mode 100644
index 000000000000000..e3b279d915a15dd
--- /dev/null
+++ b/mlir/lib/Conversion/IndexToSPIRV/CMakeLists.txt
@@ -0,0 +1,16 @@
+add_mlir_conversion_library(MLIRIndexToSPIRV
+  IndexToSPIRV.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/IndexToSPIRV
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRIndexDialect
+  MLIRSPIRVDialect
+  )
diff --git a/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
new file mode 100644
index 000000000000000..b58efc096e2eafb
--- /dev/null
+++ b/mlir/lib/Conversion/IndexToSPIRV/IndexToSPIRV.cpp
@@ -0,0 +1,418 @@
+//===- IndexToSPIRV.cpp - Index to SPIRV dialect conversion -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h"
+#include "../SPIRVCommon/Pattern.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+using namespace index;
+
+namespace {
+
+//===----------------------------------------------------------------------===//
+// Trivial Conversions
+//===----------------------------------------------------------------------===//
+
+using ConvertIndexAdd = spirv::ElementwiseOpPattern<AddOp, spirv::IAddOp>;
+using ConvertIndexSub = spirv::ElementwiseOpPattern<SubOp, spirv::ISubOp>;
+using ConvertIndexMul = spirv::ElementwiseOpPattern<MulOp, spirv::IMulOp>;
+using ConvertIndexDivS = spirv::ElementwiseOpPattern<DivSOp, spirv::SDivOp>;
+using ConvertIndexDivU = spirv::ElementwiseOpPattern<DivUOp, spirv::UDivOp>;
+using ConvertIndexRemS = spirv::ElementwiseOpPattern<RemSOp, spirv::SRemOp>;
+using ConvertIndexRemU = spirv::ElementwiseOpPattern<RemUOp, spirv::UModOp>;
+using ConvertIndexMaxS = spirv::ElementwiseOpPattern<MaxSOp, spirv::GLSMaxOp>;
+using ConvertIndexMaxU = spirv::ElementwiseOpPattern<MaxUOp, spirv::GLUMaxOp>;
+using ConvertIndexMinS = spirv::ElementwiseOpPattern<MinSOp, spirv::GLSMinOp>;
+using ConvertIndexMinU = spirv::ElementwiseOpPattern<MinUOp, spirv::GLUMinOp>;
+
+using ConvertIndexShl =
+    spirv::ElementwiseOpPattern<ShlOp, spirv::ShiftLeftLogicalOp>;
+using ConvertIndexShrS =
+    spirv::ElementwiseOpPattern<ShrSOp, spirv::ShiftRightArithmeticOp>;
+using ConvertIndexShrU =
+    spirv::ElementwiseOpPattern<ShrUOp, spirv::ShiftRightLogicalOp>;
+
+/// It is the case that when we convert bitwise operations to SPIR-V operations
+/// we must take into account the special pattern in SPIR-V that if the
+/// operands are boolean values, then SPIR-V uses `SPIRVLogicalOp`. Otherwise,
+/// for non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`. However,
+/// index.add is never a boolean operation so we can directly convert it to the
+/// Bitwise[And|Or]Op.
+using ConvertIndexAnd = spirv::ElementwiseOpPattern<AndOp, spirv::BitwiseAndOp>;
+using ConvertIndexOr = spirv::ElementwiseOpPattern<OrOp, spirv::BitwiseOrOp>;
+using ConvertIndexXor = spirv::ElementwiseOpPattern<XOrOp, spirv::BitwiseXorOp>;
+
+//===----------------------------------------------------------------------===//
+// ConvertConstantBool
+//===----------------------------------------------------------------------===//
+
+// Converts index.bool.constant operation to spirv.Constant.
+struct ConvertIndexConstantBoolOpPattern final
+    : OpConversionPattern<BoolConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(BoolConstantOp op, BoolConstantOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(op, op.getType(),
+                                                   op.getValueAttr());
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertConstant
+//===----------------------------------------------------------------------===//
+
+// Converts index.constant op to spirv.Constant. Will truncate from i64 to i32
+// when required.
+struct ConvertIndexConstantOpPattern final : OpConversionPattern<ConstantOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(ConstantOp op, ConstantOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type indexType = typeConverter->getIndexType();
+
+    APInt value = op.getValue().trunc(typeConverter->getIndexTypeBitwidth());
+    rewriter.replaceOpWithNewOp<spirv::ConstantOp>(
+        op, indexType, IntegerAttr::get(indexType, value));
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCeilDivS
+//===----------------------------------------------------------------------===//
+
+/// Convert `ceildivs(n, m)` into `x = m > 0 ? -1 : 1` and then
+/// `n*m > 0 ? (n+x)/m + 1 : -(-n/m)`. Formula taken from the equivalent
+/// conversion in IndexToLLVM.
+struct ConvertIndexCeilDivSPattern final : OpConversionPattern<CeilDivSOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CeilDivSOp op, CeilDivSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value posOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 1));
+    Value negOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, -1));
+
+    // Compute `x`.
+    Value mPos = rewriter.create<spirv::SGreaterThanOp>(loc, m, zero);
+    Value x = rewriter.create<spirv::SelectOp>(loc, mPos, negOne, posOne);
+
+    // Compute the positive result.
+    Value nPlusX = rewriter.create<spirv::IAddOp>(loc, n, x);
+    Value nPlusXDivM = rewriter.create<spirv::SDivOp>(loc, nPlusX, m);
+    Value posRes = rewriter.create<spirv::IAddOp>(loc, nPlusXDivM, posOne);
+
+    // Compute the negative result.
+    Value negN = rewriter.create<spirv::ISubOp>(loc, zero, n);
+    Value negNDivM = rewriter.create<spirv::SDivOp>(loc, negN, m);
+    Value negRes = rewriter.create<spirv::ISubOp>(loc, zero, negNDivM);
+
+    // Pick the positive result if `n` and `m` have the same sign and `n` is
+    // non-zero, i.e. `(n > 0) == (m > 0) && n != 0`.
+    Value nPos = rewriter.create<spirv::SGreaterThanOp>(loc, n, zero);
+    Value sameSign = rewriter.create<spirv::LogicalEqualOp>(loc, nPos, mPos);
+    Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
+    Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, sameSign, nNonZero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCeilDivU
+//===----------------------------------------------------------------------===//
+
+/// Convert `ceildivu(n, m)` into `n == 0 ? 0 : (n-1)/m + 1`. Formula taken
+/// from the equivalent conversion in IndexToLLVM.
+struct ConvertIndexCeilDivUPattern final : OpConversionPattern<CeilDivUOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CeilDivUOp op, CeilDivUOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value one = rewriter.create<spirv::ConstantOp>(loc, n_type,
+                                                   IntegerAttr::get(n_type, 1));
+
+    // Compute the non-zero result.
+    Value minusOne = rewriter.create<spirv::ISubOp>(loc, n, one);
+    Value quotient = rewriter.create<spirv::UDivOp>(loc, minusOne, m);
+    Value plusOne = rewriter.create<spirv::IAddOp>(loc, quotient, one);
+
+    // Pick the result
+    Value cmp = rewriter.create<spirv::IEqualOp>(loc, n, zero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, zero, plusOne);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexFloorDivS
+//===----------------------------------------------------------------------===//
+
+/// Convert `floordivs(n, m)` into `x = m < 0 ? 1 : -1` and then
+/// `n*m < 0 ? -1 - (x-n)/m : n/m`. Formula taken from the equivalent conversion
+/// in IndexToLLVM.
+struct ConvertIndexFloorDivSPattern final : OpConversionPattern<FloorDivSOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(FloorDivSOp op, FloorDivSOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op.getLoc();
+    Value n = adaptor.getLhs();
+    Type n_type = n.getType();
+    Value m = adaptor.getRhs();
+
+    // Define the constants
+    Value zero = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 0));
+    Value posOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, 1));
+    Value negOne = rewriter.create<spirv::ConstantOp>(
+        loc, n_type, IntegerAttr::get(n_type, -1));
+
+    // Compute `x`.
+    Value mNeg = rewriter.create<spirv::SLessThanOp>(loc, m, zero);
+    Value x = rewriter.create<spirv::SelectOp>(loc, mNeg, posOne, negOne);
+
+    // Compute the negative result
+    Value xMinusN = rewriter.create<spirv::ISubOp>(loc, x, n);
+    Value xMinusNDivM = rewriter.create<spirv::SDivOp>(loc, xMinusN, m);
+    Value negRes = rewriter.create<spirv::ISubOp>(loc, negOne, xMinusNDivM);
+
+    // Compute the positive result.
+    Value posRes = rewriter.create<spirv::SDivOp>(loc, n, m);
+
+    // Pick the negative result if `n` and `m` have different signs and `n` is
+    // non-zero, i.e. `(n < 0) != (m < 0) && n != 0`.
+    Value nNeg = rewriter.create<spirv::SLessThanOp>(loc, n, zero);
+    Value diffSign = rewriter.create<spirv::LogicalNotEqualOp>(loc, nNeg, mNeg);
+    Value nNonZero = rewriter.create<spirv::INotEqualOp>(loc, n, zero);
+
+    Value cmp = rewriter.create<spirv::LogicalAndOp>(loc, diffSign, nNonZero);
+    rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, cmp, posRes, negRes);
+    return success();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCast
+//===----------------------------------------------------------------------===//
+
+/// Convert a cast op. If the materialized index type is the same as the other
+/// type, fold away the op. Otherwise, use the Convert SPIR-V operation.
+/// Signed casts sign extend when the result bitwidth is larger. Unsigned casts
+/// zero extend when the result bitwidth is larger.
+template <typename CastOp, typename ConvertOp>
+struct ConvertIndexCast final : OpConversionPattern<CastOp> {
+  using OpConversionPattern<CastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CastOp op, typename CastOp::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto *typeConverter = this->template getTypeConverter<SPIRVTypeConverter>();
+    Type indexType = typeConverter->getIndexType();
+
+    Type srcType = adaptor.getInput().getType();
+    Type dstType = op.getType();
+    if (isa<IndexType>(srcType)) {
+      srcType = indexType;
+    }
+    if (isa<IndexType>(dstType)) {
+      dstType = indexType;
+    }
+
+    if (srcType == dstType) {
+      rewriter.replaceOp(op, adaptor.getInput());
+    } else {
+      rewriter.template replaceOpWithNewOp<ConvertOp>(op, dstType,
+                                                      adaptor.getOperands());
+    }
+    return success();
+  }
+};
+
+using ConvertIndexCastS = ConvertIndexCast<CastSOp, spirv::SConvertOp>;
+using ConvertIndexCastU = ConvertIndexCast<CastUOp, spirv::UConvertOp>;
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexCmp
+//===----------------------------------------------------------------------===//
+
+// Helper template to replace the operation
+template <typename ICmpOp>
+static LogicalResult rewriteCmpOp(CmpOp op, CmpOpAdaptor adaptor,
+                                  ConversionPatternRewriter &rewriter) {
+  rewriter.replaceOpWithNewOp<ICmpOp>(op, adaptor.getLhs(), adaptor.getRhs());
+  return success();
+}
+
+struct ConvertIndexCmpPattern final : OpConversionPattern<CmpOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(CmpOp op, CmpOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    // We must convert the predicates to the corresponding int comparions.
+    switch (op.getPred()) {
+    case IndexCmpPredicate::EQ:
+      return rewriteCmpOp<spirv::IEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::NE:
+      return rewriteCmpOp<spirv::INotEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SGE:
+      return rewriteCmpOp<spirv::SGreaterThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SGT:
+      return rewriteCmpOp<spirv::SGreaterThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SLE:
+      return rewriteCmpOp<spirv::SLessThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::SLT:
+      return rewriteCmpOp<spirv::SLessThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::UGE:
+      return rewriteCmpOp<spirv::UGreaterThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::UGT:
+      return rewriteCmpOp<spirv::UGreaterThanOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::ULE:
+      return rewriteCmpOp<spirv::ULessThanEqualOp>(op, adaptor, rewriter);
+    case IndexCmpPredicate::ULT:
+      return rewriteCmpOp<spirv::ULessThanOp>(op, adaptor, rewriter);
+    }
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// ConvertIndexSizeOf
+//===----------------------------------------------------------------------===//
+
+/// Lower `index.sizeof` to a constant with the value of the index bitwidth.
+struct ConvertIndexSizeOf final : OpConversionPatte...
[truncated]

@inbelic
Copy link
Contributor Author

inbelic commented Oct 20, 2023

This pull request is the same as #68085, however we will ensure that the build-bot will not fail

@inbelic inbelic changed the title Inbelic/conv index to spirv [mlir][index][spirv] Add conversion for index to spirv Oct 20, 2023
@kuhar
Copy link
Member

kuhar commented Oct 21, 2023

@inbelic Can you confirm this builds with -DBUILD_SHARED_LIBS=1?

@kuhar kuhar changed the title [mlir][index][spirv] Add conversion for index to spirv Reland: "[mlir][index][spirv] Add conversion for index to spirv" Oct 21, 2023
inbelic and others added 2 commits October 21, 2023 11:03
Due to an issue when lowering from scf to spirv as there was no
conversion pass for index to spirv, we are motivated to add a
conversion pass from the Index dialect to the SPIR-V dialect.
Furthermore, we add the new conversion patterns to the scf-to-spirv
conversion.

Fixes llvm#63713
@inbelic
Copy link
Contributor Author

inbelic commented Oct 21, 2023

Was able to reproduce the error locally with the additional flag that you recommended @kuhar.
Adding MLIRSPIRVDialect and MLIRSPIRVConversion to the MLIRIndexToSPIRV deps, as well as, adding MLIRIndexToSPIRV to the MLIRSCFToSPIRV resolved the issue.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@kuhar kuhar merged commit 5aee156 into llvm:main Oct 22, 2023
3 checks passed
@inbelic inbelic deleted the inbelic/conv-index-to-spirv branch October 24, 2023 10:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[mlir][spirv] Support index to spir-v dialect conversion
4 participants