diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp index 07c99f06ab2c3..ea367b1faa201 100644 --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -141,20 +141,25 @@ class CountLeadingZerosPattern final return failure(); Location loc = countOp.getLoc(); - Value allOneBits = getScalarOrVectorI32Constant(type, -1, rewriter, loc); - Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc); + Value input = adaptor.getOperand(); + Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc); Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc); - Value msb = - rewriter.create(loc, adaptor.getOperand()); - // We need to subtract from 31 given that the index is from the least - // significant bit. - Value sub = rewriter.create(loc, val31, msb); - // If the integer has all zero bits, GLSL FindUMsb would return -1. So - // theoretically (31 - FindUMsb) should still give the correct result. - // However, certain Vulkan implementations have driver bugs regarding it. - // So handle the corner case explicity to workaround it. - Value cmp = rewriter.create(loc, msb, allOneBits); - rewriter.replaceOpWithNewOp(countOp, cmp, val32, sub); + Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc); + + Value msb = rewriter.create(loc, input); + // We need to subtract from 31 given that the index returned by GLSL + // FindUMsb is counted from the least significant bit. Theoretically this + // also gives the correct result even if the integer has all zero bits, in + // which case GLSL FindUMsb would return -1. + Value subMsb = rewriter.create(loc, val31, msb); + // However, certain Vulkan implementations have driver bugs for the corner + // case where the input is zero. And.. it can be smart to optimize a select + // only involving the corner case. So separately compute the result when the + // input is either zero or one. + Value subInput = rewriter.create(loc, val32, input); + Value cmp = rewriter.create(loc, input, val1); + rewriter.replaceOpWithNewOp(countOp, cmp, subInput, + subMsb); return success(); } }; diff --git a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir index d8126d4e956c6..a3067af661f64 100644 --- a/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir +++ b/mlir/test/Conversion/MathToSPIRV/math-to-glsl-spirv.mlir @@ -82,13 +82,14 @@ func.func @float32_ternary_vector(%a: vector<4xf32>, %b: vector<4xf32>, // CHECK-LABEL: @ctlz_scalar // CHECK-SAME: (%[[VAL:.+]]: i32) func.func @ctlz_scalar(%val: i32) -> i32 { - // CHECK-DAG: %[[MAX:.+]] = spv.Constant -1 : i32 - // CHECK-DAG: %[[V32:.+]] = spv.Constant 32 : i32 + // CHECK-DAG: %[[V1:.+]] = spv.Constant 1 : i32 // CHECK-DAG: %[[V31:.+]] = spv.Constant 31 : i32 + // CHECK-DAG: %[[V32:.+]] = spv.Constant 32 : i32 // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : i32 - // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32 - // CHECK: %[[CMP:.+]] = spv.IEqual %[[MSB]], %[[MAX]] : i32 - // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[V32]], %[[SUB]] : i1, i32 + // CHECK: %[[SUB1:.+]] = spv.ISub %[[V31]], %[[MSB]] : i32 + // CHECK: %[[SUB2:.+]] = spv.ISub %[[V32]], %[[VAL]] : i32 + // CHECK: %[[CMP:.+]] = spv.ULessThanEqual %[[VAL]], %[[V1]] : i32 + // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[SUB2]], %[[SUB1]] : i1, i32 // CHECK: return %[[R]] %0 = math.ctlz %val : i32 return %0 : i32 @@ -98,7 +99,7 @@ func.func @ctlz_scalar(%val: i32) -> i32 { func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> { // CHECK: spv.GLSL.FindUMsb // CHECK: spv.ISub - // CHECK: spv.IEqual + // CHECK: spv.ULessThanEqual // CHECK: spv.Select %0 = math.ctlz %val : vector<1xi32> return %0 : vector<1xi32> @@ -107,14 +108,14 @@ func.func @ctlz_vector1(%val: vector<1xi32>) -> vector<1xi32> { // CHECK-LABEL: @ctlz_vector2 // CHECK-SAME: (%[[VAL:.+]]: vector<2xi32>) func.func @ctlz_vector2(%val: vector<2xi32>) -> vector<2xi32> { - // CHECK-DAG: %[[MAX:.+]] = spv.Constant dense<-1> : vector<2xi32> - // CHECK-DAG: %[[V32:.+]] = spv.Constant dense<32> : vector<2xi32> + // CHECK-DAG: %[[V1:.+]] = spv.Constant dense<1> : vector<2xi32> // CHECK-DAG: %[[V31:.+]] = spv.Constant dense<31> : vector<2xi32> + // CHECK-DAG: %[[V32:.+]] = spv.Constant dense<32> : vector<2xi32> // CHECK: %[[MSB:.+]] = spv.GLSL.FindUMsb %[[VAL]] : vector<2xi32> - // CHECK: %[[SUB:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32> - // CHECK: %[[CMP:.+]] = spv.IEqual %[[MSB]], %[[MAX]] : vector<2xi32> - // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[V32]], %[[SUB]] : vector<2xi1>, vector<2xi32> - // CHECK: return %[[R]] + // CHECK: %[[SUB1:.+]] = spv.ISub %[[V31]], %[[MSB]] : vector<2xi32> + // CHECK: %[[SUB2:.+]] = spv.ISub %[[V32]], %[[VAL]] : vector<2xi32> + // CHECK: %[[CMP:.+]] = spv.ULessThanEqual %[[VAL]], %[[V1]] : vector<2xi32> + // CHECK: %[[R:.+]] = spv.Select %[[CMP]], %[[SUB2]], %[[SUB1]] : vector<2xi1>, vector<2xi32> %0 = math.ctlz %val : vector<2xi32> return %0 : vector<2xi32> }