149 changes: 134 additions & 15 deletions mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,139 @@ struct DimOpLowering : public ConvertOpToLLVMPattern<memref::DimOp> {
}
};

/// Common base for load and store operations on MemRefs. Restricts the match
/// to supported MemRef types. Provides functionality to emit code accessing a
/// specific element of the underlying data buffer.
template <typename Derived>
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
using Base = LoadStoreOpLowering<Derived>;

LogicalResult match(Derived op) const override {
MemRefType type = op.getMemRefType();
return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
}
};

/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
/// retried until it succeeds in atomically storing a new value into memory.
///
/// +---------------------------------+
/// | <code before the AtomicRMWOp> |
/// | <compute initial %loaded> |
/// | br loop(%loaded) |
/// +---------------------------------+
/// |
/// -------| |
/// | v v
/// | +--------------------------------+
/// | | loop(%loaded): |
/// | | <body contents> |
/// | | %pair = cmpxchg |
/// | | %ok = %pair[0] |
/// | | %new = %pair[1] |
/// | | cond_br %ok, end, loop(%new) |
/// | +--------------------------------+
/// | | |
/// |----------- |
/// v
/// +--------------------------------+
/// | end: |
/// | <code after the AtomicRMWOp> |
/// +--------------------------------+
///
struct GenericAtomicRMWOpLowering
: public LoadStoreOpLowering<memref::GenericAtomicRMWOp> {
using Base::Base;

LogicalResult
matchAndRewrite(memref::GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = atomicOp.getLoc();
Type valueType = typeConverter->convertType(atomicOp.getResult().getType());

// Split the block into initial, loop, and ending parts.
auto *initBlock = rewriter.getInsertionBlock();
auto *loopBlock = rewriter.createBlock(
initBlock->getParent(), std::next(Region::iterator(initBlock)),
valueType, loc);
auto *endBlock = rewriter.createBlock(
loopBlock->getParent(), std::next(Region::iterator(loopBlock)));

// Operations range to be moved to `endBlock`.
auto opsToMoveStart = atomicOp->getIterator();
auto opsToMoveEnd = initBlock->back().getIterator();

// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = atomicOp.memref().getType().cast<MemRefType>();
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter);
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);

// Prepare the body of the loop block.
rewriter.setInsertionPointToStart(loopBlock);

// Clone the GenericAtomicRMWOp region and extract the result.
auto loopArgument = loopBlock->getArgument(0);
BlockAndValueMapping mapping;
mapping.map(atomicOp.getCurrentValue(), loopArgument);
Block &entryBlock = atomicOp.body().front();
for (auto &nestedOp : entryBlock.without_terminator()) {
Operation *clone = rewriter.clone(nestedOp, mapping);
mapping.map(nestedOp.getResults(), clone->getResults());
}
Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));

// Prepare the epilog of the loop block.
// Append the cmpxchg op to the end of the loop block.
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
auto boolType = IntegerType::get(rewriter.getContext(), 1);
auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
{valueType, boolType});
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, pairType, dataPtr, loopArgument, result, successOrdering,
failureOrdering);
// Extract the %new_loaded and %ok values from the pair.
Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
Value ok = rewriter.create<LLVM::ExtractValueOp>(
loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));

// Conditionally branch to the end or back to the loop depending on %ok.
rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
loopBlock, newLoaded);

rewriter.setInsertionPointToEnd(endBlock);
moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
std::next(opsToMoveEnd), rewriter);

// The 'result' of the atomic_rmw op is the newly loaded value.
rewriter.replaceOp(atomicOp, {newLoaded});

return success();
}

private:
// Clones a segment of ops [start, end) and erases the original.
void moveOpsRange(ValueRange oldResult, ValueRange newResult,
Block::iterator start, Block::iterator end,
ConversionPatternRewriter &rewriter) const {
BlockAndValueMapping mapping;
mapping.map(oldResult, newResult);
SmallVector<Operation *, 2> opsToErase;
for (auto it = start; it != end; ++it) {
rewriter.clone(*it, mapping);
opsToErase.push_back(&*it);
}
for (auto *it : opsToErase)
rewriter.eraseOp(it);
}
};

/// Returns the LLVM type of the global variable given the memref type `type`.
static Type convertGlobalMemrefTypeToLLVM(MemRefType type,
LLVMTypeConverter &typeConverter) {
Expand Down Expand Up @@ -520,21 +653,6 @@ struct GetGlobalMemrefOpLowering : public AllocLikeOpLLVMLowering {
}
};

// Common base for load and store operations on MemRefs. Restricts the match
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
template <typename Derived>
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
using Base = LoadStoreOpLowering<Derived>;

LogicalResult match(Derived op) const override {
MemRefType type = op.getMemRefType();
return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
}
};

