diff --git a/mlir/docs/DefiningDialects/Operations.md b/mlir/docs/DefiningDialects/Operations.md index f988bebea1223..7c1be84727476 100644 --- a/mlir/docs/DefiningDialects/Operations.md +++ b/mlir/docs/DefiningDialects/Operations.md @@ -1649,6 +1649,15 @@ inline constexpr MyBitEnum operator&(MyBitEnum a, MyBitEnum b) { inline constexpr MyBitEnum operator^(MyBitEnum a, MyBitEnum b) { return static_cast(static_cast(a) ^ static_cast(b)); } +inline constexpr MyBitEnum &operator|=(MyBitEnum &a, MyBitEnum b) { + return a = a | b; +} +inline constexpr MyBitEnum &operator&=(MyBitEnum &a, MyBitEnum b) { + return a = a & b; +} +inline constexpr MyBitEnum &operator^=(MyBitEnum &a, MyBitEnum b) { + return a = a ^ b; +} inline constexpr MyBitEnum operator~(MyBitEnum bits) { // Ensure only bits that can be present in the enum are set return static_cast(~static_cast(bits) & static_cast(15u)); diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp index 49d06497dbeea..f44552c4556c2 100644 --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -512,7 +512,7 @@ calculateMemoryRequirements(Value accessedPtr, bool isNontemporal, if (!sizeInBytes.has_value()) return failure(); - memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; + memoryAccess |= spirv::MemoryAccess::Aligned; auto memAccessAttr = spirv::MemoryAccessAttr::get(ctx, memoryAccess); auto alignmentValue = preferredAlignment ? preferredAlignment : *sizeInBytes; auto alignment = IntegerAttr::get(IntegerType::get(ctx, 32), alignmentValue); diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 04f56b9691fd1..5061a4454a7fd 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -753,7 +753,7 @@ struct VectorLoadOpConverter final spirv::MemoryAccessAttr memoryAccessAttr; IntegerAttr alignmentAttr; if (alignment.has_value()) { - memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; + memoryAccess |= spirv::MemoryAccess::Aligned; memoryAccessAttr = spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess); alignmentAttr = rewriter.getI32IntegerAttr(alignment.value()); @@ -822,7 +822,7 @@ struct VectorStoreOpConverter final spirv::MemoryAccessAttr memoryAccessAttr; IntegerAttr alignmentAttr; if (alignment.has_value()) { - memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; + memoryAccess |= spirv::MemoryAccess::Aligned; memoryAccessAttr = spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess); alignmentAttr = rewriter.getI32IntegerAttr(alignment.value()); diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp index 5ceae9b16af20..67573c4ee6061 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIScopeForLLVMFuncOp.cpp @@ -77,7 +77,7 @@ static void addScopeToFunction(LLVM::LLVMFuncOp llvmFunc, auto subprogramFlags = LLVM::DISubprogramFlags::Optimized; if (!llvmFunc.isExternal()) { id = DistinctAttr::create(UnitAttr::get(context)); - subprogramFlags = subprogramFlags | LLVM::DISubprogramFlags::Definition; + subprogramFlags |= LLVM::DISubprogramFlags::Definition; } else { compileUnitAttr = {}; } diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index d152763f7382e..d4d32f5885971 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -364,6 +364,9 @@ getAllBitsUnsetCase(llvm::ArrayRef cases) { // inline constexpr operator|( a, b); // inline constexpr operator&( a, b); // inline constexpr operator^( a, b); +// inline constexpr &operator|=( &a, b); +// inline constexpr &operator&=( &a, b); +// inline constexpr &operator^=( &a, b); // inline constexpr operator~( bits); // inline constexpr bool bitEnumContainsAll( bits, bit); // inline constexpr bool bitEnumContainsAny( bits, bit); @@ -385,6 +388,15 @@ inline constexpr {0} operator&({0} a, {0} b) {{ inline constexpr {0} operator^({0} a, {0} b) {{ return static_cast<{0}>(static_cast<{1}>(a) ^ static_cast<{1}>(b)); } +inline constexpr {0} &operator|=({0} &a, {0} b) {{ + return a = a | b; +} +inline constexpr {0} &operator&=({0} &a, {0} b) {{ + return a = a & b; +} +inline constexpr {0} &operator^=({0} &a, {0} b) {{ + return a = a ^ b; +} inline constexpr {0} operator~({0} bits) {{ // Ensure only bits that can be present in the enum are set return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u));