diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h index 911a030d4a7e58..2186107851d918 100644 --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -49,9 +49,13 @@ class SPIRVTypeConverter : public TypeConverter { /// values will be packed into one 32-bit value to be memory efficient. bool emulateNon32BitScalarTypes; + /// The number of bits to store a boolean value. It is eight bits by + /// default. + unsigned boolNumBits; + // Note: we need this instead of inline initializers becuase of // https://bugs.llvm.org/show_bug.cgi?id=36684 - Options() : emulateNon32BitScalarTypes(true) {} + Options() : emulateNon32BitScalarTypes(true), boolNumBits(8) {} }; explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp index ed66252e20ae59..397d26b0499d8c 100644 --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -991,6 +991,9 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, loadOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + bool isBool = srcBits == 1; + if (isBool) + srcBits = typeConverter.getOptions().boolNumBits; auto dstType = typeConverter.convertType(memrefType) .cast() .getPointeeType() @@ -1044,6 +1047,18 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, shiftValue); result = rewriter.create(loc, dstType, result, shiftValue); + + if (isBool) { + dstType = typeConverter.convertType(loadOp.getType()); + mask = spirv::ConstantOp::getOne(result.getType(), loc, rewriter); + Value isOne = rewriter.create(loc, result, mask); + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + result = rewriter.create(loc, dstType, isOne, one, zero); + } else if (result.getType().getIntOrFloatBitWidth() != + static_cast(dstBits)) { + result = rewriter.create(loc, dstType, result); + } rewriter.replaceOp(loadOp, result); assert(accessChainOp.use_empty()); @@ -1117,6 +1132,10 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), storeOperands.indices(), loc, rewriter); int srcBits = memrefType.getElementType().getIntOrFloatBitWidth(); + + bool isBool = srcBits == 1; + if (isBool) + srcBits = typeConverter.getOptions().boolNumBits; auto dstType = typeConverter.convertType(memrefType) .cast() .getPointeeType() @@ -1156,8 +1175,14 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, rewriter.create(loc, dstType, mask, offset); clearBitsMask = rewriter.create(loc, dstType, clearBitsMask); - Value storeVal = - shiftValue(loc, storeOperands.value(), offset, mask, dstBits, rewriter); + Value storeVal = storeOperands.value(); + if (isBool) { + Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); + Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); + storeVal = + rewriter.create(loc, dstType, storeVal, one, zero); + } + storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp, srcBits, dstBits, rewriter); Optional scope = getAtomicOpScope(memrefType); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 4e2dc01108b2d0..d26299fa82a9ce 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -153,6 +153,10 @@ SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) { #undef STORAGE_SPACE_MAP_FN } +const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const { + return options; +} + #undef STORAGE_SPACE_MAP_LIST // TODO: This is a utility function that should probably be exposed by the @@ -342,9 +346,66 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv, return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); } +static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv, + const SPIRVTypeConverter::Options &options, + MemRefType type) { + if (!type.hasStaticShape()) { + LLVM_DEBUG(llvm::dbgs() + << type << " dynamic shape on i1 is not supported yet\n"); + return nullptr; + } + + Optional storageClass = + SPIRVTypeConverter::getStorageClassForMemorySpace( + type.getMemorySpaceAsInt()); + if (!storageClass) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot convert memory space\n"); + return nullptr; + } + + unsigned numBoolBits = options.boolNumBits; + if (numBoolBits != 8) { + LLVM_DEBUG(llvm::dbgs() + << "using non-8-bit storage for bool types unimplemented"); + return nullptr; + } + auto elementType = IntegerType::get(type.getContext(), numBoolBits) + .dyn_cast(); + if (!elementType) + return nullptr; + Type arrayElemType = + convertScalarType(targetEnv, options, elementType, storageClass); + if (!arrayElemType) + return nullptr; + Optional arrayElemSize = getTypeNumBytes(options, arrayElemType); + if (!arrayElemSize) { + LLVM_DEBUG(llvm::dbgs() + << type << " illegal: cannot deduce converted element size\n"); + return nullptr; + } + + int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8; + auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize; + auto arrayType = + spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize); + + // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with + // workgroup storage class do not need the struct to be laid out explicitly. + auto structType = *storageClass == spirv::StorageClass::Workgroup + ? spirv::StructType::get(arrayType) + : spirv::StructType::get(arrayType, 0); + return spirv::PointerType::get(structType, *storageClass); +} + static Type convertMemrefType(const spirv::TargetEnv &targetEnv, const SPIRVTypeConverter::Options &options, MemRefType type) { + if (type.getElementType().isa() && + type.getElementTypeBitWidth() == 1) { + return convertBoolMemrefType(targetEnv, options, type); + } + Optional storageClass = SPIRVTypeConverter::getStorageClassForMemorySpace( type.getMemorySpaceAsInt()); diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir index d074969febff19..86d390a2ce703a 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -911,12 +911,40 @@ func @load_store_zero_rank_int(%arg0: memref, %arg1: memref) { // Check that access chain indices are properly adjusted if non-32-bit types are // emulated via 32-bit types. -// TODO: Test i1 and i64 types. +// TODO: Test i64 types. module attributes { spv.target_env = #spv.target_env< #spv.vce, {}> } { +// CHECK-LABEL: @load_i1 +func @load_i1(%arg0: memref) -> i1 { + // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 + // CHECK: %[[FOUR1:.+]] = spv.Constant 4 : i32 + // CHECK: %[[QUOTIENT:.+]] = spv.SDiv %[[ZERO]], %[[FOUR1]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %{{.+}}[%[[ZERO]], %[[QUOTIENT]]] + // CHECK: %[[LOAD:.+]] = spv.Load "StorageBuffer" %[[PTR]] + // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[BITS:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[VALUE:.+]] = spv.ShiftRightArithmetic %[[LOAD]], %[[BITS]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Constant 255 : i32 + // CHECK: %[[T1:.+]] = spv.BitwiseAnd %[[VALUE]], %[[MASK]] : i32 + // CHECK: %[[T2:.+]] = spv.Constant 24 : i32 + // CHECK: %[[T3:.+]] = spv.ShiftLeftLogical %[[T1]], %[[T2]] : i32, i32 + // CHECK: %[[T4:.+]] = spv.ShiftRightArithmetic %[[T3]], %[[T2]] : i32, i32 + // Convert to i1 type. + // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 + // CHECK: %[[ISONE:.+]] = spv.IEqual %[[T4]], %[[ONE]] : i32 + // CHECK: %[[FALSE:.+]] = spv.Constant false + // CHECK: %[[TRUE:.+]] = spv.Constant true + // CHECK: %[[RES:.+]] = spv.Select %[[ISONE]], %[[TRUE]], %[[FALSE]] : i1, i1 + // CHECK: spv.ReturnValue %[[RES]] : i1 + %0 = memref.load %arg0[] : memref + return %0 : i1 +} + // CHECK-LABEL: @load_i8 func @load_i8(%arg0: memref) { // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 @@ -982,6 +1010,31 @@ func @load_f32(%arg0: memref) { return } +// CHECK-LABEL: @store_i1 +// CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i1) +func @store_i1(%arg0: memref, %value: i1) { + // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 + // CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32 + // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 + // CHECK: %[[IDX:.+]] = spv.UMod %[[ZERO]], %[[FOUR]] : i32 + // CHECK: %[[OFFSET:.+]] = spv.IMul %[[IDX]], %[[EIGHT]] : i32 + // CHECK: %[[MASK1:.+]] = spv.Constant 255 : i32 + // CHECK: %[[TMP1:.+]] = spv.ShiftLeftLogical %[[MASK1]], %[[OFFSET]] : i32, i32 + // CHECK: %[[MASK:.+]] = spv.Not %[[TMP1]] : i32 + // CHECK: %[[ZERO1:.+]] = spv.Constant 0 : i32 + // CHECK: %[[ONE1:.+]] = spv.Constant 1 : i32 + // CHECK: %[[CASTED_ARG1:.+]] = spv.Select %[[ARG1]], %[[ONE1]], %[[ZERO1]] : i1, i32 + // CHECK: %[[CLAMPED_VAL:.+]] = spv.BitwiseAnd %[[CASTED_ARG1]], %[[MASK1]] : i32 + // CHECK: %[[STORE_VAL:.+]] = spv.ShiftLeftLogical %[[CLAMPED_VAL]], %[[OFFSET]] : i32, i32 + // CHECK: %[[FOUR2:.+]] = spv.Constant 4 : i32 + // CHECK: %[[ACCESS_IDX:.+]] = spv.SDiv %[[ZERO]], %[[FOUR2]] : i32 + // CHECK: %[[PTR:.+]] = spv.AccessChain %[[ARG0]][%[[ZERO]], %[[ACCESS_IDX]]] + // CHECK: spv.AtomicAnd "Device" "AcquireRelease" %[[PTR]], %[[MASK]] + // CHECK: spv.AtomicOr "Device" "AcquireRelease" %[[PTR]], %[[STORE_VAL]] + memref.store %value, %arg0[] : memref + return +} + // CHECK-LABEL: @store_i8 // CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i32) func @store_i8(%arg0: memref, %value: i8) { diff --git a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir index cacc3e762c149d..58513124907a51 100644 --- a/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-types-to-spirv.mlir @@ -286,20 +286,6 @@ func @memref_mem_space( // ----- -// Check that boolean memref is not supported at the moment. -module attributes { - spv.target_env = #spv.target_env<#spv.vce, {}> -} { - -// CHECK-LABEL: func @memref_type({{%.*}}: memref<3xi1>) -func @memref_type(%arg0: memref<3xi1>) { - return -} - -} // end module - -// ----- - // Check that using non-32-bit scalar types in interface storage classes // requires special capability and extension: convert them to 32-bit if not // satisfied. @@ -307,6 +293,11 @@ module attributes { spv.target_env = #spv.target_env<#spv.vce, {}> } { +// An i1 is store in 8-bit, so 5xi1 has 40 bits, which is stored in 2xi32. +// CHECK-LABEL: spv.func @memref_1bit_type +// CHECK-SAME: !spv.ptr [0])>, StorageBuffer> +func @memref_1bit_type(%arg0: memref<5xi1>) { return } + // CHECK-LABEL: spv.func @memref_8bit_StorageBuffer // CHECK-SAME: !spv.ptr [0])>, StorageBuffer> // NOEMU-LABEL: func @memref_8bit_StorageBuffer