// Load operation is lowered to obtaining a pointer to the indexed element
// and loading it.
struct LoadOpLowering : public LoadStoreOpLowering<memref::LoadOp> {
Expand Down Expand Up @@ -1683,6 +1801,7 @@ void mlir::populateMemRefToLLVMConversionPatterns(LLVMTypeConverter &converter,
AtomicRMWOpLowering,
AssumeAlignmentOpLowering,
DimOpLowering,
GenericAtomicRMWOpLowering,
GlobalMemrefOpLowering,
GetGlobalMemrefOpLowering,
LoadOpLowering,
Expand Down
135 changes: 0 additions & 135 deletions mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -565,21 +565,6 @@ struct UnrealizedConversionCastOpLowering
}
};

// Common base for load and store operations on MemRefs. Restricts the match
// to supported MemRef types. Provides functionality to emit code accessing a
// specific element of the underlying data buffer.
template <typename Derived>
struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
using ConvertOpToLLVMPattern<Derived>::ConvertOpToLLVMPattern;
using ConvertOpToLLVMPattern<Derived>::isConvertibleAndHasIdentityMaps;
using Base = LoadStoreOpLowering<Derived>;

LogicalResult match(Derived op) const override {
MemRefType type = op.getMemRefType();
return isConvertibleAndHasIdentityMaps(type) ? success() : failure();
}
};

// Base class for LLVM IR lowering terminator operations with successors.
template <typename SourceOp, typename TargetOp>
struct OneToOneLLVMTerminatorLowering
Expand Down Expand Up @@ -771,125 +756,6 @@ struct SplatNdOpLowering : public ConvertOpToLLVMPattern<SplatOp> {
}
};

/// Wrap a llvm.cmpxchg operation in a while loop so that the operation can be
/// retried until it succeeds in atomically storing a new value into memory.
///
/// +---------------------------------+
/// | <code before the AtomicRMWOp> |
/// | <compute initial %loaded> |
/// | br loop(%loaded) |
/// +---------------------------------+
/// |
/// -------| |
/// | v v
/// | +--------------------------------+
/// | | loop(%loaded): |
/// | | <body contents> |
/// | | %pair = cmpxchg |
/// | | %ok = %pair[0] |
/// | | %new = %pair[1] |
/// | | cond_br %ok, end, loop(%new) |
/// | +--------------------------------+
/// | | |
/// |----------- |
/// v
/// +--------------------------------+
/// | end: |
/// | <code after the AtomicRMWOp> |
/// +--------------------------------+
///
struct GenericAtomicRMWOpLowering
: public LoadStoreOpLowering<GenericAtomicRMWOp> {
using Base::Base;

LogicalResult
matchAndRewrite(GenericAtomicRMWOp atomicOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

auto loc = atomicOp.getLoc();
Type valueType = typeConverter->convertType(atomicOp.getResult().getType());

// Split the block into initial, loop, and ending parts.
auto *initBlock = rewriter.getInsertionBlock();
auto *loopBlock = rewriter.createBlock(
initBlock->getParent(), std::next(Region::iterator(initBlock)),
valueType, loc);
auto *endBlock = rewriter.createBlock(
loopBlock->getParent(), std::next(Region::iterator(loopBlock)));

// Operations range to be moved to `endBlock`.
auto opsToMoveStart = atomicOp->getIterator();
auto opsToMoveEnd = initBlock->back().getIterator();

// Compute the loaded value and branch to the loop block.
rewriter.setInsertionPointToEnd(initBlock);
auto memRefType = atomicOp.getMemref().getType().cast<MemRefType>();
auto dataPtr = getStridedElementPtr(loc, memRefType, adaptor.getMemref(),
adaptor.getIndices(), rewriter);
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);

// Prepare the body of the loop block.
rewriter.setInsertionPointToStart(loopBlock);

// Clone the GenericAtomicRMWOp region and extract the result.
auto loopArgument = loopBlock->getArgument(0);
BlockAndValueMapping mapping;
mapping.map(atomicOp.getCurrentValue(), loopArgument);
Block &entryBlock = atomicOp.body().front();
for (auto &nestedOp : entryBlock.without_terminator()) {
Operation *clone = rewriter.clone(nestedOp, mapping);
mapping.map(nestedOp.getResults(), clone->getResults());
}
Value result = mapping.lookup(entryBlock.getTerminator()->getOperand(0));

// Prepare the epilog of the loop block.
// Append the cmpxchg op to the end of the loop block.
auto successOrdering = LLVM::AtomicOrdering::acq_rel;
auto failureOrdering = LLVM::AtomicOrdering::monotonic;
auto boolType = IntegerType::get(rewriter.getContext(), 1);
auto pairType = LLVM::LLVMStructType::getLiteral(rewriter.getContext(),
{valueType, boolType});
auto cmpxchg = rewriter.create<LLVM::AtomicCmpXchgOp>(
loc, pairType, dataPtr, loopArgument, result, successOrdering,
failureOrdering);
// Extract the %new_loaded and %ok values from the pair.
Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
Value ok = rewriter.create<LLVM::ExtractValueOp>(
loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));

// Conditionally branch to the end or back to the loop depending on %ok.
rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
loopBlock, newLoaded);

rewriter.setInsertionPointToEnd(endBlock);
moveOpsRange(atomicOp.getResult(), newLoaded, std::next(opsToMoveStart),
std::next(opsToMoveEnd), rewriter);

