Skip to content

Commit

Permalink
[mlir] Move std.generic_atomic_rmw to the memref dialect
Browse files Browse the repository at this point in the history
This is part of splitting up the standard dialect. The move makes sense anyways,
given that the memref dialect already holds memref.atomic_rmw which is the non-region
sibling operation of std.generic_atomic_rmw (the relationship is even more clear given
they have nearly the same description % how they represent the inner computation).

Differential Revision: https://reviews.llvm.org/D118209
  • Loading branch information
River707 committed Jan 26, 2022
1 parent 480cd4c commit 632a4f8
Show file tree
Hide file tree
Showing 14 changed files with 395 additions and 413 deletions.
79 changes: 77 additions & 2 deletions mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
Expand Up @@ -698,6 +698,81 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// GenericAtomicRMWOp
//===----------------------------------------------------------------------===//

def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [
SingleBlockImplicitTerminator<"AtomicYieldOp">,
TypesMatchWith<"result type matches element type of memref",
"memref", "result",
"$_self.cast<MemRefType>().getElementType()">
]> {
let summary = "atomic read-modify-write operation with a region";
let description = [{
The `memref.generic_atomic_rmw` operation provides a way to perform a
read-modify-write sequence that is free from data races. The memref operand
represents the buffer that the read and write will be performed against, as
accessed by the specified indices. The arity of the indices is the rank of
the memref. The result represents the latest value that was stored. The
region contains the code for the modification itself. The entry block has
a single argument that represents the value stored in `memref[indices]`
before the write is performed. No side-effecting ops are allowed in the
body of `GenericAtomicRMWOp`.

Example:

```mlir
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%current_value : f32):
%c1 = arith.constant 1.0 : f32
%inc = arith.addf %c1, %current_value : f32
memref.atomic_yield %inc : f32
}
```
}];

let arguments = (ins
MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
Variadic<Index>:$indices);

let results = (outs
AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);

let regions = (region AnyRegion:$atomic_body);

let skipDefaultBuilders = 1;
let builders = [OpBuilder<(ins "Value":$memref, "ValueRange":$ivs)>];

let extraClassDeclaration = [{
// TODO: remove post migrating callers.
Region &body() { return getRegion(); }

// The value stored in memref[ivs].
Value getCurrentValue() {
return getRegion().getArgument(0);
}
MemRefType getMemRefType() {
return memref().getType().cast<MemRefType>();
}
}];
}

def AtomicYieldOp : MemRef_Op<"atomic_yield", [
HasParent<"GenericAtomicRMWOp">,
NoSideEffect,
Terminator
]> {
let summary = "yield operation for GenericAtomicRMWOp";
let description = [{
"memref.atomic_yield" yields an SSA value from a
GenericAtomicRMWOp region.
}];

let arguments = (ins AnyType:$result);
let assemblyFormat = "$result attr-dict `:` type($result)";
}

//===----------------------------------------------------------------------===//
// GetGlobalOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1687,7 +1762,7 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
]> {
let summary = "atomic read-modify-write operation";
let description = [{
The `atomic_rmw` operation provides a way to perform a read-modify-write
The `memref.atomic_rmw` operation provides a way to perform a read-modify-write
sequence that is free from data races. The kind enumeration specifies the
modification to perform. The value operand represents the new value to be
applied during the modification. The memref operand represents the buffer
Expand All @@ -1698,7 +1773,7 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
Example:

```mlir
%x = arith.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
%x = memref.atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32
```
}];

Expand Down
70 changes: 0 additions & 70 deletions mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
Expand Up @@ -178,76 +178,6 @@ def AssertOp : Std_Op<"assert"> {
let hasCanonicalizeMethod = 1;
}

def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [
SingleBlockImplicitTerminator<"AtomicYieldOp">,
TypesMatchWith<"result type matches element type of memref",
"memref", "result",
"$_self.cast<MemRefType>().getElementType()">
]> {
let summary = "atomic read-modify-write operation with a region";
let description = [{
The `generic_atomic_rmw` operation provides a way to perform a read-modify-write
sequence that is free from data races. The memref operand represents the
buffer that the read and write will be performed against, as accessed by
the specified indices. The arity of the indices is the rank of the memref.
The result represents the latest value that was stored. The region contains
the code for the modification itself. The entry block has a single argument
that represents the value stored in `memref[indices]` before the write is
performed. No side-effecting ops are allowed in the body of
`GenericAtomicRMWOp`.

Example:

```mlir
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%current_value : f32):
%c1 = arith.constant 1.0 : f32
%inc = arith.addf %c1, %current_value : f32
atomic_yield %inc : f32
}
```
}];

let arguments = (ins
MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref,
Variadic<Index>:$indices);

let results = (outs
AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result);

let regions = (region AnyRegion:$atomic_body);

let skipDefaultBuilders = 1;
let builders = [OpBuilder<(ins "Value":$memref, "ValueRange":$ivs)>];

let extraClassDeclaration = [{
// TODO: remove post migrating callers.
Region &body() { return getRegion(); }

// The value stored in memref[ivs].
Value getCurrentValue() {
return getRegion().getArgument(0);
}
MemRefType getMemRefType() {
return getMemref().getType().cast<MemRefType>();
}
}];
}

def AtomicYieldOp : Std_Op<"atomic_yield", [
HasParent<"GenericAtomicRMWOp">,
NoSideEffect,
Terminator
]> {
let summary = "yield operation for GenericAtomicRMWOp";
let description = [{
"atomic_yield" yields an SSA value from a GenericAtomicRMWOp region.
}];

let arguments = (ins AnyType:$result);
let assemblyFormat = "$result attr-dict `:` type($result)";
}

//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
Expand Down
149 changes: 134 additions & 15 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Expand Up @@ -412,6 +412,139 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
}
};

