54 changes: 24 additions & 30 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

include "mlir/Dialect/SPIRV/SPIRVBase.td"
include "mlir/Analysis/CallInterfaces.td"
include "mlir/Analysis/ControlFlowInterfaces.td"

// -----

def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
def SPV_BranchOp : SPV_Op<"Branch", [
DeclareOpInterfaceMethods<BranchOpInterface>, InFunctionScope,
Terminator]> {
let summary = "Unconditional branch to target block.";

let description = [{
Expand All @@ -41,7 +44,7 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
```
}];

let arguments = (ins);
let arguments = (ins Variadic<SPV_Type>:$targetOperands);

let results = (outs);

Expand All @@ -53,7 +56,8 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {
OpBuilder<
"Builder *, OperationState &state, "
"Block *successor, ValueRange arguments = {}", [{
state.addSuccessor(successor, arguments);
state.addSuccessors(successor);
state.addOperands(arguments);
}]
>
];
Expand All @@ -70,13 +74,16 @@ def SPV_BranchOp : SPV_Op<"Branch", [InFunctionScope, Terminator]> {

let autogenSerialization = 0;

let assemblyFormat = "successors attr-dict";
let assemblyFormat = [{
$target (`(` $targetOperands^ `:` type($targetOperands) `)`)? attr-dict
}];
}

// -----

def SPV_BranchConditionalOp : SPV_Op<"BranchConditional",
[InFunctionScope, Terminator]> {
def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [
AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
InFunctionScope, Terminator]> {
let summary = [{
If Condition is true, branch to true block, otherwise branch to false
block.
Expand Down Expand Up @@ -117,13 +124,15 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional",

let arguments = (ins
SPV_Bool:$condition,
Variadic<SPV_Type>:$trueTargetOperands,
Variadic<SPV_Type>:$falseTargetOperands,
OptionalAttr<I32ArrayAttr>:$branch_weights
);

let results = (outs);

let successors = (successor AnySuccessor:$trueTarget,
AnySuccessor:$falseTarget);
AnySuccessor:$falseTarget);

let builders = [
OpBuilder<
Expand All @@ -132,21 +141,18 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional",
"Block *falseBlock, ValueRange falseArguments, "
"Optional<std::pair<uint32_t, uint32_t>> weights = {}",
[{
state.addOperands(condition);
state.addSuccessor(trueBlock, trueArguments);
state.addSuccessor(falseBlock, falseArguments);
ArrayAttr weightsAttr;
if (weights) {
auto attr =
weightsAttr =
builder->getI32ArrayAttr({static_cast<int32_t>(weights->first),
static_cast<int32_t>(weights->second)});
state.addAttribute("branch_weights", attr);
}
build(builder, state, condition, trueArguments, falseArguments,
weightsAttr, trueBlock, falseBlock);
}]
>
];

let skipDefaultBuilders = 1;

let autogenSerialization = 0;

let extraClassDeclaration = [{
Expand All @@ -161,34 +167,22 @@ def SPV_BranchConditionalOp : SPV_Op<"BranchConditional",

/// Returns the number of arguments to the true target block.
unsigned getNumTrueBlockArguments() {
return getNumSuccessorOperands(kTrueIndex);
return trueTargetOperands().size();
}

/// Returns the number of arguments to the false target block.
unsigned getNumFalseBlockArguments() {
return getNumSuccessorOperands(kFalseIndex);
return falseTargetOperands().size();
}

// Iterator and range support for true target block arguments.
operand_iterator true_block_argument_begin() {
return operand_begin() + getTrueBlockArgumentIndex();
}
operand_iterator true_block_argument_end() {
return true_block_argument_begin() + getNumTrueBlockArguments();
}
operand_range getTrueBlockArguments() {
return {true_block_argument_begin(), true_block_argument_end()};
return trueTargetOperands();
}

// Iterator and range support for false target block arguments.
operand_iterator false_block_argument_begin() {
return true_block_argument_end();
}
operand_iterator false_block_argument_end() {
return false_block_argument_begin() + getNumFalseBlockArguments();
}
operand_range getFalseBlockArguments() {
return {false_block_argument_begin(), false_block_argument_end()};
return falseTargetOperands();
}

private:
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/SPIRV/SPIRVOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#ifndef MLIR_DIALECT_SPIRV_SPIRVOPS_H_
#define MLIR_DIALECT_SPIRV_SPIRVOPS_H_

#include "mlir/Analysis/ControlFlowInterfaces.h"
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
#include "mlir/IR/Function.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/StandardOps/IR/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#define MLIR_DIALECT_STANDARDOPS_IR_OPS_H

#include "mlir/Analysis/CallInterfaces.h"
#include "mlir/Analysis/ControlFlowInterfaces.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
Expand Down
85 changes: 49 additions & 36 deletions mlir/include/mlir/Dialect/StandardOps/IR/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define STANDARD_OPS

include "mlir/Analysis/CallInterfaces.td"
include "mlir/Analysis/ControlFlowInterfaces.td"
include "mlir/IR/OpAsmInterface.td"

def Std_Dialect : Dialect {
Expand Down Expand Up @@ -331,7 +332,8 @@ def AtomicRMWOp : Std_Op<"atomic_rmw", [
// BranchOp
//===----------------------------------------------------------------------===//

def BranchOp : Std_Op<"br", [Terminator]> {
def BranchOp : Std_Op<"br",
[DeclareOpInterfaceMethods<BranchOpInterface>, Terminator]> {
let summary = "branch operation";
let description = [{
The "br" operation represents a branch operation in a function.
Expand All @@ -345,10 +347,13 @@ def BranchOp : Std_Op<"br", [Terminator]> {
^bb3(%3: tensor<*xf32>):
}];

let arguments = (ins Variadic<AnyType>:$destOperands);
let successors = (successor AnySuccessor:$dest);

let builders = [OpBuilder<"Builder *, OperationState &result, Block *dest", [{
result.addSuccessor(dest, llvm::None);
let builders = [OpBuilder<"Builder *, OperationState &result, Block *dest, "
"ValueRange destOperands = {}", [{
result.addSuccessors(dest);
result.addOperands(destOperands);
}]>];

// BranchOp is fully verified by traits.
Expand All @@ -363,7 +368,9 @@ def BranchOp : Std_Op<"br", [Terminator]> {
}];

let hasCanonicalizer = 1;
let assemblyFormat = "$dest attr-dict";
let assemblyFormat = [{
$dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict
}];
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -668,7 +675,9 @@ def CmpIOp : Std_Op<"cmpi",
// CondBranchOp
//===----------------------------------------------------------------------===//

def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
def CondBranchOp : Std_Op<"cond_br",
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
Terminator]> {
let summary = "conditional branch operation";
let description = [{
The "cond_br" operation represents a conditional branch operation in a
Expand All @@ -685,9 +694,24 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
...
}];

let arguments = (ins I1:$condition);
let arguments = (ins I1:$condition,
Variadic<AnyType>:$trueDestOperands,
Variadic<AnyType>:$falseDestOperands);
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);

let builders = [OpBuilder<
"Builder *builder, OperationState &result, Value condition,"
"Block *trueDest, ValueRange trueOperands,"
"Block *falseDest, ValueRange falseOperands", [{
build(builder, result, condition, trueOperands, falseOperands, trueDest,
falseDest);
}]>, OpBuilder<
"Builder *builder, OperationState &result, Value condition,"
"Block *trueDest, Block *falseDest, ValueRange falseOperands = {}", [{
build(builder, result, condition, trueDest, ValueRange(), falseDest,
falseOperands);
}]>];

// CondBranchOp is fully verified by traits.
let verifier = ?;

Expand Down Expand Up @@ -719,23 +743,13 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
setOperand(getTrueDestOperandIndex() + idx, value);
}

operand_iterator true_operand_begin() {
return operand_begin() + getTrueDestOperandIndex();
}
operand_iterator true_operand_end() {
return true_operand_begin() + getNumTrueOperands();
}
operand_range getTrueOperands() {
return {true_operand_begin(), true_operand_end()};
}
operand_range getTrueOperands() { return trueDestOperands(); }

unsigned getNumTrueOperands() {
return getNumSuccessorOperands(trueIndex);
}
unsigned getNumTrueOperands() { return getTrueOperands().size(); }

/// Erase the operand at 'index' from the true operand list.
void eraseTrueOperand(unsigned index) {
getOperation()->eraseSuccessorOperand(trueIndex, index);
eraseSuccessorOperand(trueIndex, index);
}

// Accessors for operands to the 'false' destination.
Expand All @@ -748,21 +762,13 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
setOperand(getFalseDestOperandIndex() + idx, value);
}

operand_iterator false_operand_begin() { return true_operand_end(); }
operand_iterator false_operand_end() {
return false_operand_begin() + getNumFalseOperands();
}
operand_range getFalseOperands() {
return {false_operand_begin(), false_operand_end()};
}
operand_range getFalseOperands() { return falseDestOperands(); }

unsigned getNumFalseOperands() {
return getNumSuccessorOperands(falseIndex);
}
unsigned getNumFalseOperands() { return getFalseOperands().size(); }

/// Erase the operand at 'index' from the false operand list.
void eraseFalseOperand(unsigned index) {
getOperation()->eraseSuccessorOperand(falseIndex, index);
eraseSuccessorOperand(falseIndex, index);
}

private:
Expand All @@ -776,7 +782,12 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
}];

let hasCanonicalizer = 1;
let assemblyFormat = "$condition `,` successors attr-dict";
let assemblyFormat = [{
$condition `,`
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
attr-dict
}];
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1710,16 +1721,18 @@ def SubViewOp : Std_Op<"subview", [AttrSizedOperandSegments, NoSideEffect]> {
}
}];

// TODO(b/144779634, ravishankarm) : Use different arguments for
// offsets, sizes and strides.
let arguments = (ins
AnyMemRef:$source,
Variadic<Index>:$offsets,
Variadic<Index>:$sizes,
Variadic<Index>:$strides,
I32ElementsAttr:$operand_segment_sizes
Variadic<Index>:$strides
);
let results = (outs AnyMemRef);
let results = (outs AnyMemRef:$result);

let assemblyFormat = [{
$source `[` $offsets `]` `[` $sizes `]` `[` $strides `]` attr-dict `:`
type($source) `to` type($result)
}];

let builders = [
OpBuilder<
Expand Down
10 changes: 7 additions & 3 deletions mlir/include/mlir/Dialect/VectorOps/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -713,9 +713,8 @@ def Vector_ReshapeOp :
Vector_Op<"reshape", [AttrSizedOperandSegments, NoSideEffect]>,
Arguments<(ins AnyVector:$vector, Variadic<Index>:$input_shape,
Variadic<Index>:$output_shape,
I64ArrayAttr:$fixed_vector_sizes,
I32ElementsAttr:$operand_segment_sizes)>,
Results<(outs AnyVector)> {
I64ArrayAttr:$fixed_vector_sizes)>,
Results<(outs AnyVector:$result)> {
let summary = "vector reshape operation";
let description = [{
Reshapes its vector operand from 'input_shape' to 'output_shape' maintaining
Expand Down Expand Up @@ -822,6 +821,11 @@ def Vector_ReshapeOp :
static StringRef getInputShapeAttrName() { return "input_shape"; }
static StringRef getOutputShapeAttrName() { return "output_shape"; }
}];

let assemblyFormat = [{
$vector `,` `[` $input_shape `]` `,` `[` $output_shape `]` `,`
$fixed_vector_sizes attr-dict `:` type($vector) `to` type($result)
}];
}

def Vector_StridedSliceOp :
Expand Down
16 changes: 10 additions & 6 deletions mlir/include/mlir/IR/Block.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#include "mlir/IR/Visitors.h"

namespace mlir {
class TypeRange;
template <typename ValueRangeT> class ValueTypeRange;

/// `Block` represents an ordered list of `Operation`s.
class Block : public IRObjectWithUseList<BlockOperand>,
public llvm::ilist_node_with_parent<Block, Region> {
Expand Down Expand Up @@ -67,6 +70,9 @@ class Block : public IRObjectWithUseList<BlockOperand>,

BlockArgListType getArguments() { return arguments; }

/// Return a range containing the types of the arguments for this block.
ValueTypeRange<BlockArgListType> getArgumentTypes();

using args_iterator = BlockArgListType::iterator;
using reverse_args_iterator = BlockArgListType::reverse_iterator;
args_iterator args_begin() { return getArguments().begin(); }
Expand All @@ -85,15 +91,13 @@ class Block : public IRObjectWithUseList<BlockOperand>,
BlockArgument insertArgument(args_iterator it, Type type);

/// Add one argument to the argument list for each type specified in the list.
iterator_range<args_iterator> addArguments(ArrayRef<Type> types);
iterator_range<args_iterator> addArguments(TypeRange types);

// Add one value to the argument list at the specified position.
/// Add one value to the argument list at the specified position.
BlockArgument insertArgument(unsigned index, Type type);

/// Erase the argument at 'index' and remove it from the argument list. If
/// 'updatePredTerms' is set to true, this argument is also removed from the
/// terminators of each predecessor to this block.
void eraseArgument(unsigned index, bool updatePredTerms = true);
/// Erase the argument at 'index' and remove it from the argument list.
void eraseArgument(unsigned index);

unsigned getNumArguments() { return arguments.size(); }
BlockArgument getArgument(unsigned i) { return arguments[i]; }
Expand Down
158 changes: 123 additions & 35 deletions mlir/include/mlir/IR/OpDefinition.h
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ LogicalResult verifyResultsAreBoolLike(Operation *op);
LogicalResult verifyResultsAreFloatLike(Operation *op);
LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
LogicalResult verifyIsTerminator(Operation *op);
LogicalResult verifyZeroSuccessor(Operation *op);
LogicalResult verifyOneSuccessor(Operation *op);
LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
} // namespace impl
Expand Down Expand Up @@ -410,6 +414,9 @@ class TraitBase {
}
};

//===----------------------------------------------------------------------===//
// Operand Traits

namespace detail {
/// Utility trait base that provides accessors for derived traits that have
/// multiple operands.
Expand Down Expand Up @@ -522,6 +529,9 @@ template <typename ConcreteType>
class VariadicOperands
: public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};

//===----------------------------------------------------------------------===//
// Result Traits

/// This class provides return value APIs for ops that are known to have
/// zero results.
template <typename ConcreteType>
Expand Down Expand Up @@ -644,6 +654,119 @@ template <typename ConcreteType>
class VariadicResults
: public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};

//===----------------------------------------------------------------------===//
// Terminator Traits

/// This class provides the API for ops that are known to be terminators.
template <typename ConcreteType>
class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
public:
static AbstractOperation::OperationProperties getTraitProperties() {
return static_cast<AbstractOperation::OperationProperties>(
OperationProperty::Terminator);
}
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyIsTerminator(op);
}
};

/// This class provides verification for ops that are known to have zero
/// successors.
template <typename ConcreteType>
class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyZeroSuccessor(op);
}
};

namespace detail {
/// Utility trait base that provides accessors for derived traits that have
/// multiple successors.
template <typename ConcreteType, template <typename> class TraitType>
struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
using succ_iterator = Operation::succ_iterator;
using succ_range = SuccessorRange;

/// Return the number of successors.
unsigned getNumSuccessors() {
return this->getOperation()->getNumSuccessors();
}

/// Return the successor at `index`.
Block *getSuccessor(unsigned i) {
return this->getOperation()->getSuccessor(i);
}

/// Set the successor at `index`.
void setSuccessor(Block *block, unsigned i) {
return this->getOperation()->setSuccessor(block, i);
}

/// Successor iterator access.
succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
succ_iterator succ_end() { return this->getOperation()->succ_end(); }
succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
};
} // end namespace detail

/// This class provides APIs for ops that are known to have a single successor.
template <typename ConcreteType>
class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
public:
Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
void setSuccessor(Block *succ) {
this->getOperation()->setSuccessor(succ, 0);
}

static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOneSuccessor(op);
}
};

/// This class provides the API for ops that are known to have a specified
/// number of successors.
template <unsigned N>
class NSuccessors {
public:
static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2");

template <typename ConcreteType>
class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
NSuccessors<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyNSuccessors(op, N);
}
};
};

/// This class provides APIs for ops that are known to have at least a specified
/// number of successors.
template <unsigned N>
class AtLeastNSuccessors {
public:
template <typename ConcreteType>
class Impl
: public detail::MultiSuccessorTraitBase<ConcreteType,
AtLeastNSuccessors<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyAtLeastNSuccessors(op, N);
}
};
};

/// This class provides the API for ops which have an unknown number of
/// successors.
template <typename ConcreteType>
class VariadicSuccessors
: public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
};

//===----------------------------------------------------------------------===//
// Misc Traits

/// This class provides verification for ops that are known to have the same
/// operand shape: all operands are scalars, vectors/tensors of the same
/// shape.
Expand Down Expand Up @@ -789,41 +912,6 @@ class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
}
};

/// This class provides the API for ops that are known to be terminators.
template <typename ConcreteType>
class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
public:
static AbstractOperation::OperationProperties getTraitProperties() {
return static_cast<AbstractOperation::OperationProperties>(
OperationProperty::Terminator);
}
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyIsTerminator(op);
}

unsigned getNumSuccessors() {
return this->getOperation()->getNumSuccessors();
}
unsigned getNumSuccessorOperands(unsigned index) {
return this->getOperation()->getNumSuccessorOperands(index);
}

Block *getSuccessor(unsigned index) {
return this->getOperation()->getSuccessor(index);
}

void setSuccessor(Block *block, unsigned index) {
return this->getOperation()->setSuccessor(block, index);
}

void addSuccessorOperand(unsigned index, Value value) {
return this->getOperation()->addSuccessorOperand(index, value);
}
void addSuccessorOperands(unsigned index, ArrayRef<Value> values) {
return this->getOperation()->addSuccessorOperand(index, values);
}
};

/// This class provides the API for ops that are known to be isolated from
/// above.
template <typename ConcreteType>
Expand Down
28 changes: 18 additions & 10 deletions mlir/include/mlir/IR/OpImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,12 @@ class OpAsmPrinter {
/// provide a valid type for the attribute.
virtual void printAttributeWithoutType(Attribute attr) = 0;

/// Print a successor, and use list, of a terminator operation given the
/// terminator and the successor index.
virtual void printSuccessorAndUseList(Operation *term, unsigned index) = 0;
/// Print the given successor.
virtual void printSuccessor(Block *successor) = 0;

/// Print the successor and its operands.
virtual void printSuccessorAndUseList(Block *successor,
ValueRange succOperands) = 0;

/// If the specified operation has attributes, print out an attribute
/// dictionary with their values. elidedAttrs allows the client to ignore
Expand Down Expand Up @@ -120,8 +123,7 @@ class OpAsmPrinter {

/// Print the complete type of an operation in functional form.
void printFunctionalType(Operation *op) {
printFunctionalType(op->getNonSuccessorOperands().getTypes(),
op->getResultTypes());
printFunctionalType(op->getOperandTypes(), op->getResultTypes());
}
/// Print the two given type ranges in a functional form.
template <typename InputRangeT, typename ResultRangeT>
Expand Down Expand Up @@ -188,6 +190,11 @@ inline OpAsmPrinter &operator<<(OpAsmPrinter &p, bool value) {
return p << (value ? StringRef("true") : "false");
}

inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
p.printSuccessor(value);
return p;
}

template <typename ValueRangeT>
inline OpAsmPrinter &operator<<(OpAsmPrinter &p,
const ValueTypeRange<ValueRangeT> &types) {
Expand Down Expand Up @@ -574,15 +581,16 @@ class OpAsmParser {
// Successor Parsing
//===--------------------------------------------------------------------===//

/// Parse a single operation successor.
virtual ParseResult parseSuccessor(Block *&dest) = 0;

/// Parse an optional operation successor.
virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0;

/// Parse a single operation successor and its operand list.
virtual ParseResult
parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;

/// Parse an optional operation successor and its operand list.
virtual OptionalParseResult
parseOptionalSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value> &operands) = 0;

//===--------------------------------------------------------------------===//
// Type Parsing
//===--------------------------------------------------------------------===//
Expand Down
49 changes: 1 addition & 48 deletions mlir/include/mlir/IR/Operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ class Operation final
}

//===--------------------------------------------------------------------===//
// Terminators
// Successors
//===--------------------------------------------------------------------===//

MutableArrayRef<BlockOperand> getBlockOperands() {
Expand All @@ -387,62 +387,15 @@ class Operation final
succ_iterator successor_end() { return getSuccessors().end(); }
SuccessorRange getSuccessors() { return SuccessorRange(this); }

/// Return the operands of this operation that are *not* successor arguments.
operand_range getNonSuccessorOperands();

operand_range getSuccessorOperands(unsigned index);

Value getSuccessorOperand(unsigned succIndex, unsigned opIndex) {
assert(!isKnownNonTerminator() && "only terminators may have successors");
assert(opIndex < getNumSuccessorOperands(succIndex));
return getOperand(getSuccessorOperandIndex(succIndex) + opIndex);
}

bool hasSuccessors() { return numSuccs != 0; }
unsigned getNumSuccessors() { return numSuccs; }
unsigned getNumSuccessorOperands(unsigned index) {
assert(!isKnownNonTerminator() && "only terminators may have successors");
assert(index < getNumSuccessors());
return getBlockOperands()[index].numSuccessorOperands;
}

Block *getSuccessor(unsigned index) {
assert(index < getNumSuccessors());
return getBlockOperands()[index].get();
}
void setSuccessor(Block *block, unsigned index);

/// Erase a specific operand from the operand list of the successor at
/// 'index'.
void eraseSuccessorOperand(unsigned succIndex, unsigned opIndex) {
assert(succIndex < getNumSuccessors());
assert(opIndex < getNumSuccessorOperands(succIndex));
getOperandStorage().eraseOperand(getSuccessorOperandIndex(succIndex) +
opIndex);
--getBlockOperands()[succIndex].numSuccessorOperands;
}

/// Get the index of the first operand of the successor at the provided
/// index.
unsigned getSuccessorOperandIndex(unsigned index);

/// Return a pair (successorIndex, successorArgIndex) containing the index
/// of the successor that `operandIndex` belongs to and the index of the
/// argument to that successor that `operandIndex` refers to.
///
/// If `operandIndex` is not a successor operand, None is returned.
Optional<std::pair<unsigned, unsigned>>
decomposeSuccessorOperandIndex(unsigned operandIndex);

/// Returns the `BlockArgument` corresponding to operand `operandIndex` in
/// some successor, or None if `operandIndex` isn't a successor operand index.
Optional<BlockArgument> getSuccessorBlockArgument(unsigned operandIndex) {
auto decomposed = decomposeSuccessorOperandIndex(operandIndex);
if (!decomposed.hasValue())
return None;
return getSuccessor(decomposed->first)->getArgument(decomposed->second);
}

//===--------------------------------------------------------------------===//
// Accessors for various properties of operations
//===--------------------------------------------------------------------===//
Expand Down
21 changes: 20 additions & 1 deletion mlir/include/mlir/IR/OperationSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Pattern;
class Region;
class ResultRange;
class RewritePattern;
class SuccessorRange;
class Type;
class Value;
class ValueRange;
Expand Down Expand Up @@ -316,7 +317,12 @@ struct OperationState {
attributes.append(newAttributes.begin(), newAttributes.end());
}

void addSuccessor(Block *successor, ValueRange succOperands);
/// Add an array of successors.
void addSuccessors(ArrayRef<Block *> newSuccessors) {
successors.append(newSuccessors.begin(), newSuccessors.end());
}
void addSuccessors(Block *successor) { successors.push_back(successor); }
void addSuccessors(SuccessorRange newSuccessors);

/// Create a region that should be attached to the operation. These regions
/// can be filled in immediately without waiting for Operation to be
Expand Down Expand Up @@ -563,10 +569,19 @@ class TypeRange
explicit TypeRange(OperandRange values);
explicit TypeRange(ResultRange values);
explicit TypeRange(ValueRange values);
explicit TypeRange(ArrayRef<Value> values);
explicit TypeRange(ArrayRef<BlockArgument> values)
: TypeRange(ArrayRef<Value>(values.data(), values.size())) {}
template <typename ValueRangeT>
TypeRange(ValueTypeRange<ValueRangeT> values)
: TypeRange(ValueRangeT(values.begin().getCurrent(),
values.end().getCurrent())) {}
template <typename Arg,
typename = typename std::enable_if_t<
std::is_constructible<ArrayRef<Type>, Arg>::value>>
TypeRange(Arg &&arg) : TypeRange(ArrayRef<Type>(std::forward<Arg>(arg))) {}
TypeRange(std::initializer_list<Type> types)
: TypeRange(ArrayRef<Type>(types)) {}

private:
/// The owner of the range is either:
Expand Down Expand Up @@ -639,6 +654,10 @@ class OperandRange final
type_range getTypes() const { return {begin(), end()}; }
auto getType() const { return getTypes(); }

/// Return the operand index of the first element of this range. The range
/// must not be empty.
unsigned getBeginOperandIndex() const;

private:
/// See `detail::indexed_accessor_range_base` for details.
static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) {
Expand Down
7 changes: 0 additions & 7 deletions mlir/include/mlir/IR/UseDefLists.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,6 @@ class BlockOperand : public IROperand<BlockOperand, Block *> {

/// Return which operand this is in the operand list of the User.
unsigned getOperandNumber();

private:
/// The number of OpOperands that correspond with this block operand.
unsigned numSuccessorOperands = 0;

/// Allow access to 'numSuccessorOperands'.
friend Operation;
};

//===----------------------------------------------------------------------===//
Expand Down
66 changes: 1 addition & 65 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,44 +227,13 @@ class ConversionPattern : public RewritePattern {
/// Hook for derived classes to implement rewriting. `op` is the (first)
/// operation matched by the pattern, `operands` is a list of rewritten values
/// that are passed to this operation, `rewriter` can be used to emit the new
/// operations. This function must be reimplemented if the
/// ConversionPattern ever needs to replace an operation that does not
/// have successors. This function should not fail. If some specific cases of
/// operations. This function should not fail. If some specific cases of
/// the operation are not supported, these cases should not be matched.
virtual void rewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite");
}

/// Hook for derived classes to implement rewriting. `op` is the (first)
/// operation matched by the pattern, `properOperands` is a list of rewritten
/// values that are passed to the operation itself, `destinations` is a list
/// of (potentially rewritten) successor blocks, `operands` is a list of lists
/// of rewritten values passed to each of the successors, co-indexed with
/// `destinations`, `rewriter` can be used to emit the new operations. It must
/// be reimplemented if the ConversionPattern ever needs to replace a
/// terminator operation that has successors. This function should not fail
/// the pass. If some specific cases of the operation are not supported,
/// these cases should not be matched.
virtual void rewrite(Operation *op, ArrayRef<Value> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite for terminators");
}

/// Hook for derived classes to implement combined matching and rewriting.
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const {
if (!match(op))
return matchFailure();
rewrite(op, properOperands, destinations, operands, rewriter);
return matchSuccess();
}

/// Hook for derived classes to implement combined matching and rewriting.
virtual PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
Expand Down Expand Up @@ -297,21 +266,6 @@ struct OpConversionPattern : public ConversionPattern {
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), operands, rewriter);
}
void rewrite(Operation *op, ArrayRef<Value> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const final {
rewrite(cast<SourceOp>(op), properOperands, destinations, operands,
rewriter);
}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const final {
return matchAndRewrite(cast<SourceOp>(op), properOperands, destinations,
operands, rewriter);
}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Expand All @@ -328,24 +282,6 @@ struct OpConversionPattern : public ConversionPattern {
llvm_unreachable("must override matchAndRewrite or a rewrite method");
}

virtual void rewrite(SourceOp op, ArrayRef<Value> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const {
llvm_unreachable("unimplemented rewrite for terminators");
}

virtual PatternMatchResult
matchAndRewrite(SourceOp op, ArrayRef<Value> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value>> operands,
ConversionPatternRewriter &rewriter) const {
if (!match(op))
return matchFailure();
rewrite(op, properOperands, destinations, operands, rewriter);
return matchSuccess();
}

virtual PatternMatchResult
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ set(LLVM_OPTIONAL_SOURCES
AffineAnalysis.cpp
AffineStructures.cpp
CallGraph.cpp
ControlFlowInterfaces.cpp
Dominance.cpp
InferTypeOpInterface.cpp
Liveness.cpp
Expand All @@ -14,6 +15,7 @@ set(LLVM_OPTIONAL_SOURCES

add_llvm_library(MLIRAnalysis
CallGraph.cpp
ControlFlowInterfaces.cpp
InferTypeOpInterface.cpp
Liveness.cpp
SliceAnalysis.cpp
Expand All @@ -26,6 +28,7 @@ add_llvm_library(MLIRAnalysis
add_dependencies(MLIRAnalysis
MLIRAffineOps
MLIRCallOpInterfacesIncGen
MLIRControlFlowInterfacesIncGen
MLIRTypeInferOpInterfaceIncGen
MLIRLoopOps
)
Expand All @@ -45,6 +48,7 @@ add_llvm_library(MLIRLoopAnalysis
add_dependencies(MLIRLoopAnalysis
MLIRAffineOps
MLIRCallOpInterfacesIncGen
MLIRControlFlowInterfacesIncGen
MLIRTypeInferOpInterfaceIncGen
MLIRLoopOps
)
Expand Down
101 changes: 101 additions & 0 deletions mlir/lib/Analysis/ControlFlowInterfaces.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
//===- ControlFlowInterfaces.h - ControlFlow Interfaces -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/ControlFlowInterfaces.h"
#include "mlir/IR/StandardTypes.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// ControlFlowInterfaces
//===----------------------------------------------------------------------===//

#include "mlir/Analysis/ControlFlowInterfaces.cpp.inc"

//===----------------------------------------------------------------------===//
// BranchOpInterface
//===----------------------------------------------------------------------===//

/// Erase an operand from a branch operation that is used as a successor
/// operand. 'operandIndex' is the operand within 'operands' to be erased.
void mlir::detail::eraseBranchSuccessorOperand(OperandRange operands,
unsigned operandIndex,
Operation *op) {
assert(operandIndex < operands.size() &&
"invalid index for successor operands");

// Erase the operand from the operation.
size_t fullOperandIndex = operands.getBeginOperandIndex() + operandIndex;
op->eraseOperand(fullOperandIndex);

// If this operation has an OperandSegmentSizeAttr, keep it up to date.
auto operandSegmentAttr =
op->getAttrOfType<DenseElementsAttr>("operand_segment_sizes");
if (!operandSegmentAttr)
return;

// Find the segment containing the full operand index and decrement it.
// TODO: This seems like a general utility that could be added somewhere.
SmallVector<int32_t, 4> values(operandSegmentAttr.getValues<int32_t>());
unsigned currentSize = 0;
for (unsigned i = 0, e = values.size(); i != e; ++i) {
currentSize += values[i];
if (fullOperandIndex < currentSize) {
--values[i];
break;
}
}
op->setAttr("operand_segment_sizes",
DenseIntElementsAttr::get(operandSegmentAttr.getType(), values));
}

/// Returns the `BlockArgument` corresponding to operand `operandIndex` in some
/// successor if 'operandIndex' is within the range of 'operands', or None if
/// `operandIndex` isn't a successor operand index.
Optional<BlockArgument> mlir::detail::getBranchSuccessorArgument(
Optional<OperandRange> operands, unsigned operandIndex, Block *successor) {
// Check that the operands are valid.
if (!operands || operands->empty())
return llvm::None;

// Check to ensure that this operand is within the range.
unsigned operandsStart = operands->getBeginOperandIndex();
if (operandIndex < operandsStart ||
operandIndex >= (operandsStart + operands->size()))
return llvm::None;

// Index the successor.
unsigned argIndex = operandIndex - operandsStart;
return successor->getArgument(argIndex);
}

/// Verify that the given operands match those of the given successor block.
LogicalResult
mlir::detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
Optional<OperandRange> operands) {
if (!operands)
return success();

// Check the count.
unsigned operandCount = operands->size();
Block *destBB = op->getSuccessor(succNo);
if (operandCount != destBB->getNumArguments())
return op->emitError() << "branch has " << operandCount
<< " operands for successor #" << succNo
<< ", but target block has "
<< destBB->getNumArguments();

// Check the types.
auto operandIt = operands->begin();
for (unsigned i = 0; i != operandCount; ++i, ++operandIt) {
if ((*operandIt).getType() != destBB->getArgument(i).getType())
return op->emitError() << "type mismatch for bb argument #" << i
<< " of successor #" << succNo;
}
return success();
}
16 changes: 5 additions & 11 deletions mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ struct GPUAllReduceOpLowering : public ConvertToLLVMPattern {

// Add branch before inserted body, into body.
block = block->getNextNode();
rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value>{},
llvm::makeArrayRef(block), ValueRange());
rewriter.create<LLVM::BrOp>(loc, ValueRange(), block);

// Replace all gpu.yield ops with branch out of body.
for (; block != split; block = block->getNextNode()) {
Expand All @@ -100,8 +99,7 @@ struct GPUAllReduceOpLowering : public ConvertToLLVMPattern {
continue;
rewriter.setInsertionPointToEnd(block);
rewriter.replaceOpWithNewOp<LLVM::BrOp>(
terminator, ArrayRef<Value>{}, llvm::makeArrayRef(split),
ValueRange(terminator->getOperand(0)));
terminator, terminator->getOperand(0), split);
}

// Return accumulator result.
Expand Down Expand Up @@ -254,13 +252,10 @@ struct GPUAllReduceOpLowering : public ConvertToLLVMPattern {
Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());

rewriter.setInsertionPointToEnd(currentBlock);
rewriter.create<LLVM::CondBrOp>(loc, llvm::makeArrayRef(condition),
ArrayRef<Block *>{thenBlock, elseBlock});
rewriter.create<LLVM::CondBrOp>(loc, condition, thenBlock, elseBlock);

auto addBranch = [&](ValueRange operands) {
rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value>{},
llvm::makeArrayRef(continueBlock),
llvm::makeArrayRef(operands));
rewriter.create<LLVM::BrOp>(loc, operands, continueBlock);
};

rewriter.setInsertionPointToStart(thenBlock);
Expand Down Expand Up @@ -645,8 +640,7 @@ struct GPUReturnOpLowering : public ConvertToLLVMPattern {
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands,
ArrayRef<Block *>());
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, operands);
return matchSuccess();
}
};
Expand Down
61 changes: 23 additions & 38 deletions mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2185,13 +2185,10 @@ struct OneToOneLLVMTerminatorLowering
using Super = OneToOneLLVMTerminatorLowering<SourceOp, TargetOp>;

PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> properOperands,
ArrayRef<Block *> destinations,
ArrayRef<ArrayRef<Value>> operands,
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
SmallVector<ValueRange, 2> operandRanges(operands.begin(), operands.end());
rewriter.replaceOpWithNewOp<TargetOp>(op, properOperands, destinations,
operandRanges, op->getAttrs());
rewriter.replaceOpWithNewOp<TargetOp>(op, operands, op->getSuccessors(),
op->getAttrs());
return this->matchSuccess();
}
};
Expand All @@ -2213,13 +2210,12 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
// If ReturnOp has 0 or 1 operand, create it and return immediately.
if (numArguments == 0) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, ArrayRef<Value>(), ArrayRef<Block *>(), op->getAttrs());
op, ArrayRef<Type>(), ArrayRef<Value>(), op->getAttrs());
return matchSuccess();
}
if (numArguments == 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, ArrayRef<Value>(operands.front()), ArrayRef<Block *>(),
op->getAttrs());
op, ArrayRef<Type>(), operands.front(), op->getAttrs());
return matchSuccess();
}

Expand All @@ -2234,8 +2230,8 @@ struct ReturnOpLowering : public LLVMLegalizationPattern<ReturnOp> {
op->getLoc(), packedType, packed, operands[i],
rewriter.getI64ArrayAttr(i));
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, llvm::makeArrayRef(packed), ArrayRef<Block *>(), op->getAttrs());
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, ArrayRef<Type>(), packed,
op->getAttrs());
return matchSuccess();
}
};
Expand Down Expand Up @@ -2742,10 +2738,8 @@ struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
auto memRefType = atomicOp.getMemRefType();
auto dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
adaptor.indices(), rewriter, getModule());
auto init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
std::array<Value, 1> brRegionOperands{init};
std::array<ValueRange, 1> brOperands{brRegionOperands};
rewriter.create<LLVM::BrOp>(loc, ArrayRef<Value>{}, loopBlock, brOperands);
Value init = rewriter.create<LLVM::LoadOp>(loc, dataPtr);
rewriter.create<LLVM::BrOp>(loc, init, loopBlock);

// Prepare the body of the loop block.
rewriter.setInsertionPointToStart(loopBlock);
Expand All @@ -2768,19 +2762,14 @@ struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
loc, pairType, dataPtr, loopArgument, select, successOrdering,
failureOrdering);
// Extract the %new_loaded and %ok values from the pair.
auto newLoaded = rewriter.create<LLVM::ExtractValueOp>(
Value newLoaded = rewriter.create<LLVM::ExtractValueOp>(
loc, valueType, cmpxchg, rewriter.getI64ArrayAttr({0}));
auto ok = rewriter.create<LLVM::ExtractValueOp>(
Value ok = rewriter.create<LLVM::ExtractValueOp>(
loc, boolType, cmpxchg, rewriter.getI64ArrayAttr({1}));

// Conditionally branch to the end or back to the loop depending on %ok.
std::array<Value, 1> condBrProperOperands{ok};
std::array<Block *, 2> condBrDestinations{endBlock, loopBlock};
std::array<Value, 1> condBrRegionOperands{newLoaded};
std::array<ValueRange, 2> condBrOperands{ArrayRef<Value>{},
condBrRegionOperands};
rewriter.create<LLVM::CondBrOp>(loc, condBrProperOperands,
condBrDestinations, condBrOperands);
rewriter.create<LLVM::CondBrOp>(loc, ok, endBlock, ArrayRef<Value>(),
loopBlock, newLoaded);

// The 'result' of the atomic_rmw op is the newly loaded value.
rewriter.replaceOp(op, {newLoaded});
Expand All @@ -2792,7 +2781,9 @@ struct AtomicCmpXchgOpLowering : public LoadStoreOpLowering<AtomicRMWOp> {
} // namespace

static void ensureDistinctSuccessors(Block &bb) {
auto *terminator = bb.getTerminator();
Operation *terminator = bb.getTerminator();
if (terminator->getNumSuccessors() < 2)
return;

// Find repeated successors with arguments.
llvm::SmallDenseMap<Block *, SmallVector<int, 4>> successorPositions;
Expand All @@ -2811,21 +2802,15 @@ static void ensureDistinctSuccessors(Block &bb) {
// There is no need to pass arguments to the dummy block because it will be
// dominated by the original block and can therefore use any values defined in
// the original block.
OpBuilder builder(terminator->getContext());
for (const auto &successor : successorPositions) {
const auto &positions = successor.second;
// Start from the second occurrence of a block in the successor list.
for (auto position = std::next(positions.begin()), end = positions.end();
position != end; ++position) {
auto *dummyBlock = new Block();
bb.getParent()->push_back(dummyBlock);
auto builder = OpBuilder(dummyBlock);
SmallVector<Value, 8> operands(
terminator->getSuccessorOperands(*position));
builder.create<BranchOp>(terminator->getLoc(), successor.first, operands);
terminator->setSuccessor(dummyBlock, *position);
for (int i = 0, e = terminator->getNumSuccessorOperands(*position); i < e;
++i)
terminator->eraseSuccessorOperand(*position, i);
for (int position : llvm::drop_begin(successor.second, 1)) {
Block *dummyBlock = builder.createBlock(bb.getParent());
terminator->setSuccessor(dummyBlock, position);
dummyBlock->addArguments(successor.first->getArgumentTypes());
builder.create<BranchOp>(terminator->getLoc(), successor.first,
dummyBlock->getArguments());
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/AffineOps/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1621,9 +1621,8 @@ static LogicalResult verify(AffineIfOp op) {
"symbol count must match");

// Verify that the operands are valid dimension/symbols.
if (failed(verifyDimAndSymbolIdentifiers(
op, op.getOperation()->getNonSuccessorOperands(),
condition.getNumDims())))
if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(),
condition.getNumDims())))
return failure();

// Verify that the entry of each child region does not have arguments.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/LLVMIR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ add_mlir_dialect_library(MLIRLLVMIR
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR
)
add_dependencies(MLIRLLVMIR MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen MLIROpenMP LLVMFrontendOpenMP LLVMAsmParser LLVMCore LLVMSupport)
add_dependencies(MLIRLLVMIR MLIRControlFlowInterfacesIncGen MLIRLLVMOpsIncGen MLIRLLVMConversionsIncGen MLIROpenMP LLVMFrontendOpenMP LLVMAsmParser LLVMCore LLVMSupport)
target_link_libraries(MLIRLLVMIR LLVMAsmParser LLVMCore LLVMSupport LLVMFrontendOpenMP MLIROpenMP MLIRIR)

add_mlir_dialect_library(MLIRNVVMIR
Expand Down
67 changes: 51 additions & 16 deletions mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,28 @@ static ParseResult parseAllocaOp(OpAsmParser &parser, OperationState &result) {
return success();
}

//===----------------------------------------------------------------------===//
// LLVM::BrOp
//===----------------------------------------------------------------------===//

Optional<OperandRange> BrOp::getSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getOperands();
}

bool BrOp::canEraseSuccessorOperand() { return true; }

//===----------------------------------------------------------------------===//
// LLVM::CondBrOp
//===----------------------------------------------------------------------===//

Optional<OperandRange> CondBrOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? trueDestOperands() : falseDestOperands();
}

bool CondBrOp::canEraseSuccessorOperand() { return true; }

//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::LoadOp.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -228,9 +250,16 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
return success();
}

///===----------------------------------------------------------------------===//
/// Verifying/Printing/Parsing for LLVM::InvokeOp.
///===----------------------------------------------------------------------===//
///===---------------------------------------------------------------------===//
/// LLVM::InvokeOp
///===---------------------------------------------------------------------===//

Optional<OperandRange> InvokeOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == 0 ? normalDestOperands() : unwindDestOperands();
}

bool InvokeOp::canEraseSuccessorOperand() { return true; }

static LogicalResult verify(InvokeOp op) {
if (op.getNumResults() > 1)
Expand All @@ -249,7 +278,7 @@ static LogicalResult verify(InvokeOp op) {
return success();
}

static void printInvokeOp(OpAsmPrinter &p, InvokeOp &op) {
static void printInvokeOp(OpAsmPrinter &p, InvokeOp op) {
auto callee = op.callee();
bool isDirect = callee.hasValue();

Expand All @@ -263,17 +292,16 @@ static void printInvokeOp(OpAsmPrinter &p, InvokeOp &op) {

p << '(' << op.getOperands().drop_front(isDirect ? 0 : 1) << ')';
p << " to ";
p.printSuccessorAndUseList(op.getOperation(), 0);
p.printSuccessorAndUseList(op.normalDest(), op.normalDestOperands());
p << " unwind ";
p.printSuccessorAndUseList(op.getOperation(), 1);
p.printSuccessorAndUseList(op.unwindDest(), op.unwindDestOperands());

p.printOptionalAttrDict(op.getAttrs(), {"callee"});

SmallVector<Type, 8> argTypes(
llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1));

p << " : "
<< FunctionType::get(argTypes, op.getResultTypes(), op.getContext());
p.printOptionalAttrDict(op.getAttrs(),
{InvokeOp::getOperandSegmentSizeAttr(), "callee"});
p << " : ";
p.printFunctionalType(
llvm::drop_begin(op.getOperandTypes(), isDirect ? 0 : 1),
op.getResultTypes());
}

/// <operation> ::= `llvm.invoke` (function-id | ssa-use) `(` ssa-use-list `)`
Expand All @@ -287,6 +315,7 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
llvm::SMLoc trailingTypeLoc;
Block *normalDest, *unwindDest;
SmallVector<Value, 4> normalOperands, unwindOperands;
Builder &builder = parser.getBuilder();

// Parse an operand list that will, in practice, contain 0 or 1 operand. In
// case of an indirect call, there will be 1 operand before `(`. In case of a
Expand Down Expand Up @@ -322,7 +351,6 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
return parser.emitError(trailingTypeLoc,
"expected function with 0 or 1 result");

Builder &builder = parser.getBuilder();
auto *llvmDialect =
builder.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
LLVM::LLVMType llvmResultType;
Expand Down Expand Up @@ -361,8 +389,15 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {

result.addTypes(llvmResultType);
}
result.addSuccessor(normalDest, normalOperands);
result.addSuccessor(unwindDest, unwindOperands);
result.addSuccessors({normalDest, unwindDest});
result.addOperands(normalOperands);
result.addOperands(unwindOperands);

result.addAttribute(
InvokeOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr({static_cast<int32_t>(operands.size()),
static_cast<int32_t>(normalOperands.size()),
static_cast<int32_t>(unwindOperands.size())}));
return success();
}

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRSPIRV
)

add_dependencies(MLIRSPIRV
MLIRControlFlowInterfacesIncGen
MLIRSPIRVAvailabilityIncGen
MLIRSPIRVCanonicalizationIncGen
MLIRSPIRVEnumAvailabilityIncGen
Expand Down
48 changes: 37 additions & 11 deletions mlir/lib/Dialect/SPIRV/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -942,16 +942,35 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
return success();
}

//===----------------------------------------------------------------------===//
// spv.BranchOp
//===----------------------------------------------------------------------===//

Optional<OperandRange> spirv::BranchOp::getSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getOperands();
}

bool spirv::BranchOp::canEraseSuccessorOperand() { return true; }

//===----------------------------------------------------------------------===//
// spv.BranchConditionalOp
//===----------------------------------------------------------------------===//

Optional<OperandRange>
spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
assert(index < 2 && "invalid successor index");
return index == kTrueIndex ? getTrueBlockArguments()
: getFalseBlockArguments();
}

bool spirv::BranchConditionalOp::canEraseSuccessorOperand() { return true; }

static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
OperationState &state) {
auto &builder = parser.getBuilder();
OpAsmParser::OperandType condInfo;
Block *dest;
SmallVector<Value, 4> destOperands;

// Parse the condition.
Type boolTy = builder.getI1Type();
Expand All @@ -976,17 +995,24 @@ static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
}

// Parse the true branch.
SmallVector<Value, 4> trueOperands;
if (parser.parseComma() ||
parser.parseSuccessorAndUseList(dest, destOperands))
parser.parseSuccessorAndUseList(dest, trueOperands))
return failure();
state.addSuccessor(dest, destOperands);
state.addSuccessors(dest);
state.addOperands(trueOperands);

// Parse the false branch.
destOperands.clear();
SmallVector<Value, 4> falseOperands;
if (parser.parseComma() ||
parser.parseSuccessorAndUseList(dest, destOperands))
parser.parseSuccessorAndUseList(dest, falseOperands))
return failure();
state.addSuccessor(dest, destOperands);
state.addSuccessors(dest);
state.addOperands(falseOperands);
state.addAttribute(
spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr({1, static_cast<int32_t>(trueOperands.size()),
static_cast<int32_t>(falseOperands.size())}));

return success();
}
Expand All @@ -1004,11 +1030,11 @@ static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
}

printer << ", ";
printer.printSuccessorAndUseList(branchOp.getOperation(),
spirv::BranchConditionalOp::kTrueIndex);
printer.printSuccessorAndUseList(branchOp.getTrueBlock(),
branchOp.getTrueBlockArguments());
printer << ", ";
printer.printSuccessorAndUseList(branchOp.getOperation(),
spirv::BranchConditionalOp::kFalseIndex);
printer.printSuccessorAndUseList(branchOp.getFalseBlock(),
branchOp.getFalseBlockArguments());
}

static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
Expand Down Expand Up @@ -1894,7 +1920,7 @@ static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
return false;

auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
return branchOp && branchOp.getSuccessor(0) == &dstBlock;
return branchOp && branchOp.getSuccessor() == &dstBlock;
}

static LogicalResult verify(spirv::LoopOp loopOp) {
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/StandardOps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRStandardOps
add_dependencies(MLIRStandardOps

MLIRCallOpInterfacesIncGen
MLIRControlFlowInterfacesIncGen
MLIREDSC
MLIRIR
MLIRStandardOpsIncGen
Expand Down
67 changes: 18 additions & 49 deletions mlir/lib/Dialect/StandardOps/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,19 +477,26 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern<BranchOp> {
};
} // end anonymous namespace.

Block *BranchOp::getDest() { return getSuccessor(0); }
Block *BranchOp::getDest() { return getSuccessor(); }

void BranchOp::setDest(Block *block) { return setSuccessor(block, 0); }
void BranchOp::setDest(Block *block) { return setSuccessor(block); }

void BranchOp::eraseOperand(unsigned index) {
getOperation()->eraseSuccessorOperand(0, index);
getOperation()->eraseOperand(index);
}

void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<SimplifyBrToBlockWithSinglePred>(context);
}

Optional<OperandRange> BranchOp::getSuccessorOperands(unsigned index) {
assert(index == 0 && "invalid successor index");
return getOperands();
}

bool BranchOp::canEraseSuccessorOperand() { return true; }

//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -749,6 +756,13 @@ void CondBranchOp::getCanonicalizationPatterns(
results.insert<SimplifyConstCondBranchPred>(context);
}

Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return index == trueIndex ? getTrueOperands() : getFalseOperands();
}

bool CondBranchOp::canEraseSuccessorOperand() { return true; }

//===----------------------------------------------------------------------===//
// Constant*Op
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1811,10 +1825,7 @@ void mlir::SubViewOp::build(Builder *b, OperationState &result, Value source,
ArrayRef<NamedAttribute> attrs) {
if (!resultType)
resultType = inferSubViewResultType(source.getType().cast<MemRefType>());
auto segmentAttr = b->getI32VectorAttr(
{1, static_cast<int>(offsets.size()), static_cast<int32_t>(sizes.size()),
static_cast<int32_t>(strides.size())});
build(b, result, resultType, source, offsets, sizes, strides, segmentAttr);
build(b, result, resultType, source, offsets, sizes, strides);
result.addAttributes(attrs);
}

Expand All @@ -1824,48 +1835,6 @@ void mlir::SubViewOp::build(Builder *b, OperationState &result, Type resultType,
resultType);
}

static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType srcInfo;
SmallVector<OpAsmParser::OperandType, 4> offsetsInfo;
SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
SmallVector<OpAsmParser::OperandType, 4> stridesInfo;
auto indexType = parser.getBuilder().getIndexType();
Type srcType, dstType;
if (parser.parseOperand(srcInfo) ||
parser.parseOperandList(offsetsInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOperandList(stridesInfo, OpAsmParser::Delimiter::Square)) {
return failure();
}

auto builder = parser.getBuilder();
result.addAttribute(
SubViewOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr({1, static_cast<int>(offsetsInfo.size()),
static_cast<int32_t>(sizesInfo.size()),
static_cast<int32_t>(stridesInfo.size())}));

return failure(
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(srcType) ||
parser.resolveOperand(srcInfo, srcType, result.operands) ||
parser.resolveOperands(offsetsInfo, indexType, result.operands) ||
parser.resolveOperands(sizesInfo, indexType, result.operands) ||
parser.resolveOperands(stridesInfo, indexType, result.operands) ||
parser.parseKeywordType("to", dstType) ||
parser.addTypeToList(dstType, result.types));
}

static void print(OpAsmPrinter &p, SubViewOp op) {
p << op.getOperationName() << ' ' << op.getOperand(0) << '[' << op.offsets()
<< "][" << op.sizes() << "][" << op.strides() << ']';

std::array<StringRef, 1> elidedAttrs = {
SubViewOp::getOperandSegmentSizeAttr()};
p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
p << " : " << op.getOperand(0).getType() << " to " << op.getType();
}

static LogicalResult verify(SubViewOp op) {
auto baseType = op.getBaseMemRefType().cast<MemRefType>();
auto subViewType = op.getType();
Expand Down
52 changes: 0 additions & 52 deletions mlir/lib/Dialect/VectorOps/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -963,58 +963,6 @@ static LogicalResult verify(OuterProductOp op) {
// ReshapeOp
//===----------------------------------------------------------------------===//

static void print(OpAsmPrinter &p, ReshapeOp op) {
p << op.getOperationName() << " " << op.vector() << ", [" << op.input_shape()
<< "], [" << op.output_shape() << "], " << op.fixed_vector_sizes();
std::array<StringRef, 2> elidedAttrs = {
ReshapeOp::getOperandSegmentSizeAttr(),
ReshapeOp::getFixedVectorSizesAttrName()};
p.printOptionalAttrDict(op.getAttrs(), elidedAttrs);
p << " : " << op.getInputVectorType() << " to " << op.getOutputVectorType();
}

// TODO(b/146516564) Consider passing number of inner vector dimensions that
// are fixed, instead of their values in 'fixesVectorSizes' array attr.
//
// operation ::= ssa-id `=` `vector.reshape` ssa-use, `[` ssa-use-list `]`,
// `[` ssa-use-list `]`, `[` array-attribute `]`
// `:` vector-type 'to' vector-type
//
static ParseResult parseReshapeOp(OpAsmParser &parser, OperationState &result) {
OpAsmParser::OperandType inputInfo;
SmallVector<OpAsmParser::OperandType, 4> inputShapeInfo;
SmallVector<OpAsmParser::OperandType, 4> outputShapeInfo;
ArrayAttr fixedVectorSizesAttr;
StringRef attrName = ReshapeOp::getFixedVectorSizesAttrName();
auto indexType = parser.getBuilder().getIndexType();
if (parser.parseOperand(inputInfo) || parser.parseComma() ||
parser.parseOperandList(inputShapeInfo, OpAsmParser::Delimiter::Square) ||
parser.parseComma() ||
parser.parseOperandList(outputShapeInfo,
OpAsmParser::Delimiter::Square) ||
parser.parseComma()) {
return failure();
}

auto builder = parser.getBuilder();
result.addAttribute(
ReshapeOp::getOperandSegmentSizeAttr(),
builder.getI32VectorAttr({1, static_cast<int32_t>(inputShapeInfo.size()),
static_cast<int32_t>(outputShapeInfo.size())}));
Type inputType;
Type outputType;
return failure(
parser.parseAttribute(fixedVectorSizesAttr, attrName,
result.attributes) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(inputType) ||
parser.resolveOperand(inputInfo, inputType, result.operands) ||
parser.resolveOperands(inputShapeInfo, indexType, result.operands) ||
parser.resolveOperands(outputShapeInfo, indexType, result.operands) ||
parser.parseKeywordType("to", outputType) ||
parser.addTypeToList(outputType, result.types));
}

static LogicalResult verify(ReshapeOp op) {
// Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
auto inputVectorType = op.getInputVectorType();
Expand Down
35 changes: 16 additions & 19 deletions mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1964,9 +1964,13 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
/*withKeyword=*/true);
}

/// Print the given successor.
void printSuccessor(Block *successor) override;

/// Print an operation successor with the operands used for the block
/// arguments.
void printSuccessorAndUseList(Operation *term, unsigned index) override;
void printSuccessorAndUseList(Block *successor,
ValueRange succOperands) override;

/// Print the given region.
void printRegion(Region &region, bool printEntryBlockArgs,
Expand Down Expand Up @@ -2062,23 +2066,14 @@ void OperationPrinter::printGenericOp(Operation *op) {
os << '"';
printEscapedString(op->getName().getStringRef(), os);
os << "\"(";

// Get the list of operands that are not successor operands.
unsigned totalNumSuccessorOperands = 0;
unsigned numSuccessors = op->getNumSuccessors();
for (unsigned i = 0; i < numSuccessors; ++i)
totalNumSuccessorOperands += op->getNumSuccessorOperands(i);
unsigned numProperOperands = op->getNumOperands() - totalNumSuccessorOperands;
interleaveComma(op->getOperands().take_front(numProperOperands),
[&](Value value) { printValueID(value); });

interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
os << ')';

// For terminators, print the list of successors and their operands.
if (numSuccessors != 0) {
if (op->getNumSuccessors() != 0) {
os << '[';
interleaveComma(llvm::seq<unsigned>(0, numSuccessors),
[&](unsigned i) { printSuccessorAndUseList(op, i); });
interleaveComma(op->getSuccessors(),
[&](Block *successor) { printBlockName(successor); });
os << ']';
}

Expand Down Expand Up @@ -2167,12 +2162,14 @@ void OperationPrinter::printValueID(Value value, bool printResultNo) const {
state->getSSANameState().printValueID(value, printResultNo, os);
}

void OperationPrinter::printSuccessorAndUseList(Operation *term,
unsigned index) {
printBlockName(term->getSuccessor(index));
void OperationPrinter::printSuccessor(Block *successor) {
printBlockName(successor);
}

auto succOperands = term->getSuccessorOperands(index);
if (succOperands.begin() == succOperands.end())
void OperationPrinter::printSuccessorAndUseList(Block *successor,
ValueRange succOperands) {
printBlockName(successor);
if (succOperands.empty())
return;

os << '(';
Expand Down
31 changes: 10 additions & 21 deletions mlir/lib/IR/Block.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,20 +143,23 @@ void Block::recomputeOpOrder() {
// Argument list management.
//===----------------------------------------------------------------------===//

/// Return a range containing the types of the arguments for this block.
auto Block::getArgumentTypes() -> ValueTypeRange<BlockArgListType> {
return ValueTypeRange<BlockArgListType>(getArguments());
}

BlockArgument Block::addArgument(Type type) {
BlockArgument arg = BlockArgument::create(type, this);
arguments.push_back(arg);
return arg;
}

/// Add one argument to the argument list for each type specified in the list.
auto Block::addArguments(ArrayRef<Type> types)
-> iterator_range<args_iterator> {
arguments.reserve(arguments.size() + types.size());
auto initialSize = arguments.size();
for (auto type : types) {
auto Block::addArguments(TypeRange types) -> iterator_range<args_iterator> {
size_t initialSize = arguments.size();
arguments.reserve(initialSize + types.size());
for (auto type : types)
addArgument(type);
}
return {arguments.data() + initialSize, arguments.data() + arguments.size()};
}

Expand All @@ -167,22 +170,8 @@ BlockArgument Block::insertArgument(unsigned index, Type type) {
return arg;
}

void Block::eraseArgument(unsigned index, bool updatePredTerms) {
void Block::eraseArgument(unsigned index) {
assert(index < arguments.size());

// If requested, update predecessors. We do this first since this block might
// be a predecessor of itself and use this block argument as a successor
// operand.
if (updatePredTerms) {
// Erase this argument from each of the predecessor's terminator.
for (auto predIt = pred_begin(), predE = pred_end(); predIt != predE;
++predIt) {
auto *predTerminator = (*predIt)->getTerminator();
predTerminator->eraseSuccessorOperand(predIt.getSuccessorIndex(), index);
}
}

// Delete the argument.
arguments[index].destroy();
arguments.erase(arguments.begin() + index);
}
Expand Down
215 changes: 58 additions & 157 deletions mlir/lib/IR/Operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,10 @@ Operation *Operation::create(Location location, OperationName name,
NamedAttributeList attributes,
ArrayRef<Block *> successors, unsigned numRegions,
bool resizableOperandList) {
unsigned numSuccessors = successors.size();

// We only need to allocate additional memory for a subset of results.
unsigned numTrailingResults = OpResult::getNumTrailing(resultTypes.size());

// Input operands are nullptr-separated for each successor, the null operands
// aren't actually stored.
unsigned numOperands = operands.size() - numSuccessors;
unsigned numSuccessors = successors.size();
unsigned numOperands = operands.size();

// Compute the byte size for the operation and the operand storage.
auto byteSize = totalSizeToAlloc<detail::TrailingOpResult, BlockOperand,
Expand Down Expand Up @@ -152,56 +148,17 @@ Operation *Operation::create(Location location, OperationName name,
for (unsigned i = 0; i != numRegions; ++i)
new (&op->getRegion(i)) Region(op);

// Initialize the results and operands.
// Initialize the operands.
new (&op->getOperandStorage())
detail::OperandStorage(numOperands, resizableOperandList);
auto opOperands = op->getOpOperands();
for (unsigned i = 0; i != numOperands; ++i)
new (&opOperands[i]) OpOperand(op, operands[i]);

// Initialize normal operands.
unsigned operandIt = 0, operandE = operands.size();
unsigned nextOperand = 0;
for (; operandIt != operandE; ++operandIt) {
// Null operands are used as sentinels between successor operand lists. If
// we encounter one here, break and handle the successor operands lists
// separately below.
if (!operands[operandIt])
break;
new (&opOperands[nextOperand++]) OpOperand(op, operands[operandIt]);
}

unsigned currentSuccNum = 0;
if (operandIt == operandE) {
// Verify that the amount of sentinel operands is equivalent to the number
// of successors.
assert(currentSuccNum == numSuccessors);
return op;
}

assert(!op->isKnownNonTerminator() &&
"Unexpected nullptr in operand list when creating non-terminator.");
auto instBlockOperands = op->getBlockOperands();
unsigned *succOperandCount = nullptr;

for (; operandIt != operandE; ++operandIt) {
// If we encounter a sentinel branch to the next operand update the count
// variable.
if (!operands[operandIt]) {
assert(currentSuccNum < numSuccessors);

new (&instBlockOperands[currentSuccNum])
BlockOperand(op, successors[currentSuccNum]);
succOperandCount =
&instBlockOperands[currentSuccNum].numSuccessorOperands;
++currentSuccNum;
continue;
}
new (&opOperands[nextOperand++]) OpOperand(op, operands[operandIt]);
++(*succOperandCount);
}

// Verify that the amount of sentinel operands is equivalent to the number of
// successors.
assert(currentSuccNum == numSuccessors);
// Initialize the successors.
auto blockOperands = op->getBlockOperands();
for (unsigned i = 0; i != numSuccessors; ++i)
new (&blockOperands[i]) BlockOperand(op, successors[i]);

return op;
}
Expand Down Expand Up @@ -564,49 +521,6 @@ void Operation::setSuccessor(Block *block, unsigned index) {
getBlockOperands()[index].set(block);
}

auto Operation::getNonSuccessorOperands() -> operand_range {
return getOperands().take_front(hasSuccessors() ? getSuccessorOperandIndex(0)
: getNumOperands());
}

/// Get the index of the first operand of the successor at the provided
/// index.
unsigned Operation::getSuccessorOperandIndex(unsigned index) {
assert(!isKnownNonTerminator() && "only terminators may have successors");
assert(index < getNumSuccessors());

// Count the number of operands for each of the successors after, and
// including, the one at 'index'. This is based upon the assumption that all
// non successor operands are placed at the beginning of the operand list.
auto blockOperands = getBlockOperands().drop_front(index);
unsigned postSuccessorOpCount =
std::accumulate(blockOperands.begin(), blockOperands.end(), 0u,
[](unsigned cur, const BlockOperand &operand) {
return cur + operand.numSuccessorOperands;
});
return getNumOperands() - postSuccessorOpCount;
}

Optional<std::pair<unsigned, unsigned>>
Operation::decomposeSuccessorOperandIndex(unsigned operandIndex) {
assert(!isKnownNonTerminator() && "only terminators may have successors");
assert(operandIndex < getNumOperands());
unsigned currentOperandIndex = getNumOperands();
auto blockOperands = getBlockOperands();
for (unsigned i = 0, e = getNumSuccessors(); i < e; i++) {
unsigned successorIndex = e - i - 1;
currentOperandIndex -= blockOperands[successorIndex].numSuccessorOperands;
if (currentOperandIndex <= operandIndex)
return std::make_pair(successorIndex, operandIndex - currentOperandIndex);
}
return None;
}

auto Operation::getSuccessorOperands(unsigned index) -> operand_range {
unsigned succOperandIndex = getSuccessorOperandIndex(index);
return getOperands().slice(succOperandIndex, getNumSuccessorOperands(index));
}

/// Attempt to fold this operation using the Op's registered foldHook.
LogicalResult Operation::fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
Expand Down Expand Up @@ -645,39 +559,20 @@ Operation *Operation::cloneWithoutRegions(BlockAndValueMapping &mapper) {
SmallVector<Value, 8> operands;
SmallVector<Block *, 2> successors;

operands.reserve(getNumOperands() + getNumSuccessors());

if (getNumSuccessors() == 0) {
// Non-branching operations can just add all the operands.
for (auto opValue : getOperands())
operands.push_back(mapper.lookupOrDefault(opValue));
} else {
// We add the operands separated by nullptr's for each successor.
unsigned firstSuccOperand =
getNumSuccessors() ? getSuccessorOperandIndex(0) : getNumOperands();
auto opOperands = getOpOperands();

unsigned i = 0;
for (; i != firstSuccOperand; ++i)
operands.push_back(mapper.lookupOrDefault(opOperands[i].get()));

successors.reserve(getNumSuccessors());
for (unsigned succ = 0, e = getNumSuccessors(); succ != e; ++succ) {
successors.push_back(mapper.lookupOrDefault(getSuccessor(succ)));
// Remap the operands.
operands.reserve(getNumOperands());
for (auto opValue : getOperands())
operands.push_back(mapper.lookupOrDefault(opValue));

// Add sentinel to delineate successor operands.
operands.push_back(nullptr);
// Remap the successors.
successors.reserve(getNumSuccessors());
for (Block *successor : getSuccessors())
successors.push_back(mapper.lookupOrDefault(successor));

// Remap the successors operands.
for (auto operand : getSuccessorOperands(succ))
operands.push_back(mapper.lookupOrDefault(operand));
}
}

unsigned numRegions = getNumRegions();
auto *newOp =
Operation::create(getLoc(), getName(), getResultTypes(), operands, attrs,
successors, numRegions, hasResizableOperandsList());
// Create the new operation.
auto *newOp = Operation::create(getLoc(), getName(), getResultTypes(),
operands, attrs, successors, getNumRegions(),
hasResizableOperandsList());

// Remember the mapping of any results.
for (unsigned i = 0, e = getNumResults(); i != e; ++i)
Expand Down Expand Up @@ -942,52 +837,58 @@ LogicalResult OpTrait::impl::verifySameOperandsAndResultType(Operation *op) {
return success();
}

static LogicalResult verifySuccessor(Operation *op, unsigned succNo) {
Operation::operand_range operands = op->getSuccessorOperands(succNo);
unsigned operandCount = op->getNumSuccessorOperands(succNo);
Block *destBB = op->getSuccessor(succNo);
if (operandCount != destBB->getNumArguments())
return op->emitError() << "branch has " << operandCount
<< " operands for successor #" << succNo
<< ", but target block has "
<< destBB->getNumArguments();

auto operandIt = operands.begin();
for (unsigned i = 0, e = operandCount; i != e; ++i, ++operandIt) {
if ((*operandIt).getType() != destBB->getArgument(i).getType())
return op->emitError() << "type mismatch for bb argument #" << i
<< " of successor #" << succNo;
}

LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
Block *block = op->getBlock();
// Verify that the operation is at the end of the respective parent block.
if (!block || &block->back() != op)
return op->emitOpError("must be the last operation in the parent block");
return success();
}

static LogicalResult verifyTerminatorSuccessors(Operation *op) {
auto *parent = op->getParentRegion();

// Verify that the operands lines up with the BB arguments in the successor.
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
auto *succ = op->getSuccessor(i);
for (Block *succ : op->getSuccessors())
if (succ->getParent() != parent)
return op->emitError("reference to block defined in another region");
if (failed(verifySuccessor(op, i)))
return failure();
}
return success();
}

LogicalResult OpTrait::impl::verifyIsTerminator(Operation *op) {
Block *block = op->getBlock();
// Verify that the operation is at the end of the respective parent block.
if (!block || &block->back() != op)
return op->emitOpError("must be the last operation in the parent block");

// Verify the state of the successor blocks.
if (op->getNumSuccessors() != 0 && failed(verifyTerminatorSuccessors(op)))
return failure();
LogicalResult OpTrait::impl::verifyZeroSuccessor(Operation *op) {
if (op->getNumSuccessors() != 0) {
return op->emitOpError("requires 0 successors but found ")
<< op->getNumSuccessors();
}
return success();
}

LogicalResult OpTrait::impl::verifyOneSuccessor(Operation *op) {
if (op->getNumSuccessors() != 1) {
return op->emitOpError("requires 1 successor but found ")
<< op->getNumSuccessors();
}
return verifyTerminatorSuccessors(op);
}
LogicalResult OpTrait::impl::verifyNSuccessors(Operation *op,
unsigned numSuccessors) {
if (op->getNumSuccessors() != numSuccessors) {
return op->emitOpError("requires ")
<< numSuccessors << " successors but found "
<< op->getNumSuccessors();
}
return verifyTerminatorSuccessors(op);
}
LogicalResult OpTrait::impl::verifyAtLeastNSuccessors(Operation *op,
unsigned numSuccessors) {
if (op->getNumSuccessors() < numSuccessors) {
return op->emitOpError("requires at least ")
<< numSuccessors << " successors but found "
<< op->getNumSuccessors();
}
return verifyTerminatorSuccessors(op);
}

LogicalResult OpTrait::impl::verifyResultsAreBoolLike(Operation *op) {
for (auto resultType : op->getResultTypes()) {
auto elementType = getTensorOrVectorElementType(resultType);
Expand Down
17 changes: 11 additions & 6 deletions mlir/lib/IR/OperationSupport.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,11 @@ OperationState::OperationState(Location location, StringRef name,
}

void OperationState::addOperands(ValueRange newOperands) {
assert(successors.empty() && "Non successor operands should be added first.");
operands.append(newOperands.begin(), newOperands.end());
}

void OperationState::addSuccessor(Block *successor, ValueRange succOperands) {
successors.push_back(successor);
// Insert a sentinel operand to mark a barrier between successor operands.
operands.push_back(nullptr);
operands.append(succOperands.begin(), succOperands.end());
void OperationState::addSuccessors(SuccessorRange newSuccessors) {
successors.append(newSuccessors.begin(), newSuccessors.end());
}

Region *OperationState::addRegion() {
Expand Down Expand Up @@ -150,6 +146,8 @@ TypeRange::TypeRange(OperandRange values)
TypeRange::TypeRange(ResultRange values)
: TypeRange(values.getBase()->getResultTypes().slice(values.getStartIndex(),
values.size())) {}
TypeRange::TypeRange(ArrayRef<Value> values)
: TypeRange(values.data(), values.size()) {}
TypeRange::TypeRange(ValueRange values) : TypeRange(OwnerT(), values.size()) {
detail::ValueRangeOwner owner = values.begin().getBase();
if (auto *op = reinterpret_cast<Operation *>(owner.ptr.dyn_cast<void *>()))
Expand Down Expand Up @@ -183,6 +181,13 @@ Type TypeRange::dereference_iterator(OwnerT object, ptrdiff_t index) {
OperandRange::OperandRange(Operation *op)
: OperandRange(op->getOpOperands().data(), op->getNumOperands()) {}

/// Return the operand index of the first element of this range. The range
/// must not be empty.
unsigned OperandRange::getBeginOperandIndex() const {
assert(!empty() && "range must not be empty");
return base->getOperandNumber();
}

//===----------------------------------------------------------------------===//
// ResultRange

Expand Down
83 changes: 35 additions & 48 deletions mlir/lib/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3301,13 +3301,11 @@ class OperationParser : public Parser {
/// Parse an operation instance.
ParseResult parseOperation();

/// Parse a single operation successor and its operand list.
ParseResult parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value> &operands);
/// Parse a single operation successor.
ParseResult parseSuccessor(Block *&dest);

/// Parse a comma-separated list of operation successors in brackets.
ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations,
SmallVectorImpl<SmallVector<Value, 4>> &operands);
ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations);

/// Parse an operation instance that is in the generic form.
Operation *parseGenericOperation();
Expand Down Expand Up @@ -3797,46 +3795,32 @@ ParseResult OperationParser::parseOperation() {
return success();
}

/// Parse a single operation successor and its operand list.
/// Parse a single operation successor.
///
/// successor ::= block-id branch-use-list?
/// branch-use-list ::= `(` ssa-use-list ':' type-list-no-parens `)`
/// successor ::= block-id
///
ParseResult
OperationParser::parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value> &operands) {
ParseResult OperationParser::parseSuccessor(Block *&dest) {
// Verify branch is identifier and get the matching block.
if (!getToken().is(Token::caret_identifier))
return emitError("expected block name");
dest = getBlockNamed(getTokenSpelling(), getToken().getLoc());
consumeToken();

// Handle optional arguments.
if (consumeIf(Token::l_paren) &&
(parseOptionalSSAUseAndTypeList(operands) ||
parseToken(Token::r_paren, "expected ')' to close argument list"))) {
return failure();
}

return success();
}

/// Parse a comma-separated list of operation successors in brackets.
///
/// successor-list ::= `[` successor (`,` successor )* `]`
///
ParseResult OperationParser::parseSuccessors(
SmallVectorImpl<Block *> &destinations,
SmallVectorImpl<SmallVector<Value, 4>> &operands) {
ParseResult
OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
if (parseToken(Token::l_square, "expected '['"))
return failure();

auto parseElt = [this, &destinations, &operands]() {
auto parseElt = [this, &destinations] {
Block *dest;
SmallVector<Value, 4> destOperands;
auto res = parseSuccessorAndUseList(dest, destOperands);
ParseResult res = parseSuccessor(dest);
destinations.push_back(dest);
operands.push_back(destOperands);
return res;
};
return parseCommaSeparatedListUntil(Token::r_square, parseElt,
Expand Down Expand Up @@ -3880,24 +3864,23 @@ Operation *OperationParser::parseGenericOperation() {

// Parse the operand list.
SmallVector<SSAUseInfo, 8> operandInfos;

if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
parseOptionalSSAUseList(operandInfos) ||
parseToken(Token::r_paren, "expected ')' to end operand list")) {
return nullptr;
}

// Parse the successor list but don't add successors to the result yet to
// avoid messing up with the argument order.
SmallVector<Block *, 2> successors;
SmallVector<SmallVector<Value, 4>, 2> successorOperands;
// Parse the successor list.
if (getToken().is(Token::l_square)) {
// Check if the operation is a known terminator.
const AbstractOperation *abstractOp = result.name.getAbstractOperation();
if (abstractOp && !abstractOp->hasProperty(OperationProperty::Terminator))
return emitError("successors in non-terminator"), nullptr;
if (parseSuccessors(successors, successorOperands))

SmallVector<Block *, 2> successors;
if (parseSuccessors(successors))
return nullptr;
result.addSuccessors(successors);
}

// Parse the region list.
Expand Down Expand Up @@ -3948,13 +3931,6 @@ Operation *OperationParser::parseGenericOperation() {
return nullptr;
}

// Add the successors, and their operands after the proper operands.
for (auto succ : llvm::zip(successors, successorOperands)) {
Block *successor = std::get<0>(succ);
const SmallVector<Value, 4> &operands = std::get<1>(succ);
result.addSuccessor(successor, operands);
}

// Parse a location if one is present.
if (parseOptionalTrailingLocation(result.location))
return nullptr;
Expand Down Expand Up @@ -4421,20 +4397,31 @@ class CustomOpAsmParser : public OpAsmParser {
// Successor Parsing
//===--------------------------------------------------------------------===//

/// Parse a single operation successor and its operand list.
ParseResult
parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value> &operands) override {
return parser.parseSuccessorAndUseList(dest, operands);
/// Parse a single operation successor.
ParseResult parseSuccessor(Block *&dest) override {
return parser.parseSuccessor(dest);
}

/// Parse an optional operation successor and its operand list.
OptionalParseResult
parseOptionalSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value> &operands) override {
OptionalParseResult parseOptionalSuccessor(Block *&dest) override {
if (parser.getToken().isNot(Token::caret_identifier))
return llvm::None;
return parseSuccessorAndUseList(dest, operands);
return parseSuccessor(dest);
}

/// Parse a single operation successor and its operand list.
ParseResult
parseSuccessorAndUseList(Block *&dest,
SmallVectorImpl<Value> &operands) override {
if (parseSuccessor(dest))
return failure();

// Handle optional arguments.
if (succeeded(parseOptionalLParen()) &&
(parser.parseOptionalSSAUseAndTypeList(operands) || parseRParen())) {
return failure();
}
return success();
}

//===--------------------------------------------------------------------===//
Expand Down
20 changes: 14 additions & 6 deletions mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,21 +634,29 @@ LogicalResult Importer::processInstruction(llvm::Instruction *inst) {
auto *brInst = cast<llvm::BranchInst>(inst);
OperationState state(loc,
brInst->isConditional() ? "llvm.cond_br" : "llvm.br");
SmallVector<Value, 4> ops;
if (brInst->isConditional()) {
Value condition = processValue(brInst->getCondition());
if (!condition)
return failure();
ops.push_back(condition);
state.addOperands(condition);
}
state.addOperands(ops);
SmallVector<Block *, 4> succs;
for (auto *succ : llvm::reverse(brInst->successors())) {

std::array<int32_t, 3> operandSegmentSizes = {1, 0, 0};
for (int i : llvm::seq<int>(0, brInst->getNumSuccessors())) {
auto *succ = brInst->getSuccessor(i);
SmallVector<Value, 4> blockArguments;
if (failed(processBranchArgs(brInst, succ, blockArguments)))
return failure();
state.addSuccessor(blocks[succ], blockArguments);
state.addSuccessors(blocks[succ]);
state.addOperands(blockArguments);
operandSegmentSizes[i + 1] = blockArguments.size();
}

if (brInst->isConditional()) {
state.addAttribute(LLVM::CondBrOp::getOperandSegmentSizeAttr(),
b.getI32VectorAttr(operandSegmentSizes));
}

b.createOperation(state);
return success();
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ LogicalResult ModuleTranslation::convertOperation(Operation &opInst,
// Emit branches. We need to look up the remapped blocks and ignore the block
// arguments that were transformed into PHI nodes.
if (auto brOp = dyn_cast<LLVM::BrOp>(opInst)) {
builder.CreateBr(blockMapping[brOp.getSuccessor(0)]);
builder.CreateBr(blockMapping[brOp.getSuccessor()]);
return success();
}
if (auto condbrOp = dyn_cast<LLVM::CondBrOp>(opInst)) {
Expand Down Expand Up @@ -501,8 +501,8 @@ static Value getPHISourceValue(Block *current, Block *pred,
"different blocks");

return condBranchOp.getSuccessor(0) == current
? terminator.getSuccessorOperand(0, index)
: terminator.getSuccessorOperand(1, index);
? condBranchOp.trueDestOperands()[index]
: condBranchOp.falseDestOperands()[index];
}

void ModuleTranslation::connectPHINodes(LLVMFuncOp func) {
Expand Down
28 changes: 1 addition & 27 deletions mlir/lib/Transforms/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,33 +1005,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
SmallVector<Value, 4> operands;
auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
dialectRewriter.getImpl().remapValues(op->getOperands(), operands);

// If this operation has no successors, invoke the rewrite directly.
if (op->getNumSuccessors() == 0)
return matchAndRewrite(op, operands, dialectRewriter);

// Otherwise, we need to remap the successors.
SmallVector<Block *, 2> destinations;
destinations.reserve(op->getNumSuccessors());

SmallVector<ArrayRef<Value>, 2> operandsPerDestination;
unsigned firstSuccessorOperand = op->getSuccessorOperandIndex(0);
for (unsigned i = 0, seen = 0, e = op->getNumSuccessors(); i < e; ++i) {
destinations.push_back(op->getSuccessor(i));

// Lookup the successors operands.
unsigned n = op->getNumSuccessorOperands(i);
operandsPerDestination.push_back(
llvm::makeArrayRef(operands.data() + firstSuccessorOperand + seen, n));
seen += n;
}

// Rewrite the operation.
return matchAndRewrite(
op,
llvm::makeArrayRef(operands.data(),
operands.data() + firstSuccessorOperand),
destinations, operandsPerDestination, dialectRewriter);
return matchAndRewrite(op, operands, dialectRewriter);
}

//===----------------------------------------------------------------------===//
Expand Down
53 changes: 44 additions & 9 deletions mlir/lib/Transforms/Utils/RegionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Transforms/RegionUtils.h"
#include "mlir/Analysis/ControlFlowInterfaces.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/RegionGraphTraits.h"
Expand Down Expand Up @@ -172,8 +173,9 @@ static bool isUseSpeciallyKnownDead(OpOperand &use, LiveMap &liveMap) {
// node, rather than to the terminator op itself, a terminator op can't e.g.
// "print" the value of a successor operand.
if (owner->isKnownTerminator()) {
if (auto arg = owner->getSuccessorBlockArgument(operandIndex))
return !liveMap.wasProvenLive(*arg);
if (BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(owner))
if (auto arg = branchInterface.getSuccessorBlockArgument(operandIndex))
return !liveMap.wasProvenLive(*arg);
return false;
}
return false;
Expand All @@ -200,6 +202,29 @@ static bool isOpIntrinsicallyLive(Operation *op) {
}

static void propagateLiveness(Region &region, LiveMap &liveMap);

static void propagateTerminatorLiveness(Operation *op, LiveMap &liveMap) {
// Terminators are always live.
liveMap.setProvedLive(op);

// Check to see if we can reason about the successor operands and mutate them.
BranchOpInterface branchInterface = dyn_cast<BranchOpInterface>(op);
if (!branchInterface || !branchInterface.canEraseSuccessorOperand()) {
for (Block *successor : op->getSuccessors())
for (BlockArgument arg : successor->getArguments())
liveMap.setProvedLive(arg);
return;
}

// If we can't reason about the operands to a successor, conservatively mark
// all arguments as live.
for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
if (!branchInterface.getSuccessorOperands(i))
for (BlockArgument arg : op->getSuccessor(i)->getArguments())
liveMap.setProvedLive(arg);
}
}

static void propagateLiveness(Operation *op, LiveMap &liveMap) {
// All Value's are either a block argument or an op result.
// We call processValue on those cases.
Expand All @@ -208,6 +233,10 @@ static void propagateLiveness(Operation *op, LiveMap &liveMap) {
for (Region &region : op->getRegions())
propagateLiveness(region, liveMap);

// Process terminator operations.
if (op->isKnownTerminator())
return propagateTerminatorLiveness(op, liveMap);

// Process the op itself.
if (isOpIntrinsicallyLive(op)) {
liveMap.setProvedLive(op);
Expand Down Expand Up @@ -238,22 +267,28 @@ static void propagateLiveness(Region &region, LiveMap &liveMap) {

static void eraseTerminatorSuccessorOperands(Operation *terminator,
LiveMap &liveMap) {
BranchOpInterface branchOp = dyn_cast<BranchOpInterface>(terminator);
if (!branchOp)
return;

for (unsigned succI = 0, succE = terminator->getNumSuccessors();
succI < succE; succI++) {
// Iterating successors in reverse is not strictly needed, since we
// aren't erasing any successors. But it is slightly more efficient
// since it will promote later operands of the terminator being erased
// first, reducing the quadratic-ness.
unsigned succ = succE - succI - 1;
for (unsigned argI = 0, argE = terminator->getNumSuccessorOperands(succ);
argI < argE; argI++) {
Optional<OperandRange> succOperands = branchOp.getSuccessorOperands(succ);
if (!succOperands)
continue;
Block *successor = terminator->getSuccessor(succ);

for (unsigned argI = 0, argE = succOperands->size(); argI < argE; ++argI) {
// Iterating args in reverse is needed for correctness, to avoid
// shifting later args when earlier args are erased.
unsigned arg = argE - argI - 1;
Value value = terminator->getSuccessor(succ)->getArgument(arg);
if (!liveMap.wasProvenLive(value)) {
terminator->eraseSuccessorOperand(succ, arg);
}
if (!liveMap.wasProvenLive(successor->getArgument(arg)))
branchOp.eraseSuccessorOperand(succ, arg);
}
}
}
Expand Down Expand Up @@ -294,7 +329,7 @@ static LogicalResult deleteDeadness(MutableArrayRef<Region> regions,
// earlier arguments.
for (unsigned i = 0, e = block.getNumArguments(); i < e; i++)
if (!liveMap.wasProvenLive(block.getArgument(e - i - 1))) {
block.eraseArgument(e - i - 1, /*updatePredTerms=*/false);
block.eraseArgument(e - i - 1);
erasedAnything = true;
}
}
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Conversion/LoopsToGPU/parallel_loop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,17 @@ module {
// CHECK: [[VAL_31:%.*]] = affine.min #[[MAP3]]([[VAL_28]]){{\[}}[[VAL_30]]]
// CHECK: [[VAL_32:%.*]] = dim [[VAL_0]], 1 : memref<?x?xf32, #[[MAP0]]>
// CHECK: [[VAL_33:%.*]] = affine.min #[[MAP4]]([[VAL_29]]){{\[}}[[VAL_32]]]
// CHECK: [[VAL_34:%.*]] = std.subview [[VAL_0]]{{\[}}[[VAL_28]], [[VAL_29]]]{{\[}}[[VAL_31]], [[VAL_33]]]{{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
// CHECK: [[VAL_34:%.*]] = subview [[VAL_0]]{{\[}}[[VAL_28]], [[VAL_29]]] {{\[}}[[VAL_31]], [[VAL_33]]] {{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
// CHECK: [[VAL_35:%.*]] = dim [[VAL_1]], 0 : memref<?x?xf32, #[[MAP0]]>
// CHECK: [[VAL_36:%.*]] = affine.min #[[MAP3]]([[VAL_28]]){{\[}}[[VAL_35]]]
// CHECK: [[VAL_37:%.*]] = dim [[VAL_1]], 1 : memref<?x?xf32, #[[MAP0]]>
// CHECK: [[VAL_38:%.*]] = affine.min #[[MAP4]]([[VAL_29]]){{\[}}[[VAL_37]]]
// CHECK: [[VAL_39:%.*]] = std.subview [[VAL_1]]{{\[}}[[VAL_28]], [[VAL_29]]]{{\[}}[[VAL_36]], [[VAL_38]]]{{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
// CHECK: [[VAL_39:%.*]] = subview [[VAL_1]]{{\[}}[[VAL_28]], [[VAL_29]]] {{\[}}[[VAL_36]], [[VAL_38]]] {{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
// CHECK: [[VAL_40:%.*]] = dim [[VAL_2]], 0 : memref<?x?xf32, #[[MAP0]]>
// CHECK: [[VAL_41:%.*]] = affine.min #[[MAP3]]([[VAL_28]]){{\[}}[[VAL_40]]]
// CHECK: [[VAL_42:%.*]] = dim [[VAL_2]], 1 : memref<?x?xf32, #[[MAP0]]>
// CHECK: [[VAL_43:%.*]] = affine.min #[[MAP4]]([[VAL_29]]){{\[}}[[VAL_42]]]
// CHECK: [[VAL_44:%.*]] = std.subview [[VAL_2]]{{\[}}[[VAL_28]], [[VAL_29]]]{{\[}}[[VAL_41]], [[VAL_43]]]{{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
// CHECK: [[VAL_44:%.*]] = subview [[VAL_2]]{{\[}}[[VAL_28]], [[VAL_29]]] {{\[}}[[VAL_41]], [[VAL_43]]] {{\[}}[[VAL_3]], [[VAL_3]]] : memref<?x?xf32, #[[MAP0]]> to memref<?x?xf32, #[[MAP5]]>
// CHECK: [[VAL_45:%.*]] = affine.apply #[[MAP2]]([[VAL_22]]){{\[}}[[VAL_3]], [[VAL_4]]]
// CHECK: [[VAL_46:%.*]] = cmpi "slt", [[VAL_45]], [[VAL_31]] : index
// CHECK: loop.if [[VAL_46]] {
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -556,16 +556,16 @@ func @dfs_block_order(%arg0: i32) -> (i32) {
}
// CHECK-LABEL: func @cond_br_same_target(%arg0: !llvm.i1, %arg1: !llvm.i32, %arg2: !llvm.i32)
func @cond_br_same_target(%arg0: i1, %arg1: i32, %arg2 : i32) -> (i32) {
// CHECK-NEXT: llvm.cond_br %arg0, ^[[origBlock:bb[0-9]+]](%arg1 : !llvm.i32), ^[[dummyBlock:bb[0-9]+]]
// CHECK-NEXT: llvm.cond_br %arg0, ^[[origBlock:bb[0-9]+]](%arg1 : !llvm.i32), ^[[dummyBlock:bb[0-9]+]](%arg2 : !llvm.i32)
cond_br %arg0, ^bb1(%arg1 : i32), ^bb1(%arg2 : i32)

// CHECK: ^[[origBlock]](%0: !llvm.i32):
// CHECK-NEXT: llvm.return %0 : !llvm.i32
// CHECK: ^[[origBlock]](%[[BLOCKARG1:.*]]: !llvm.i32):
// CHECK-NEXT: llvm.return %[[BLOCKARG1]] : !llvm.i32
^bb1(%0 : i32):
return %0 : i32

// CHECK: ^[[dummyBlock]]:
// CHECK-NEXT: llvm.br ^[[origBlock]](%arg2 : !llvm.i32)
// CHECK: ^[[dummyBlock]](%[[BLOCKARG2:.*]]: !llvm.i32):
// CHECK-NEXT: llvm.br ^[[origBlock]](%[[BLOCKARG2]] : !llvm.i32)
}

// CHECK-LABEL: func @fcmp(%arg0: !llvm.float, %arg1: !llvm.float) {
Expand Down
Loading