// The 'result' of the atomic_rmw op is the newly loaded value.
rewriter.replaceOp(atomicOp, {newLoaded});

return success();
}

private:
// Clones a segment of ops [start, end) and erases the original.
void moveOpsRange(ValueRange oldResult, ValueRange newResult,
Block::iterator start, Block::iterator end,
ConversionPatternRewriter &rewriter) const {
BlockAndValueMapping mapping;
mapping.map(oldResult, newResult);
SmallVector<Operation *, 2> opsToErase;
for (auto it = start; it != end; ++it) {
rewriter.clone(*it, mapping);
opsToErase.push_back(&*it);
}
for (auto *it : opsToErase)
rewriter.eraseOp(it);
}
};

} // namespace

void mlir::populateStdToLLVMFuncOpConversionPattern(
Expand All @@ -911,7 +777,6 @@ void mlir::populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
CallOpLowering,
CondBranchOpLowering,
ConstantOpLowering,
GenericAtomicRMWOpLowering,
ReturnOpLowering,
SelectOpLowering,
SplatOpLowering,
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Complex/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ add_mlir_dialect_library(MLIRComplex
LINK_LIBS PUBLIC
MLIRArithmetic
MLIRDialect
MLIRInferTypeOpInterface
MLIRIR
MLIRStandard
)
8 changes: 4 additions & 4 deletions mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"

using namespace mlir;

Expand All @@ -25,9 +24,10 @@ Operation *complex::ComplexDialect::materializeConstant(OpBuilder &builder,
Attribute value,
Type type,
Location loc) {
// TODO complex.constant
if (type.isa<ComplexType>())
return builder.create<ConstantOp>(loc, type, value);
if (complex::ConstantOp::isBuildableWith(value, type)) {
return builder.create<complex::ConstantOp>(loc, type,
value.cast<ArrayAttr>());
}
if (arith::ConstantOp::isBuildableWith(value, type))
return builder.create<arith::ConstantOp>(loc, type, value);
return nullptr;
Expand Down
64 changes: 61 additions & 3 deletions mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,54 @@ using namespace mlir;
using namespace mlir::complex;

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
// ConstantOp
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
OpFoldResult ConstantOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
return getValue();
}

void ConstantOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "cst");
}

bool ConstantOp::isBuildableWith(Attribute value, Type type) {
if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
auto complexTy = type.dyn_cast<ComplexType>();
if (!complexTy)
return false;
auto complexEltTy = complexTy.getElementType();
return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
arrAttr[1].getType() == complexEltTy;
}
return false;
}

static LogicalResult verify(ConstantOp op) {
ArrayAttr arrayAttr = op.getValue();
if (arrayAttr.size() != 2) {
return op.emitOpError(
"requires 'value' to be a complex constant, represented as array of "
"two values");
}

auto complexEltTy = op.getType().getElementType();
if (complexEltTy != arrayAttr[0].getType() ||
complexEltTy != arrayAttr[1].getType()) {
return op.emitOpError()
<< "requires attribute's element types (" << arrayAttr[0].getType()
<< ", " << arrayAttr[1].getType()
<< ") to match the element type of the op's return type ("
<< complexEltTy << ")";
}
return success();
}

//===----------------------------------------------------------------------===//
// CreateOp
//===----------------------------------------------------------------------===//

OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 2 && "binary op takes two operands");
Expand All @@ -32,6 +75,10 @@ OpFoldResult CreateOp::fold(ArrayRef<Attribute> operands) {
return {};
}

//===----------------------------------------------------------------------===//
// ImOp
//===----------------------------------------------------------------------===//

OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1 && "unary op takes 1 operand");
ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
Expand All @@ -42,6 +89,10 @@ OpFoldResult ImOp::fold(ArrayRef<Attribute> operands) {
return {};
}

//===----------------------------------------------------------------------===//
// ReOp
//===----------------------------------------------------------------------===//

OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1 && "unary op takes 1 operand");
ArrayAttr arrayAttr = operands[0].dyn_cast_or_null<ArrayAttr>();
Expand All @@ -51,3 +102,10 @@ OpFoldResult ReOp::fold(ArrayRef<Attribute> operands) {
return createOp.getOperand(0);
return {};
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/Complex/IR/ComplexOps.cpp.inc"
1 change: 0 additions & 1 deletion mlir/lib/Dialect/GPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ add_mlir_dialect_library(MLIRGPUOps
MLIRMemRef
MLIRSideEffectInterfaces
MLIRSupport
MLIRLLVMIR
)

add_mlir_dialect_library(MLIRGPUTransforms
Expand Down
24 changes: 15 additions & 9 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "mlir/Dialect/GPU/GPUDialect.h"

#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
Expand Down Expand Up @@ -226,21 +225,28 @@ LogicalResult GPUDialect::verifyOperationAttribute(Operation *op,

// Check that `launch_func` refers to a well-formed kernel function.
Operation *kernelFunc = module.lookupSymbol(launchOp.kernelAttr());
auto kernelGPUFunction = dyn_cast_or_null<gpu::GPUFuncOp>(kernelFunc);
auto kernelLLVMFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(kernelFunc);
if (!kernelGPUFunction && !kernelLLVMFunction)
if (!kernelFunc)
return launchOp.emitOpError("kernel function '")
<< launchOp.kernel() << "' is undefined";
auto kernelConvertedFunction = dyn_cast<FunctionOpInterface>(kernelFunc);
if (!kernelConvertedFunction) {
InFlightDiagnostic diag = launchOp.emitOpError()
<< "referenced kernel '" << launchOp.kernel()
<< "' is not a function";
diag.attachNote(kernelFunc->getLoc()) << "see the kernel definition here";
return diag;
}

if (!kernelFunc->getAttrOfType<mlir::UnitAttr>(
GPUDialect::getKernelFuncAttrName()))
return launchOp.emitOpError("kernel function is missing the '")
<< GPUDialect::getKernelFuncAttrName() << "' attribute";

// TODO: if the kernel function has been converted to
// the LLVM dialect but the caller hasn't (which happens during the
// separate compilation), do not check type correspondence as it would
// require the verifier to be aware of the LLVM type conversion.
if (kernelLLVMFunction)
// TODO: If the kernel isn't a GPU function (which happens during separate
// compilation), do not check type correspondence as it would require the
// verifier to be aware of the type conversion.
auto kernelGPUFunction = dyn_cast<gpu::GPUFuncOp>(kernelFunc);
if (!kernelGPUFunction)
return success();

unsigned actualNumArguments = launchOp.getNumKernelOperands();
Expand Down
83 changes: 83 additions & 0 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,89 @@ static LogicalResult verify(DmaWaitOp op) {
return success();
}

//===----------------------------------------------------------------------===//
// GenericAtomicRMWOp
//===----------------------------------------------------------------------===//

void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
Value memref, ValueRange ivs) {
result.addOperands(memref);
result.addOperands(ivs);

if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
Type elementType = memrefType.getElementType();
result.addTypes(elementType);

Region *bodyRegion = result.addRegion();
bodyRegion->push_back(new Block());
bodyRegion->addArgument(elementType, memref.getLoc());
}
}

static LogicalResult verify(GenericAtomicRMWOp op) {
auto &body = op.getRegion();
if (body.getNumArguments() != 1)
return op.emitOpError("expected single number of entry block arguments");

if (op.getResult().getType() != body.getArgument(0).getType())
return op.emitOpError(
"expected block argument of the same type result type");

bool hasSideEffects =
body.walk([&](Operation *nestedOp) {
if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
return WalkResult::advance();
nestedOp->emitError(
"body of 'memref.generic_atomic_rmw' should contain "
"only operations with no side effects");
return WalkResult::interrupt();
})
.wasInterrupted();
return hasSideEffects ? failure() : success();
}

static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType memref;
Type memrefType;
SmallVector<OpAsmParser::OperandType, 4> ivs;

Type indexType = parser.getBuilder().getIndexType();
if (parser.parseOperand(memref) ||
parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
parser.parseColonType(memrefType) ||
parser.resolveOperand(memref, memrefType, result.operands) ||
parser.resolveOperands(ivs, indexType, result.operands))
return failure();

Region *body = result.addRegion();
if (parser.parseRegion(*body, llvm::None, llvm::None) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
result.types.push_back(memrefType.cast<MemRefType>().getElementType());
return success();
}

static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
p << ' ' << op.memref() << "[" << op.indices()
<< "] : " << op.memref().getType() << ' ';
p.printRegion(op.getRegion());
p.printOptionalAttrDict(op->getAttrs());
}

//===----------------------------------------------------------------------===//
// AtomicYieldOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(AtomicYieldOp op) {
Type parentType = op->getParentOp()->getResultTypes().front();
Type resultType = op.result().getType();
if (parentType != resultType)
return op.emitOpError() << "types mismatch between yield op: " << resultType
<< " and its parent: " << parentType;
return success();
}

//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
Expand Down
124 changes: 6 additions & 118 deletions mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,88 +131,6 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
return failure();
}

//===----------------------------------------------------------------------===//
// GenericAtomicRMWOp
//===----------------------------------------------------------------------===//

void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
Value memref, ValueRange ivs) {
result.addOperands(memref);
result.addOperands(ivs);

if (auto memrefType = memref.getType().dyn_cast<MemRefType>()) {
Type elementType = memrefType.getElementType();
result.addTypes(elementType);

Region *bodyRegion = result.addRegion();
bodyRegion->push_back(new Block());
bodyRegion->addArgument(elementType, memref.getLoc());
}
}

