Skip to content

Commit

Permalink
Introduce the ability for "isolated from above" ops to introduce shad…
Browse files Browse the repository at this point in the history
…owing

names for the basic block arguments in their body.

PiperOrigin-RevId: 265084627
  • Loading branch information
lattner authored and tensorflower-gardener committed Aug 23, 2019
1 parent 0017796 commit 31a003d
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 11 deletions.
7 changes: 7 additions & 0 deletions mlir/include/mlir/IR/OpImplementation.h
Expand Up @@ -85,6 +85,13 @@ class OpAsmPrinter {
virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
bool printBlockTerminators = true) = 0;

/// Renumber the arguments for the specified region to the same names as the
/// SSA values in namesToUse. This may only be used for IsolatedFromAbove
/// operations. If any entry in namesToUse is null, the corresponding
/// argument name is left alone.
virtual void shadowRegionArgs(Region &region,
ArrayRef<Value *> namesToUse) = 0;

/// Prints an affine map of SSA ids, where SSA id names are used in place
/// of dims/symbols.
/// Operand values must come from single-result sources, and be valid
Expand Down
61 changes: 54 additions & 7 deletions mlir/lib/IR/AsmPrinter.cpp
Expand Up @@ -1244,6 +1244,12 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
os.indent(currentIndent) << "}";
}

/// Renumber the arguments for the specified region to the same names as the
/// SSA values in namesToUse. This may only be used for IsolatedFromAbove
/// operations. If any entry in namesToUse is null, the corresponding
/// argument name is left alone.
void shadowRegionArgs(Region &region, ArrayRef<Value *> namesToUse) override;

void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
ArrayRef<Value *> operands) override {
AffineMap map = mapAttr.getValue();
Expand All @@ -1270,9 +1276,14 @@ class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
void numberValueID(Value *value);
void numberValuesInRegion(Region &region);
void numberValuesInBlock(Block &block);
void printValueID(Value *value, bool printResultNo = true) const;
void printValueID(Value *value, bool printResultNo = true) const {
printValueIDImpl(value, printResultNo, os);
}

private:
void printValueIDImpl(Value *value, bool printResultNo,
raw_ostream &stream) const;

/// Uniques the given value name within the printer. If the given name
/// conflicts, it is automatically renamed.
StringRef uniqueValueName(StringRef name);
Expand Down Expand Up @@ -1491,7 +1502,8 @@ void OperationPrinter::print(Operation *op) {
printTrailingLocation(op->getLoc());
}

void OperationPrinter::printValueID(Value *value, bool printResultNo) const {
void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo,
raw_ostream &stream) const {
int resultNo = -1;
auto lookupValue = value;

Expand All @@ -1507,21 +1519,56 @@ void OperationPrinter::printValueID(Value *value, bool printResultNo) const {

auto it = valueIDs.find(lookupValue);
if (it == valueIDs.end()) {
os << "<<INVALID SSA VALUE>>";
stream << "<<INVALID SSA VALUE>>";
return;
}

os << '%';
stream << '%';
if (it->second != nameSentinel) {
os << it->second;
stream << it->second;
} else {
auto nameIt = valueNames.find(lookupValue);
assert(nameIt != valueNames.end() && "Didn't have a name entry?");
os << nameIt->second;
stream << nameIt->second;
}

if (resultNo != -1 && printResultNo)
os << '#' << resultNo;
stream << '#' << resultNo;
}

/// Renumber the arguments for the specified region to the same names as the
/// SSA values in namesToUse. This may only be used for IsolatedFromAbove
/// operations. If any entry in namesToUse is null, the corresponding
/// argument name is left alone.
void OperationPrinter::shadowRegionArgs(Region &region,
ArrayRef<Value *> namesToUse) {
assert(!region.empty() && "cannot shadow arguments of an empty region");
assert(region.front().getNumArguments() == namesToUse.size() &&
"incorrect number of names passed in");
assert(region.getParentOp()->isKnownIsolatedFromAbove() &&
"only KnownIsolatedFromAbove ops can shadow names");

SmallVector<char, 16> nameStr;
for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
auto *nameToUse = namesToUse[i];
if (nameToUse == nullptr)
continue;

auto *nameToReplace = region.front().getArgument(i);

nameStr.clear();
llvm::raw_svector_ostream nameStream(nameStr);
printValueIDImpl(nameToUse, /*printResultNo=*/true, nameStream);

// Entry block arguments should already have a pretty "arg" name.
assert(valueIDs[nameToReplace] == nameSentinel);

// Use the name without the leading %.
auto name = StringRef(nameStream.str()).drop_front();

// Overwrite the name.
valueNames[nameToReplace] = name.copy(usedNameAllocator);
}
}

void OperationPrinter::printOperation(Operation *op) {
Expand Down
20 changes: 16 additions & 4 deletions mlir/test/IR/parser.mlir
Expand Up @@ -1055,13 +1055,25 @@ func @op_with_region_args() {
// CHECK-LABEL: func @op_with_passthrough_region_args
func @op_with_passthrough_region_args() {
// CHECK: [[VAL:%.*]] = constant
// CHECK: "test.isolated_region"([[VAL]])
// CHECK-NEXT: ^{{.*}}([[ARG:%.*]]: index)
// CHECK-NEXT: "foo.consumer"([[ARG]]) : (index)

%0 = constant 10 : index

// CHECK: test.isolated_region [[VAL]] {
// CHECK-NEXT: "foo.consumer"([[VAL]]) : (index)
// CHECK-NEXT: }
test.isolated_region %0 {
"foo.consumer"(%0) : (index) -> ()
}

// CHECK: [[VAL:%.*]]:2 = "foo.op"
%result:2 = "foo.op"() : () -> (index, index)

// CHECK: test.isolated_region [[VAL]]#1 {
// CHECK-NEXT: "foo.consumer"([[VAL]]#1) : (index)
// CHECK-NEXT: }
test.isolated_region %result#1 {
"foo.consumer"(%result#1) : (index) -> ()
}

return
}

7 changes: 7 additions & 0 deletions mlir/test/lib/TestDialect/TestDialect.cpp
Expand Up @@ -54,6 +54,13 @@ static ParseResult parseIsolatedRegionOp(OpAsmParser *parser,
/*enableNameShadowing=*/true);
}

static void print(OpAsmPrinter *p, IsolatedRegionOp op) {
*p << "test.isolated_region ";
p->printOperand(op.getOperand());
p->shadowRegionArgs(op.region(), op.getOperand());
p->printRegion(op.region(), /*printEntryBlockArgs=*/false);
}

//===----------------------------------------------------------------------===//
// Test PolyForOp - parse list of region arguments.
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/test/lib/TestDialect/TestOps.td
Expand Up @@ -704,6 +704,7 @@ def IsolatedRegionOp : TEST_Op<"isolated_region", [IsolatedFromAbove]> {
let arguments = (ins Index:$input);
let regions = (region SizedRegion<1>:$region);
let parser = [{ return ::parse$cppClass(parser, result); }];
let printer = [{ return ::print(p, *this); }];
}

def PolyForOp : TEST_Op<"polyfor">
Expand Down

0 comments on commit 31a003d

Please sign in to comment.