diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 86eff557291914..24b2cf1dc422bc 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -304,6 +304,11 @@ def ConvertMemRefToSPIRV : Pass<"convert-memref-to-spirv", "ModuleOp"> { let summary = "Convert MemRef dialect to SPIR-V dialect"; let constructor = "mlir::createConvertMemRefToSPIRVPass()"; let dependentDialects = ["spirv::SPIRVDialect"]; + let options = [ + Option<"boolNumBits", "bool-num-bits", + "int", /*default=*/"8", + "The number of bits to store a boolean value"> + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index ddc312eb09145c..899c582796a6fc 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -119,6 +119,17 @@ static Optional getAtomicOpScope(MemRefType t) { return {}; } +/// Casts the given `srcBool` into an integer of `dstType`. +static Value castBoolToIntN(Location loc, Value srcBool, Type dstType, + OpBuilder &builder) { + assert(srcBool.getType().isInteger(1)); + if (dstType.isInteger(1)) + return srcBool; + Value zero = spirv::ConstantOp::getZero(dstType, loc, builder); + Value one = spirv::ConstantOp::getOne(dstType, loc, builder); + return builder.create(loc, dstType, srcBool, one, zero); +} + //===----------------------------------------------------------------------===// // Operation conversion //===----------------------------------------------------------------------===// @@ -336,9 +347,7 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, dstType = typeConverter.convertType(loadOp.getType()); mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter); Value isOne = rewriter.create(loc, result, mask); - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - result = rewriter.create(loc, dstType, isOne, one, zero); + result = castBoolToIntN(loc, isOne, dstType, rewriter); } else if (result.getType().getIntOrFloatBitWidth() != static_cast(dstBits)) { result = rewriter.create(loc, dstType, result); @@ -392,6 +401,7 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, bool isBool = srcBits == 1; if (isBool) srcBits = typeConverter.getOptions().boolNumBits; + Type pointeeType = typeConverter.convertType(memrefType) .cast() .getPointeeType(); @@ -406,8 +416,11 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, assert(dstBits % srcBits == 0); if (srcBits == dstBits) { + Value storeVal = storeOperands.value(); + if (isBool) + storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); rewriter.replaceOpWithNewOp( - storeOp, accessChainOp.getResult(), storeOperands.value()); + storeOp, accessChainOp.getResult(), storeVal); return success(); } @@ -435,12 +448,8 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, clearBitsMask = rewriter.create(loc, dstType, clearBitsMask); Value storeVal = storeOperands.value(); - if (isBool) { - Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); - Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); - storeVal = - rewriter.create(loc, dstType, storeVal, one, zero); - } + if (isBool) + storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp index 79d137fd8a7e09..cd0d3429c6b582 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRVPass.cpp @@ -34,7 +34,9 @@ void ConvertMemRefToSPIRVPass::runOnOperation() { std::unique_ptr target = SPIRVConversionTarget::get(targetAttr); - SPIRVTypeConverter typeConverter(targetAttr); + SPIRVTypeConverter::Options options; + options.boolNumBits = this->boolNumBits; + SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull in // patterns for other dialects. diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index bb8b8a5e86d0ee..daf0ea0797a548 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -1,9 +1,17 @@ -// RUN: mlir-opt -split-input-file -convert-memref-to-spirv %s -o - | FileCheck %s +// RUN: mlir-opt -split-input-file -convert-memref-to-spirv="bool-num-bits=8" %s -o - | FileCheck %s + +// Check that with proper compute and storage extensions, we don't need to +// perform special tricks. module attributes { spv.target_env = #spv.target_env< - #spv.vce, {}> + #spv.vce, {}> } { // CHECK-LABEL: @load_store_zero_rank_float @@ -57,6 +65,27 @@ func @load_store_unknown_dim(%i: index, %source: memref, %dest: memref, +// CHECK-SAME: %[[IDX:.+]]: index +func @store_i1(%dst: memref<4xi1>, %i: index) { + %true = constant true + // CHECK: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1> to !spv.ptr [0])>, StorageBuffer> + // CHECK: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] + // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 + // CHECK: %[[MUL:.+]] = spv.IMul %[[ONE]], %[[IDX_CAST]] : i32 + // CHECK: %[[ADD:.+]] = spv.IAdd %[[ZERO_1]], %[[MUL]] : i32 + // CHECK: %[[ADDR:.+]] = spv.AccessChain %[[DST_CAST]][%[[ZERO_0]], %[[ADD]]] : !spv.ptr [0])>, StorageBuffer>, i32, i32 + // CHECK: %[[ZERO_I8:.+]] = spv.Constant 0 : i8 + // CHECK: %[[ONE_I8:.+]] = spv.Constant 1 : i8 + // CHECK: %[[RES:.+]] = spv.Select %{{.+}}, %[[ONE_I8]], %[[ZERO_I8]] : i1, i8 + // CHECK: spv.Store "StorageBuffer" %[[ADDR]], %[[RES]] : i8 + memref.store %true, %dst[%i]: memref<4xi1> + return +} + } // end module // ----- @@ -88,10 +117,7 @@ func @load_i1(%arg0: memref) -> i1 { // CHECK: %[[T4:.+]] = spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 // Convert to i1 type. // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 - // CHECK: %[[ISONE:.+]] = spv.IEqual %[[T4]], %[[ONE]] : i32 - // CHECK: %[[FALSE:.+]] = spv.Constant false - // CHECK: %[[TRUE:.+]] = spv.Constant true - // CHECK: %[[RES:.+]] = spv.Select %[[ISONE]], %[[TRUE]], %[[FALSE]] : i1, i1 + // CHECK: %[[RES:.+]] = spv.IEqual %[[T4]], %[[ONE]] : i32 // CHECK: return %[[RES]] %0 = memref.load %arg0[] : memref return %0 : i1