static LogicalResult verify(GenericAtomicRMWOp op) {
auto &body = op.getRegion();
if (body.getNumArguments() != 1)
return op.emitOpError("expected single number of entry block arguments");

if (op.getResult().getType() != body.getArgument(0).getType())
return op.emitOpError(
"expected block argument of the same type result type");

bool hasSideEffects =
body.walk([&](Operation *nestedOp) {
if (MemoryEffectOpInterface::hasNoEffect(nestedOp))
return WalkResult::advance();
nestedOp->emitError("body of 'generic_atomic_rmw' should contain "
"only operations with no side effects");
return WalkResult::interrupt();
})
.wasInterrupted();
return hasSideEffects ? failure() : success();
}

static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType memref;
Type memrefType;
SmallVector<OpAsmParser::OperandType, 4> ivs;

Type indexType = parser.getBuilder().getIndexType();
if (parser.parseOperand(memref) ||
parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) ||
parser.parseColonType(memrefType) ||
parser.resolveOperand(memref, memrefType, result.operands) ||
parser.resolveOperands(ivs, indexType, result.operands))
return failure();

Region *body = result.addRegion();
if (parser.parseRegion(*body, llvm::None, llvm::None) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();
result.types.push_back(memrefType.cast<MemRefType>().getElementType());
return success();
}

static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
p << ' ' << op.getMemref() << "[" << op.getIndices()
<< "] : " << op.getMemref().getType() << ' ';
p.printRegion(op.getRegion());
p.printOptionalAttrDict(op->getAttrs());
}

//===----------------------------------------------------------------------===//
// AtomicYieldOp
//===----------------------------------------------------------------------===//

static LogicalResult verify(AtomicYieldOp op) {
Type parentType = op->getParentOp()->getResultTypes().front();
Type resultType = op.getResult().getType();
if (parentType != resultType)
return op.emitOpError() << "types mismatch between yield op: " << resultType
<< " and its parent: " << parentType;
return success();
}

//===----------------------------------------------------------------------===//
// BranchOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -669,8 +587,8 @@ static void print(OpAsmPrinter &p, ConstantOp &op) {
p << ' ';
p << op.getValue();

// If the value is a symbol reference or Array, print a trailing type.
if (op.getValue().isa<SymbolRefAttr, ArrayAttr>())
// If the value is a symbol reference, print a trailing type.
if (op.getValue().isa<SymbolRefAttr>())
p << " : " << op.getType();
}

Expand All @@ -681,10 +599,9 @@ static ParseResult parseConstantOp(OpAsmParser &parser,
parser.parseAttribute(valueAttr, "value", result.attributes))
return failure();

