Skip to content

Commit

Permalink
NFC - clean up op accessor usage, std.load/store op verify, other sta…
Browse files Browse the repository at this point in the history
…le info

- also remove stale terminology/references in docs

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>

Closes tensorflow/mlir#148

COPYBARA_INTEGRATE_REVIEW=tensorflow/mlir#148 from bondhugula:cleanup e846b641a3c2936e874138aff480a23cdbf66591
PiperOrigin-RevId: 271618279
  • Loading branch information
bondhugula authored and tensorflower-gardener committed Sep 27, 2019
1 parent ddf737c commit 74eabdd
Show file tree
Hide file tree
Showing 13 changed files with 79 additions and 103 deletions.
43 changes: 21 additions & 22 deletions mlir/g3doc/Rationale.md
Expand Up @@ -211,12 +211,12 @@ appear in subscripts, sizes of aggregate types and affine expressions. They are
also tightly coupled with `affine.apply` and load/store operations; having
`index` type is a necessary precondition of a value to be acceptable by these
operations. While it may be useful to have `memref<?xindex>` to express indirect
accesses in MLFunctions, e.g. sparse matrix manipulations or lookup tables, it
creates problems MLIR is not ready to address yet. MLIR needs to internally
store constants of aggregate types and emit code operating on values of those
types, which are subject to target-specific size and alignment constraints.
Since MLIR does not have a target description mechanism at the moment, it cannot
reliably emit such code. Moreover, some platforms may not support vectors of
accesses, e.g. sparse matrix manipulations or lookup tables, it creates problems
MLIR is not ready to address yet. MLIR needs to internally store constants of
aggregate types and emit code operating on values of those types, which are
subject to target-specific size and alignment constraints. Since MLIR does not
have a target description mechanism at the moment, it cannot reliably emit such
code. Moreover, some platforms may not support vectors of
type equivalent to `index`.

