diff --git a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp index c4b6382a42bac..31d9023e4346d 100644 --- a/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.cpp @@ -275,6 +275,22 @@ static bool isBoolScalarOrVector(Type type) { return false; } +/// Returns true if scalar/vector type `a` and `b` have the same number of +/// bitwidth. +static bool hasSameBitwidth(Type a, Type b) { + auto getNumBitwidth = [](Type type) { + unsigned bw = 0; + if (type.isIntOrFloat()) + bw = type.getIntOrFloatBitWidth(); + else if (auto vecType = type.dyn_cast()) + bw = vecType.getElementTypeBitWidth() * vecType.getNumElements(); + return bw; + }; + unsigned aBW = getNumBitwidth(a); + unsigned bBW = getNumBitwidth(b); + return aBW != 0 && bBW != 0 && aBW == bBW; +} + //===----------------------------------------------------------------------===// // ConstantOp with composite type //===----------------------------------------------------------------------===// @@ -655,10 +671,11 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite( switch (op.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ - case cmpPredicate: \ - rewriter.replaceOpWithNewOp(op, op.getResult().getType(), \ - adaptor.getLhs(), adaptor.getRhs()); \ - return success(); + case cmpPredicate: { \ + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), \ + adaptor.getRhs()); \ + return success(); \ + } DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp); DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp); @@ -676,20 +693,23 @@ LogicalResult CmpIOpBooleanPattern::matchAndRewrite( LogicalResult CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Type operandType = op.getLhs().getType(); - if (isBoolScalarOrVector(operandType)) + Type srcType = op.getLhs().getType(); + if (isBoolScalarOrVector(srcType)) + return failure(); + Type dstType = getTypeConverter()->convertType(srcType); + if (!dstType) return failure(); switch (op.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ if (spirvOp::template hasTrait() && \ - operandType != this->getTypeConverter()->convertType(operandType)) { \ + srcType != dstType && !hasSameBitwidth(srcType, dstType)) { \ return op.emitError( \ "bitwidth emulation is not implemented yet on unsigned op"); \ } \ - rewriter.replaceOpWithNewOp(op, op.getResult().getType(), \ - adaptor.getLhs(), adaptor.getRhs()); \ + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), \ + adaptor.getRhs()); \ return success(); DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp); @@ -718,8 +738,8 @@ CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, switch (op.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ - rewriter.replaceOpWithNewOp(op, op.getResult().getType(), \ - adaptor.getLhs(), adaptor.getRhs()); \ + rewriter.replaceOpWithNewOp(op, adaptor.getLhs(), \ + adaptor.getRhs()); \ return success(); // Ordered. diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir index 7d17359030d46..8925197512875 100644 --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -392,6 +392,15 @@ func.func @cmpi(%arg0 : i32, %arg1 : i32) { return } +// CHECK-LABEL: @vec1cmpi +func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) { + // CHECK: spv.ULessThan + %0 = arith.cmpi ult, %arg0, %arg1 : vector<1xi32> + // CHECK: spv.SGreaterThan + %1 = arith.cmpi sgt, %arg0, %arg1 : vector<1xi32> + return +} + // CHECK-LABEL: @boolcmpi func.func @boolcmpi(%arg0 : i1, %arg1 : i1) { // CHECK: spv.LogicalEqual @@ -401,6 +410,15 @@ func.func @boolcmpi(%arg0 : i1, %arg1 : i1) { return } +// CHECK-LABEL: @vec1boolcmpi +func.func @vec1boolcmpi(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) { + // CHECK: spv.LogicalEqual + %0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1> + // CHECK: spv.LogicalNotEqual + %1 = arith.cmpi ne, %arg0, %arg1 : vector<1xi1> + return +} + // CHECK-LABEL: @vecboolcmpi func.func @vecboolcmpi(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { // CHECK: spv.LogicalEqual @@ -1237,6 +1255,15 @@ func.func @cmpf(%arg0 : f32, %arg1 : f32) { return } +// CHECK-LABEL: @vec1cmpf +func.func @vec1cmpf(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>) { + // CHECK: spv.FOrdGreaterThan + %0 = arith.cmpf ogt, %arg0, %arg1 : vector<1xf32> + // CHECK: spv.FUnordLessThan + %1 = arith.cmpf ult, %arg0, %arg1 : vector<1xf32> + return +} + } // end module // -----