// If the attribute is a symbol reference or array, then we expect a trailing
// type.
// If the attribute is a symbol reference, then we expect a trailing type.
Type type;
if (!valueAttr.isa<SymbolRefAttr, ArrayAttr>())
if (!valueAttr.isa<SymbolRefAttr>())
type = valueAttr.getType();
else if (parser.parseColonType(type))
return failure();
Expand All @@ -705,24 +622,6 @@ static LogicalResult verify(ConstantOp &op) {
return op.emitOpError() << "requires attribute's type (" << value.getType()
<< ") to match op's return type (" << type << ")";

if (auto complexTy = type.dyn_cast<ComplexType>()) {
auto arrayAttr = value.dyn_cast<ArrayAttr>();
if (!complexTy || arrayAttr.size() != 2)
return op.emitOpError(
"requires 'value' to be a complex constant, represented as array of "
"two values");
auto complexEltTy = complexTy.getElementType();
if (complexEltTy != arrayAttr[0].getType() ||
complexEltTy != arrayAttr[1].getType()) {
return op.emitOpError()
<< "requires attribute's element types (" << arrayAttr[0].getType()
<< ", " << arrayAttr[1].getType()
<< ") to match the element type of the op's return type ("
<< complexEltTy << ")";
}
return success();
}

if (type.isa<FunctionType>()) {
auto fnAttr = value.dyn_cast<FlatSymbolRefAttr>();
if (!fnAttr)
Expand Down Expand Up @@ -769,19 +668,8 @@ bool ConstantOp::isBuildableWith(Attribute value, Type type) {
// SymbolRefAttr can only be used with a function type.
if (value.isa<SymbolRefAttr>())
return type.isa<FunctionType>();
// The attribute must have the same type as 'type'.
if (!value.getType().isa<NoneType>() && value.getType() != type)
return false;
// Finally, check that the attribute kind is handled.
if (auto arrAttr = value.dyn_cast<ArrayAttr>()) {
auto complexTy = type.dyn_cast<ComplexType>();
if (!complexTy)
return false;
auto complexEltTy = complexTy.getElementType();
return arrAttr.size() == 2 && arrAttr[0].getType() == complexEltTy &&
arrAttr[1].getType() == complexEltTy;
}
return value.isa<UnitAttr>();
// Otherwise, this must be a UnitAttr.
return value.isa<UnitAttr>() && type.isa<NoneType>();
}

//===----------------------------------------------------------------------===//
Expand Down
12 changes: 6 additions & 6 deletions mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ namespace {

/// Converts `atomic_rmw` that cannot be lowered to a simple atomic op with
/// AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to
/// `generic_atomic_rmw` with the expanded code.
/// `memref.generic_atomic_rmw` with the expanded code.
///
/// %x = atomic_rmw "maxf" %fval, %F[%i] : (f32, memref<10xf32>) -> f32
///
/// will be lowered to
///
/// %x = std.generic_atomic_rmw %F[%i] : memref<10xf32> {
/// %x = memref.generic_atomic_rmw %F[%i] : memref<10xf32> {
/// ^bb0(%current: f32):
/// %cmp = arith.cmpf "ogt", %current, %fval : f32
/// %new_value = select %cmp, %current, %fval : f32
/// atomic_yield %new_value : f32
/// memref.atomic_yield %new_value : f32
/// }
struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
public:
Expand All @@ -59,16 +59,16 @@ struct AtomicRMWOpConverter : public OpRewritePattern<memref::AtomicRMWOp> {
}

auto loc = op.getLoc();
auto genericOp =
rewriter.create<GenericAtomicRMWOp>(loc, op.memref(), op.indices());
auto genericOp = rewriter.create<memref::GenericAtomicRMWOp>(
loc, op.memref(), op.indices());
OpBuilder bodyBuilder =
OpBuilder::atBlockEnd(genericOp.getBody(), rewriter.getListener());

Value lhs = genericOp.getCurrentValue();
Value rhs = op.value();
Value cmp = bodyBuilder.create<arith::CmpFOp>(loc, predicate, lhs, rhs);
Value select = bodyBuilder.create<SelectOp>(loc, cmp, lhs, rhs);
bodyBuilder.create<AtomicYieldOp>(loc, select);
bodyBuilder.create<memref::AtomicYieldOp>(loc, select);

rewriter.replaceOp(op, genericOp.getResult());
return success();
Expand Down
19 changes: 19 additions & 0 deletions mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -892,3 +892,22 @@ func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, a
%1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>> into memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
return %1 : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
}

// -----

// CHECK-LABEL: func @generic_atomic_rmw
func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) {
%x = memref.generic_atomic_rmw %I[%i] : memref<10xi32> {
^bb0(%old_value : i32):
memref.atomic_yield %old_value : i32
}
// CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm.ptr<i32>
// CHECK-NEXT: llvm.br ^bb1([[init]] : i32)
// CHECK-NEXT: ^bb1([[loaded:%.*]]: i32):
// CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[loaded]]
// CHECK-SAME: acq_rel monotonic : i32
// CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0]
// CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1]
// CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32)
llvm.return
}
27 changes: 0 additions & 27 deletions mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -486,33 +486,6 @@ func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {

// -----

// CHECK-LABEL: func @generic_atomic_rmw
func @generic_atomic_rmw(%I : memref<10xi32>, %i : index) -> i32 {
%x = generic_atomic_rmw %I[%i] : memref<10xi32> {
^bb0(%old_value : i32):
%c1 = arith.constant 1 : i32
atomic_yield %c1 : i32
}
// CHECK: [[init:%.*]] = llvm.load %{{.*}} : !llvm.ptr<i32>
// CHECK-NEXT: llvm.br ^bb1([[init]] : i32)
// CHECK-NEXT: ^bb1([[loaded:%.*]]: i32):
// CHECK-NEXT: [[c1:%.*]] = llvm.mlir.constant(1 : i32)
// CHECK-NEXT: [[pair:%.*]] = llvm.cmpxchg %{{.*}}, [[loaded]], [[c1]]
// CHECK-SAME: acq_rel monotonic : i32
// CHECK-NEXT: [[new:%.*]] = llvm.extractvalue [[pair]][0]
// CHECK-NEXT: [[ok:%.*]] = llvm.extractvalue [[pair]][1]
// CHECK-NEXT: llvm.cond_br [[ok]], ^bb2, ^bb1([[new]] : i32)
// CHECK-NEXT: ^bb2:
%c2 = arith.constant 2 : i32
%add = arith.addi %c2, %x : i32
return %add : i32
// CHECK-NEXT: [[c2:%.*]] = llvm.mlir.constant(2 : i32)
// CHECK-NEXT: [[add:%.*]] = llvm.add [[c2]], [[new]] : i32
// CHECK-NEXT: llvm.return [[add]]
}

// -----

// CHECK-LABEL: func @ceilf(
// CHECK-SAME: f32
func @ceilf(%arg0 : f32) {
Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Complex/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func @create_of_real_and_imag_different_operand(
func @real_of_const() -> f32 {
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
// CHECK-NEXT: return %[[CST]] : f32
%complex = constant [1.0 : f32, 0.0 : f32] : complex<f32>
%complex = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
%1 = complex.re %complex : complex<f32>
return %1 : f32
}
Expand All @@ -47,7 +47,7 @@ func @real_of_create_op() -> f32 {
func @imag_of_const() -> f32 {
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: return %[[CST]] : f32
%complex = constant [1.0 : f32, 0.0 : f32] : complex<f32>
%complex = complex.constant [1.0 : f32, 0.0 : f32] : complex<f32>
%1 = complex.im %complex : complex<f32>
return %1 : f32
}
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Dialect/Complex/invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: mlir-opt -split-input-file %s -verify-diagnostics

func @complex_constant_wrong_array_attribute_length() {
// expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}}
%0 = complex.constant [1.0 : f32] : complex<f32>
return
}

// -----

func @complex_constant_wrong_element_types() {
// expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}}
%0 = complex.constant [1.0 : f32, -1.0 : f32] : complex<f64>
return
}

// -----

func @complex_constant_two_different_element_types() {
// expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}}
%0 = complex.constant [1.0 : f32, -1.0 : f64] : complex<f64>
return
}
7 changes: 6 additions & 1 deletion mlir/test/Dialect/Complex/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
// CHECK-LABEL: func @ops(
// CHECK-SAME: %[[F:.*]]: f32) {
func @ops(%f: f32) {
// CHECK: complex.constant [1.{{.*}}, -1.{{.*}}] : complex<f64>
%cst_f64 = complex.constant [0.1, -1.0] : complex<f64>

// CHECK: complex.constant [1.{{.*}} : f32, -1.{{.*}} : f32] : complex<f32>
%cst_f32 = complex.constant [0.1 : f32, -1.0 : f32] : complex<f32>

// CHECK: %[[C:.*]] = complex.create %[[F]], %[[F]] : complex<f32>
%complex = complex.create %f, %f : complex<f32>

Expand Down Expand Up @@ -51,4 +57,3 @@ func @ops(%f: f32) {
%diff = complex.sub %complex, %complex : complex<f32>
return
}

15 changes: 15 additions & 0 deletions mlir/test/Dialect/GPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,21 @@ module attributes {gpu.container_module} {

// -----

module attributes {gpu.container_module} {
gpu.module @kernels {
// expected-note@+1 {{see the kernel definition here}}
memref.global "private" @kernel_1 : memref<4xi32>
}

func @launch_func_undefined_function(%sz : index) {
// expected-error@+1 {{referenced kernel '@kernels::@kernel_1' is not a function}}
gpu.launch_func @kernels::@kernel_1 blocks in (%sz, %sz, %sz) threads in (%sz, %sz, %sz)
return
}
}

// -----

module attributes {gpu.container_module} {
module @kernels {
gpu.func @kernel_1(%arg1 : !llvm.ptr<f32>) kernel {
Expand Down
60 changes: 60 additions & 0 deletions mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -910,3 +910,63 @@ func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) {
%x = memref.atomic_rmw addi %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32
return
}

// -----

func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{expected single number of entry block arguments}}
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%arg0 : f32, %arg1 : f32):
%c1 = arith.constant 1.0 : f32
memref.atomic_yield %c1 : f32
}
return
}

// -----

func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{expected block argument of the same type result type}}
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : i32):
%c1 = arith.constant 1.0 : f32
memref.atomic_yield %c1 : f32
}
return
}

// -----

func @generic_atomic_rmw_result_type_mismatch(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{failed to verify that result type matches element type of memref}}
%0 = "memref.generic_atomic_rmw"(%I, %i) ({
^bb0(%old_value: f32):
%c1 = arith.constant 1.0 : f32
memref.atomic_yield %c1 : f32
}) : (memref<10xf32>, index) -> i32
return
}

