diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 18e857c81af8d..cb0c829719565 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override; }; +struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using Adaptor = + typename ConvertOpToLLVMPattern::OneToNOpAdaptor; + + LogicalResult + matchAndRewrite(arith::SelectOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -479,6 +489,32 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, rewriter); } +//===----------------------------------------------------------------------===// +// SelectOpOneToNLowering +//===----------------------------------------------------------------------===// + +/// Pattern for arith.select where the true/false values lower to multiple +/// SSA values (1:N conversion). This pattern generates multiple arith.select +/// than can be lowered by the 1:1 arith.select pattern. +LogicalResult SelectOpOneToNLowering::matchAndRewrite( + arith::SelectOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // In case of a 1:1 conversion, the 1:1 pattern will match. + if (llvm::hasSingleElement(adaptor.getTrueValue())) + return rewriter.notifyMatchFailure( + op, "not a 1:N conversion, 1:1 pattern will match"); + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure(op, + "non-i1 conditions are not supported"); + SmallVector results; + for (auto [trueValue, falseValue] : + llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue())) + results.push_back(arith::SelectOp::create( + rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue)); + rewriter.replaceOpWithMultiple(op, {results}); + return success(); +} + //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -587,6 +623,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns( RemSIOpLowering, RemUIOpLowering, SelectOpLowering, + SelectOpOneToNLowering, ShLIOpLowering, ShRSIOpLowering, ShRUIOpLowering, diff --git a/mlir/test/Conversion/ArithToLLVM/type-conversion.mlir b/mlir/test/Conversion/ArithToLLVM/type-conversion.mlir new file mode 100644 index 0000000000000..e3a0c82a628ba --- /dev/null +++ b/mlir/test/Conversion/ArithToLLVM/type-conversion.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -test-llvm-legalize-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-llvm-legalize-patterns="allow-pattern-rollback=0" -split-input-file | FileCheck %s + +// CHECK-LABEL: llvm.func @arith_select( +// CHECK-SAME: %[[arg0:.*]]: i1, %[[arg1:.*]]: i18, %[[arg2:.*]]: i18, %[[arg3:.*]]: i18, %[[arg4:.*]]: i18) -> !llvm.struct<(i18, i18)> +// CHECK: %[[select0:.*]] = llvm.select %[[arg0]], %[[arg1]], %[[arg3]] : i1, i18 +// CHECK: %[[select1:.*]] = llvm.select %[[arg0]], %[[arg2]], %[[arg4]] : i1, i18 +// CHECK: %[[i0:.*]] = llvm.mlir.poison : !llvm.struct<(i18, i18)> +// CHECK: %[[i1:.*]] = llvm.insertvalue %[[select0]], %[[i0]][0] : !llvm.struct<(i18, i18)> +// CHECK: %[[i2:.*]] = llvm.insertvalue %[[select1]], %[[i1]][1] : !llvm.struct<(i18, i18)> +// CHECK: llvm.return %[[i2]] +func.func @arith_select(%arg0: i1, %arg1: i17, %arg2: i17) -> (i17) { + %0 = arith.select %arg0, %arg1, %arg2 : i17 + return %0 : i17 +} diff --git a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp index 9d30ae43cccc1..69a3d98bc09e4 100644 --- a/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/LLVM/TestPatterns.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" @@ -70,6 +71,7 @@ struct TestLLVMLegalizePatternsPass // Populate patterns. mlir::RewritePatternSet patterns(ctx); patterns.add(ctx, converter); + arith::populateArithToLLVMConversionPatterns(converter, patterns); populateFuncToLLVMConversionPatterns(converter, patterns); cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);