Indirect access use cases can be alternatively supported by providing and
Expand Down Expand Up @@ -721,9 +721,9 @@ in a dilated convolution.
// input: [batch, input_height, input_width, input_feature]
// kernel: [kernel_height, kernel_width, input_feature, output_feature]
// output: [batch, output_height, output_width, output_feature]
func @conv2d(memref<16x1024x1024x3xf32, #lm0, vmem> %input,
memref<5x5x3x32xf32, #lm0, vmem> %kernel,
memref<16x512x512x32xf32, #lm0, vmem> %output) {
func @conv2d(memref<16x1024x1024x3xf32, #lm0, /*scratchpad=*/1> %input,
memref<5x5x3x32xf32, #lm0, /*scratchpad=*/1> %kernel,
memref<16x512x512x32xf32, #lm0, /*scratchpad=*/1> %output) {
affine.for %b = 0 to %batch {
affine.for %oh = 0 to %output_height {
affine.for %ow = 0 to %output_width {
Expand Down Expand Up @@ -794,14 +794,13 @@ At a high level, we have two alternatives here:
explicitly propagate the schedule into domains and model all the cleanup
code. An example and more detail on the schedule tree form is in the next
section.
1. Having two different forms of MLFunctions: an affine loop tree form
1. Having two different forms of "affine regions": an affine loop tree form
(AffineLoopTreeFunction) and a polyhedral schedule tree form as two
different forms of MLFunctions. Or in effect, having four different forms
for functions in MLIR instead of three: CFG Function,
AffineLoopTreeFunction, Polyhedral Schedule Tree function, and external
functions.
different forms. Or in effect, having four different forms for functions in
MLIR instead of three: CFG Function, AffineLoopTreeFunction, Polyhedral
Schedule Tree function, and external functions.

#### Schedule Tree Representation for MLFunctions
#### Schedule Tree Representation for Affine Regions

This representation is based on a simplified form of the domain/schedule
representation used by the polyhedral compiler community. Domains represent what
Expand All @@ -826,15 +825,15 @@ func @matmul(%A, %B, %C, %M, %N, %K) : (...) { // %M, N, K are symbols
mldim %t1 : {S1,S2,S3,S4,S5} floordiv (i, 128) {
mldim %t2 : {S1,S2,S3,S4,S5} floordiv (j, 128) {
// (%i, %j) = affine.apply (d0, d1) -> (128*d0, 128*d1) (%t1, %t2)
call dma_hbm_to_vmem(%C, %i, %j, %M, %N, %K)
call dma_mem_to_scratchpad(%C, %i, %j, %M, %N, %K)
with @intset_ij(%i, %j) [%M, %N, %K]
mldim %t3 : {S2,S3,S4,S5} floordiv (k, 128) {
// (%i, %j, %k) = affine.apply (d0, d1, d2)
// -> (128*d0, 128*d1, 128*d2) (%t1, %t2, %t3)
call dma_hbm_to_vmem(%A, ...) with #inset_ijk (%i, %j, %k) [%M, %N, %K]
call dma_mem_to_scratchpad(%A, ...) with #inset_ijk (%i, %j, %k) [%M, %N, %K]
// (%i, %j, %k) = affine.apply (d0, d1, d2)
// -> (128*d0, 128*d1, 128*d2) (%t1, %t2, %t3)
call dma_hbm_to_vmem(%B, ...) with #inset_ijk (%i, %j, %k) [%M, %N, %K]
call dma_mem_to_scratchpad(%B, ...) with #inset_ijk (%i, %j, %k) [%M, %N, %K]
mldim %t4 : {S4} i mod 128 {
mldim %t5 : {S4} j mod 128 {
mldim %t6 : {S4} k mod 128 {
Expand All @@ -846,7 +845,7 @@ func @matmul(%A, %B, %C, %M, %N, %K) : (...) { // %M, N, K are symbols
} // end mldim t4
} // end mldim t3
// (%i, %j) = affine.apply (d0, d1) -> (128*d0, 128*d1) (%t1, %t2)
call $dma_vmem_to_hbm_C ... with #intset(%i, %j) [%M, %N, %K]
call $dma_scratchpad_to_mem_C ... with #intset(%i, %j) [%M, %N, %K]
} // end mldim t2
} // end mldim t1
return
Expand Down Expand Up @@ -978,15 +977,15 @@ Example:
```mlir {.mlir}
##rel9 ( ) [s0] -> (r0, r1) : 0 <= r0 <= 1023, 0 <= r1 <= s0 - 1
func @cblas_reduce_ffi(memref<1024 x ? x f32, #layout_map0, hbm> %M) -> f32 [
func @cblas_reduce_ffi(memref<1024 x ? x f32, #layout_map0, /*mem=*/0> %M) -> f32 [
reads: {%M, ##rel9() }
writes: /* empty */
may_reads: /* empty */
may_writes: /* empty */
]
func @dma_hbm_to_vmem(memref<1024 x f32, #layout_map0, hbm> %a,
offset, memref<1024 x f32, #layout_map0, vmem> %b,
func @dma_mem_to_scratchpad(memref<1024 x f32, #layout_map0, /*mem=*/0> %a,
offset, memref<1024 x f32, #layout_map0, 1> %b,
memref<1024 x f32, #layout_map0> %c
) [
reads: {%M, ##rel9() }
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Analysis/NestedMatcher.h
@@ -1,4 +1,4 @@
//===- NestedMacher.h - Nested matcher for MLFunction -----------*- C++ -*-===//
//===- NestedMacher.h - Nested matcher for Function -------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
Expand Down
4 changes: 1 addition & 3 deletions mlir/include/mlir/Dialect/AffineOps/AffineOps.h
Expand Up @@ -580,9 +580,7 @@ class AffineBound {
AffineValueMap getAsAffineValueMap();

unsigned getNumOperands() { return opEnd - opStart; }
Value *getOperand(unsigned idx) {
return op.getOperation()->getOperand(opStart + idx);
}
Value *getOperand(unsigned idx) { return op.getOperand(opStart + idx); }

using operand_iterator = AffineForOp::operand_iterator;
using operand_range = AffineForOp::operand_range;
Expand Down
29 changes: 15 additions & 14 deletions mlir/include/mlir/Dialect/StandardOps/Ops.td
Expand Up @@ -300,7 +300,8 @@ def CallIndirectOp : Std_Op<"call_indirect", [CallOpInterface]> {
let hasCanonicalizer = 1;
}

def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
def CmpIOp : Std_Op<"cmpi",
[NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
let summary = "integer comparison operation";
let description = [{
The "cmpi" operation compares its two operands according to the integer
Expand Down Expand Up @@ -345,7 +346,8 @@ def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands, SameOperandsAndResu
let hasFolder = 1;
}

def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
def CmpFOp : Std_Op<"cmpf",
[NoSideEffect, SameTypeOperands, SameOperandsAndResultShape]> {
let summary = "floating-point comparison operation";
let description = [{
The "cmpf" operation compares its two operands according to the float
Expand Down Expand Up @@ -431,12 +433,12 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {

/// Return the destination if the condition is true.
Block *getTrueDest() {
return getOperation()->getSuccessor(trueIndex);
return getSuccessor(trueIndex);
}

/// Return the destination if the condition is false.
Block *getFalseDest() {
return getOperation()->getSuccessor(falseIndex);
return getSuccessor(falseIndex);
}

// Accessors for operands to the 'true' destination.
Expand All @@ -461,7 +463,7 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
}

unsigned getNumTrueOperands() {
return getOperation()->getNumSuccessorOperands(trueIndex);
return getNumSuccessorOperands(trueIndex);
}

/// Erase the operand at 'index' from the true operand list.
Expand All @@ -488,7 +490,7 @@ def CondBranchOp : Std_Op<"cond_br", [Terminator]> {
}

unsigned getNumFalseOperands() {
return getOperation()->getNumSuccessorOperands(falseIndex);
return getNumSuccessorOperands(falseIndex);
}

/// Erase the operand at 'index' from the false operand list.
Expand Down Expand Up @@ -624,8 +626,7 @@ def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
Value *getAggregate() { return getOperand(0); }

operand_range getIndices() {
return {getOperation()->operand_begin() + 1,
getOperation()->operand_end()};
return {operand_begin() + 1, operand_end()};
}
}];

Expand Down Expand Up @@ -698,9 +699,7 @@ def LoadOp : Std_Op<"load"> {
return getMemRef()->getType().cast<MemRefType>();
}

operand_range getIndices() {
return {getOperation()->operand_begin() + 1, getOperation()->operand_end()};
}
operand_range getIndices() { return {operand_begin() + 1, operand_end()}; }
}];

let hasCanonicalizer = 1;
Expand Down Expand Up @@ -843,7 +842,8 @@ def SelectOp : Std_Op<"select", [NoSideEffect, SameOperandsAndResultShape]> {
let hasFolder = 1;
}

def SignExtendIOp : Std_Op<"sexti", [NoSideEffect, SameOperandsAndResultShape]> {
def SignExtendIOp : Std_Op<"sexti",
[NoSideEffect, SameOperandsAndResultShape]> {
let summary = "integer sign extension operation";
let description = [{
The integer sign extension operation takes an integer input of
Expand Down Expand Up @@ -930,7 +930,8 @@ def StoreOp : Std_Op<"store"> {
store %v, %A[%i, %j] : memref<4x128xf32, (d0, d1) -> (d0, d1), 0>
}];

let arguments = (ins AnyType:$value, AnyMemRef:$memref, Variadic<Index>:$indices);
let arguments = (ins AnyType:$value, AnyMemRef:$memref,
Variadic<Index>:$indices);

let builders = [OpBuilder<
"Builder *, OperationState &result, Value *valueToStore, Value *memref", [{
Expand All @@ -948,7 +949,7 @@ def StoreOp : Std_Op<"store"> {
}

operand_range getIndices() {
return {getOperation()->operand_begin() + 2, getOperation()->operand_end()};
return {operand_begin() + 2, operand_end()};
}
}];

Expand Down
34 changes: 17 additions & 17 deletions mlir/lib/Dialect/AffineOps/AffineOps.cpp
Expand Up @@ -188,7 +188,7 @@ void AffineApplyOp::build(Builder *builder, OperationState &result,

ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
auto affineIntTy = builder.getIndexType();
auto indexTy = builder.getIndexType();

AffineMapAttr mapAttr;
unsigned numDims;
Expand All @@ -204,7 +204,7 @@ ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
"dimension or symbol index mismatch");
}

result.types.append(map.getNumResults(), affineIntTy);
result.types.append(map.getNumResults(), indexTy);
return success();
}

Expand Down Expand Up @@ -1139,7 +1139,7 @@ static ParseResult parseBound(bool isLower, OperationState &result,
return p.emitError(p.getNameLoc(),
"expected only one loop bound operand");

// TODO: improve error message when SSA value is not an affine integer.
// TODO: improve error message when SSA value is not of index type.
// Currently it is 'use of value ... expects different type than prior uses'
if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
result.operands))
Expand Down Expand Up @@ -1754,7 +1754,7 @@ void AffineLoadOp::build(Builder *builder, OperationState &result,

ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
auto affineIntTy = builder.getIndexType();
auto indexTy = builder.getIndexType();

MemRefType type;
OpAsmParser::OperandType memrefInfo;
Expand All @@ -1767,7 +1767,7 @@ ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseOptionalAttributeDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperand(memrefInfo, type, result.operands) ||
parser.resolveOperands(mapOperands, affineIntTy, result.operands) ||
parser.resolveOperands(mapOperands, indexTy, result.operands) ||
parser.addTypeToList(type.getElementType(), result.types));
}

Expand Down Expand Up @@ -1845,24 +1845,24 @@ void AffineStoreOp::build(Builder *builder, OperationState &result,
}

ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
auto affineIntTy = parser.getBuilder().getIndexType();
auto indexTy = parser.getBuilder().getIndexType();

MemRefType type;
OpAsmParser::OperandType storeValueInfo;
OpAsmParser::OperandType memrefInfo;
AffineMapAttr mapAttr;
SmallVector<OpAsmParser::OperandType, 1> mapOperands;
return failure(
parser.parseOperand(storeValueInfo) || parser.parseComma() ||
parser.parseOperand(memrefInfo) ||
parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(),
result.attributes) ||
parser.parseOptionalAttributeDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperand(storeValueInfo, type.getElementType(),
result.operands) ||
parser.resolveOperand(memrefInfo, type, result.operands) ||
parser.resolveOperands(mapOperands, affineIntTy, result.operands));
return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
parser.parseOperand(memrefInfo) ||
parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
getMapAttrName(),
result.attributes) ||
parser.parseOptionalAttributeDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperand(storeValueInfo, type.getElementType(),
result.operands) ||
parser.resolveOperand(memrefInfo, type, result.operands) ||
parser.resolveOperands(mapOperands, indexTy, result.operands));
}

void AffineStoreOp::print(OpAsmPrinter &p) {
Expand Down
25 changes: 12 additions & 13 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Expand Up @@ -406,14 +406,14 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
ViewType type;

auto affineIntTy = parser.getBuilder().getIndexType();
auto indexTy = parser.getBuilder().getIndexType();
return failure(
parser.parseOperand(viewInfo) ||
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
parser.parseOptionalAttributeDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperand(viewInfo, type, result.operands) ||
parser.resolveOperands(indexInfo, affineIntTy, result.operands) ||
parser.resolveOperands(indexInfo, indexTy, result.operands) ||
parser.addTypeToList(type.getElementType(), result.types));
}

Expand All @@ -438,15 +438,14 @@ static void print(OpAsmPrinter &p, RangeOp op) {
static ParseResult parseRangeOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 3> rangeInfo(3);
RangeType type;
auto affineIntTy = parser.getBuilder().getIndexType();
return failure(
parser.parseOperand(rangeInfo[0]) || parser.parseColon() ||
parser.parseOperand(rangeInfo[1]) || parser.parseColon() ||
parser.parseOperand(rangeInfo[2]) ||
parser.parseOptionalAttributeDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperands(rangeInfo, affineIntTy, result.operands) ||
parser.addTypeToList(type, result.types));
auto indexTy = parser.getBuilder().getIndexType();
return failure(parser.parseOperand(rangeInfo[0]) || parser.parseColon() ||
parser.parseOperand(rangeInfo[1]) || parser.parseColon() ||
parser.parseOperand(rangeInfo[2]) ||
parser.parseOptionalAttributeDict(result.attributes) ||
parser.parseColonType(type) ||
parser.resolveOperands(rangeInfo, indexTy, result.operands) ||
parser.addTypeToList(type, result.types));
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -538,7 +537,7 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
ViewType viewType;

auto affineIntTy = parser.getBuilder().getIndexType();
auto indexTy = parser.getBuilder().getIndexType();
return failure(
parser.parseOperand(storeValueInfo) || parser.parseComma() ||
parser.parseOperand(viewInfo) ||
Expand All @@ -548,7 +547,7 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
parser.resolveOperand(storeValueInfo, viewType.getElementType(),
result.operands) ||
parser.resolveOperand(viewInfo, viewType, result.operands) ||
parser.resolveOperands(indexInfo, affineIntTy, result.operands));
parser.resolveOperands(indexInfo, indexTy, result.operands));
}

static LogicalResult verify(linalg::StoreOp op) {
Expand Down

0 comments on commit 74eabdd

Please sign in to comment.