// -----

func @generic_atomic_rmw_has_side_effects(%I: memref<10xf32>, %i : index) {
// expected-error@+4 {{should contain only operations with no side effects}}
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : f32):
%c1 = arith.constant 1.0 : f32
%buf = memref.alloc() : memref<2048xf32>
memref.atomic_yield %c1 : f32
}
}

// -----

func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) {
// expected-error@+4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}}
%x = memref.generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : f32):
%c1 = arith.constant 1 : i32
memref.atomic_yield %c1 : i32
}
return
}
14 changes: 14 additions & 0 deletions mlir/test/Dialect/MemRef/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,17 @@ func @atomic_rmw(%I: memref<10xf32>, %val: f32, %i : index) {
// CHECK: memref.atomic_rmw addf [[VAL]], [[BUF]]{{\[}}[[I]]]
return
}

// CHECK-LABEL: func @generic_atomic_rmw
// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index)
func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {
%x = memref.generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> {
// CHECK-NEXT: memref.generic_atomic_rmw [[BUF]]{{\[}}[[I]], [[J]]] : memref
^bb0(%old_value : f32):
%c1 = arith.constant 1.0 : f32
%out = arith.addf %c1, %old_value : f32
memref.atomic_yield %out : f32
// CHECK: index_attr = 8 : index
} { index_attr = 8 : index }
return
}
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Standard/expand-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ func @atomic_rmw_to_generic(%F: memref<10xf32>, %f: f32, %i: index) -> f32 {
%x = memref.atomic_rmw maxf %f, %F[%i] : (f32, memref<10xf32>) -> f32
return %x : f32
}
// CHECK: %0 = generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
// CHECK: %0 = memref.generic_atomic_rmw %arg0[%arg2] : memref<10xf32> {
// CHECK: ^bb0([[CUR_VAL:%.*]]: f32):
// CHECK: [[CMP:%.*]] = arith.cmpf ogt, [[CUR_VAL]], [[f]] : f32
// CHECK: [[SELECT:%.*]] = select [[CMP]], [[CUR_VAL]], [[f]] : f32
// CHECK: atomic_yield [[SELECT]] : f32
// CHECK: memref.atomic_yield [[SELECT]] : f32
// CHECK: }
// CHECK: return %0 : f32