/// Common base for load and store operations on MemRefs. Restricts the match
/// to supported MemRef types. Provides functionality to emit code accessing a
/// specific element of the underlying data buffer.
template <typename Derived>
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
using Base = LoadStoreOpLowering<Derived>;

LogicalResult match(Derived op) const override {
MemRefType type = op.getMemRefType();
return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
}
};

/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
/// retried until it succeeds in atomically storing a new value into memory.
///
/// +---------------------------------+
/// | <code before the AtomicRMWOp> |
/// | <compute initial %loaded> |
/// | br loop(%loaded) |
/// +---------------------------------+
/// |
/// -------| |
/// | v v
/// | +--------------------------------+
/// | | loop(%loaded): |
/// | | <body contents> |
/// | | %pair = cmpxchg |
/// | | %ok = %pair[0] |
/// | | %new = %pair[1] |
/// | | cond_br %ok, end, loop(%new) |
/// | +--------------------------------+
/// | | |
/// |----------- |
/// v
/// +--------------------------------+
/// | end: |
/// | <code after the AtomicRMWOp> |
/// +--------------------------------+
///
struct GenericAtomicRMWOpLowering
: public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
using Base::Base;

LogicalResult
matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = atomicOp.getLoc();
Type valueType = typeConverter->convertType(atomicOp.getResult().getType());

// Split the block into initial, loop, and ending parts.
auto *initBlock = rewriter.getInsertionBlock();
auto *loopBlock = rewriter.createBlock(
initBlock->getParent(), std::next(Region::iterator(initBlock)),
valueType, loc);
auto *endBlock = rewriter.createBlock(
loopBlock->getParent(), std::next(Region::iterator(loopBlock)));

// Operations range to be moved to `endBlock`.
auto opsToMoveStart = atomicOp->getIterator();
auto opsToMoveEnd = initBlock->back().getIterator();

// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter);
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);

// Prepare the body of the loop block.
rewriter.setInsertionPointToStart(loopBlock);

// Clone the GenericAtomicRMWOp region and extract the result.
auto loopArgument = loopBlock->getArgument(0);
BlockAndValueMapping mapping;
mapping.map(atomicOp.getCurrentValue(), loopArgument);
Block &entryBlock = atomicOp.body().front();
for (auto &nestedOp : entryBlock.without_terminator()) {
Operation *clone = rewriter.clone(nestedOp, mapping);
mapping.map(nestedOp.getResults(), clone->getResults());
}
Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));

// Prepare the epilog of the loop block.
// Append the cmpxchg op to the end of the loop block.
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
auto boolType = IntegerType::get(rewriter.getContext(), 1);
auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
{valueType, boolType});
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, pairType, dataPtr, loopArgument, result, successOrdering,
failureOrdering);
// Extract the %new_loaded and %ok values from the pair.
Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
Value ok = rewriter.create<LLVM::ExtractValueOp>(
loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));

// Conditionally branch to the end or back to the loop depending on %ok.
rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
loopBlock, newLoaded);

rewriter.setInsertionPointToEnd(endBlock);
moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
std::next(opsToMoveEnd), rewriter);

// The 'result' of the atomic_rmw op is the newly loaded value.
rewriter.replaceOp(atomicOp, {newLoaded});

return success();
}

private:
// Clones a segment of ops [start, end) and erases the original.
void moveOpsRange(ValueRange oldResult, ValueRange newResult,
Block::iterator start, Block::iterator end,
ConversionPatternRewriter &rewriter) const {
BlockAndValueMapping mapping;
mapping.map(oldResult, newResult);
SmallVector<Operation *, 2> opsToErase;
for (auto it = start; it != end; ++it) {
rewriter.clone(*it, mapping);
opsToErase.push_back(&*it);
}
for (auto *it : opsToErase)
rewriter.eraseOp(it);
}
};

/// Returns the LLVM type of the global variable given the memref type `type`.
static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
LLVMTypeConverter &typeConverter) {
Expand Down Expand Up @@ -520,21 +653,6 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
}
};

// Common base for load and store operations on MemRefs. Restricts the match
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
template <typename Derived>
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
using Base = LoadStoreOpLowering<Derived>;

LogicalResult match(Derived op) const override {
MemRefType type = op.getMemRefType();
return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
}
};

// Load operation is lowered to obtaining a pointer to the indexed element
// and loading it.
struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
Expand Down Expand Up @@ -1683,6 +1801,7 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
AtomicRMWOpLowering,
AssumeAlignmentOpLowering,
DimOpLowering,
GenericAtomicRMWOpLowering,
GlobalMemrefOpLowering,
GetGlobalMemrefOpLowering,
LoadOpLowering,
Expand Down

0 comments on commit 632a4f8

Please sign in to comment.