Expand Down
24 changes: 0 additions & 24 deletions mlir/test/Dialect/Standard/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,6 @@ func @unsupported_attribute() {

// -----

func @complex_constant_wrong_array_attribute_length() {
// expected-error @+1 {{requires 'value' to be a complex constant, represented as array of two values}}
%0 = constant [1.0 : f32] : complex<f32>
return
}

// -----

func @complex_constant_wrong_element_types() {
// expected-error @+1 {{requires attribute's element types ('f32', 'f32') to match the element type of the op's return type ('f64')}}
%0 = constant [1.0 : f32, -1.0 : f32] : complex<f64>
return
}

// -----

func @complex_constant_two_different_element_types() {
// expected-error @+1 {{requires attribute's element types ('f32', 'f64') to match the element type of the op's return type ('f64')}}
%0 = constant [1.0 : f32, -1.0 : f64] : complex<f64>
return
}

// -----

func @return_i32_f32() -> (i32, f32) {
%0 = arith.constant 1 : i32
%1 = arith.constant 1. : f32
Expand Down
12 changes: 0 additions & 12 deletions mlir/test/Dialect/Standard/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,6 @@ func @switch_i64(%flag : i64, %caseOperand : i32) {
return
}

// CHECK-LABEL: func @constant_complex_f32(
func @constant_complex_f32() -> complex<f32> {
%result = constant [0.1 : f32, -1.0 : f32] : complex<f32>
return %result : complex<f32>
}

// CHECK-LABEL: func @constant_complex_f64(
func @constant_complex_f64() -> complex<f64> {
%result = constant [0.1 : f64, -1.0 : f64] : complex<f64>
return %result : complex<f64>
}

// CHECK-LABEL: func @vector_splat_0d(
func @vector_splat_0d(%a: f32) -> vector<f32> {
// CHECK: splat %{{.*}} : vector<f32>
Expand Down
14 changes: 0 additions & 14 deletions mlir/test/IR/core-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -325,20 +325,6 @@ func @unranked_tensor_load_store(%0 : memref<*xi32>, %1 : tensor<*xi32>) {
return
}

// CHECK-LABEL: func @generic_atomic_rmw
// CHECK-SAME: ([[BUF:%.*]]: memref<1x2xf32>, [[I:%.*]]: index, [[J:%.*]]: index)
func @generic_atomic_rmw(%I: memref<1x2xf32>, %i : index, %j : index) {
%x = generic_atomic_rmw %I[%i, %j] : memref<1x2xf32> {
// CHECK-NEXT: generic_atomic_rmw [[BUF]]{{\[}}[[I]], [[J]]] : memref
^bb0(%old_value : f32):
%c1 = arith.constant 1.0 : f32
%out = arith.addf %c1, %old_value : f32
atomic_yield %out : f32
// CHECK: index_attr = 8 : index
} { index_attr = 8 : index }
return
}

// CHECK-LABEL: func @assume_alignment
// CHECK-SAME: %[[MEMREF:.*]]: memref<4x4xf16>
func @assume_alignment(%0: memref<4x4xf16>) {
Expand Down
60 changes: 0 additions & 60 deletions mlir/test/IR/invalid-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -127,63 +127,3 @@ func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
// expected-error@-1 {{expects different type than prior uses}}
return
}

// -----

func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{expected single number of entry block arguments}}
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%arg0 : f32, %arg1 : f32):
%c1 = arith.constant 1.0 : f32
atomic_yield %c1 : f32
}
return
}

// -----

func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{expected block argument of the same type result type}}
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : i32):
%c1 = arith.constant 1.0 : f32
atomic_yield %c1 : f32
}
return
}

// -----

func @generic_atomic_rmw_result_type_mismatch(%I: memref<10xf32>, %i : index) {
// expected-error@+1 {{failed to verify that result type matches element type of memref}}
%0 = "std.generic_atomic_rmw"(%I, %i) ({
^bb0(%old_value: f32):
%c1 = arith.constant 1.0 : f32
atomic_yield %c1 : f32
}) : (memref<10xf32>, index) -> i32
return
}

// -----

func @generic_atomic_rmw_has_side_effects(%I: memref<10xf32>, %i : index) {
// expected-error@+4 {{should contain only operations with no side effects}}
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : f32):
%c1 = arith.constant 1.0 : f32
%buf = memref.alloc() : memref<2048xf32>
atomic_yield %c1 : f32
}
}

// -----

func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) {
// expected-error@+4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}}
%x = generic_atomic_rmw %I[%i] : memref<10xf32> {
^bb0(%old_value : f32):
%c1 = arith.constant 1 : i32
atomic_yield %c1 : i32
}
return
}