Skip to content

Commit

Permalink
[mlir][Linalg] Revisit 0-D abstraction
Browse files Browse the repository at this point in the history
This revision takes advantage of the empty AffineMap to specify the
0-D edge case. This allows removing a bunch of annoying corner cases
that ended up impacting users of Linalg.

Differential Revision: https://reviews.llvm.org/D75831
  • Loading branch information
Nicolas Vasilache committed Mar 10, 2020
1 parent 4a0267e commit 47ec870
Show file tree
Hide file tree
Showing 11 changed files with 56 additions and 77 deletions.
3 changes: 2 additions & 1 deletion mlir/docs/Dialects/Affine.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ affine-expr ::= `(` affine-expr `)`
| bare-id
| `-`? integer-literal
multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)`
multi-dim-affine-expr ::= `(` `)`
| `(` affine-expr (`,` affine-expr)* `)`
```

`ceildiv` is the ceiling function which maps the result of the division of its
Expand Down
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ def DotOp : LinalgStructured_Op<"dot", [NInputs<2>, NOutputs<1>]> {
MLIRContext *context = getContext();
auto r_i = getAffineDimExpr(0, context);
return SmallVector<AffineMap, 8>{
AffineMap::get(1, 0, {r_i}), AffineMap::get(1, 0, {r_i}), AffineMap()};
AffineMap::get(1, 0, {r_i}),
AffineMap::get(1, 0, {r_i}),
AffineMap::get(1, 0, context)};
}
}];

Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/IR/AffineMap.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ class AffineMap {
/// Returns a zero result affine map with no dimensions or symbols: () -> ().
static AffineMap get(MLIRContext *context);

/// Returns a zero result affine map with `dimCount` dimensions and
/// `symbolCount` symbols, e.g.: `(...) -> ()`.
static AffineMap get(unsigned dimCount, unsigned symbolCount,
MLIRContext *context);

static AffineMap get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results);

Expand Down Expand Up @@ -275,8 +280,7 @@ inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
namespace llvm {

// AffineExpr hash just like pointers
template <>
struct DenseMapInfo<mlir::AffineMap> {
template <> struct DenseMapInfo<mlir::AffineMap> {
static mlir::AffineMap getEmptyKey() {
auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
return mlir::AffineMap(static_cast<mlir::AffineMap::ImplType *>(pointer));
Expand Down
10 changes: 2 additions & 8 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,15 +356,9 @@ static LogicalResult verifyGenericOp(GenericOpType op) {
<< idx << " to have " << nLoops
<< " dim(s) to match the number of loops";

if (m.getNumResults() == 1 && view.getRank() == 0) {
auto cst = m.getResult(0).template dyn_cast<AffineConstantExpr>();
if (!cst || cst.getValue() != 0)
return op.emitOpError("expected indexing_map #")
<< idx << " to be 0 to match 0-D view: " << view;
} else if (m.getNumResults() != view.getRank()) {
if (m.getNumResults() != view.getRank())
return op.emitOpError("expected indexing_map #")
<< idx << " results to match view rank: " << view;
}
}

auto concatMap = concatAffineMaps(indexingMaps);
Expand Down Expand Up @@ -886,7 +880,7 @@ AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap,
if (maybeMap)
return maybeMap.getValue();
if (rank == 0)
return AffineMap();
return AffineMap::get(context);
return AffineMap::getMultiDimIdentityMap(rank, context);
}

Expand Down
46 changes: 15 additions & 31 deletions mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ using edsc::op::operator==;
static SmallVector<ValueHandle, 8>
makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map,
ArrayRef<Value> vals) {
if (map.isEmpty())
return {};
assert(map.getNumSymbols() == 0);
assert(map.getNumInputs() == vals.size());
SmallVector<ValueHandle, 8> res;
Expand Down Expand Up @@ -241,26 +243,17 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {

// 1.a. Emit std_load from input views.
for (unsigned i = 0; i < nInputs; ++i) {
Value input = genericOp.getInput(i);
if (input.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getInputIndexingMap(i), allIvs));
indexedValues[i] = std_load(input, indexing);
} else {
indexedValues[i] = std_load(input);
}
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getInputIndexingMap(i), allIvs));
indexedValues[i] = std_load(genericOp.getInput(i), indexing);
}

// 1.b. Emit std_load from output views.
for (unsigned i = 0; i < nOutputs; ++i) {
Value output = genericOp.getOutputBuffer(i);
if (output.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nInputs + i] = std_load(output, indexing);
} else {
indexedValues[nInputs + i] = std_load(output);
}
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
indexedValues[nInputs + i] = std_load(output, indexing);
}

auto funcOp = genericOp.getFunction();
Expand All @@ -272,13 +265,9 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
// 3. Emit std_store.
for (unsigned i = 0; i < nOutputs; ++i) {
Value output = genericOp.getOutputBuffer(i);
if (output.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
std_store(callOp->getResult(i), output, indexing);
} else {
std_store(callOp->getResult(i), output);
}
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
std_store(callOp->getResult(i), output, indexing);
}
return;
}
Expand All @@ -297,15 +286,10 @@ class LinalgScopedEmitter<IndexedValueType, GenericOp> {
auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
assert(yieldOp->getNumOperands() == nOutputs);
for (unsigned i = 0; i < nOutputs; ++i) {
Value output = genericOp.getOutputBuffer(i);
if (output.getType().cast<ShapedType>().getRank()) {
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)),
genericOp.getOutputBuffer(i), indexing);
} else {
std_store(map.lookup(yieldOp->getOperand(i)), output);
}
ValueHandleArray indexing(makeCanonicalAffineApplies(
b, loc, genericOp.getOutputIndexingMap(i), allIvs));
std_store(map.lookup(yieldOp->getOperand(i)),
genericOp.getOutputBuffer(i), indexing);
}
}
};
Expand Down
13 changes: 7 additions & 6 deletions mlir/lib/IR/AffineMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ AffineMap AffineMap::compose(AffineMap map) {
exprs.reserve(getResults().size());
for (auto expr : getResults())
exprs.push_back(expr.compose(newMap));
return AffineMap::get(numDims, numSymbols, exprs);
return exprs.empty() ? AffineMap::get(numDims, 0, map.getContext())
: AffineMap::get(numDims, numSymbols, exprs);
}

bool AffineMap::isProjectedPermutation() {
Expand Down Expand Up @@ -325,7 +326,7 @@ AffineMap mlir::simplifyAffineMap(AffineMap map) {
}

AffineMap mlir::inversePermutation(AffineMap map) {
if (!map)
if (map.isEmpty())
return map;
assert(map.getNumSymbols() == 0 && "expected map without symbols");
SmallVector<AffineExpr, 4> exprs(map.getNumDims());
Expand All @@ -351,18 +352,18 @@ AffineMap mlir::inversePermutation(AffineMap map) {
AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
unsigned numResults = 0;
for (auto m : maps)
numResults += (m && !m.isSingleConstant()) ? m.getNumResults() : 0;
numResults += m.getNumResults();
unsigned numDims = 0;
SmallVector<AffineExpr, 8> results;
results.reserve(numResults);
for (auto m : maps) {
if (!m || m.isSingleConstant())
continue;
assert(m.getNumSymbols() == 0 && "expected map without symbols");
results.append(m.getResults().begin(), m.getResults().end());
numDims = std::max(m.getNumDims(), numDims);
}
return numDims == 0 ? AffineMap() : AffineMap::get(numDims, 0, results);
return results.empty() ? AffineMap::get(numDims, /*numSymbols=*/0,
maps.front().getContext())
: AffineMap::get(numDims, /*numSymbols=*/0, results);
}

//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,11 @@ AffineMap AffineMap::get(MLIRContext *context) {
return getImpl(/*dimCount=*/0, /*symbolCount=*/0, /*results=*/{}, context);
}

AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
MLIRContext *context) {
return getImpl(dimCount, /*symbolCount=*/0, /*results=*/{}, context);
}

AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
ArrayRef<AffineExpr> results) {
// The number of results can't be zero.
Expand Down
15 changes: 9 additions & 6 deletions mlir/lib/Parser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3068,14 +3068,16 @@ AffineParser::parseAffineMapOfSSAIds(AffineMap &map,
};

// Parse a multi-dimensional affine expression (a comma-separated list of
// 1-d affine expressions); the list cannot be empty. Grammar:
// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
// 1-d affine expressions); the list can be empty. Grammar:
// multi-dim-affine-expr ::= `(` `)`
// | `(` affine-expr (`,` affine-expr)* `)`
if (parseCommaSeparatedListUntil(rightToken, parseElt,
/*allowEmptyList=*/true))
return failure();
// Parsed a valid affine map.
if (exprs.empty())
map = AffineMap::get(getContext());
map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
getContext());
else
map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
exprs);
Expand All @@ -3101,13 +3103,14 @@ AffineMap AffineParser::parseAffineMapRange(unsigned numDims,
};

// Parse a multi-dimensional affine expression (a comma-separated list of
// 1-d affine expressions); the list cannot be empty. Grammar:
// multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
// 1-d affine expressions). Grammar:
// multi-dim-affine-expr ::= `(` `)`
// | `(` affine-expr (`,` affine-expr)* `)`
if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, true))
return AffineMap();

if (exprs.empty())
return AffineMap::get(getContext());
return AffineMap::get(numDims, numSymbols, getContext());

// Parsed a valid affine map.
return AffineMap::get(numDims, numSymbols, exprs);
Expand Down
19 changes: 2 additions & 17 deletions mlir/test/Dialect/Linalg/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -170,30 +170,15 @@ func @generic_symbol_in_map(%arg0: memref<i32>) {

func @foo(%0: i32) -> i32 { return %0: i32 }

func @generic_wrong_dim_in_map(%arg0: memref<i32>) {
func @generic_wrong_dim_in_map(%arg0: memref<1xi32>) {
// expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
linalg.generic {
args_in = 0,
args_out = 1,
fun = @foo,
indexing_maps = [ affine_map<() -> (0)> ],
iterator_types = ["parallel"]
} %arg0: memref<i32>
}

// -----

func @foo(%0: i32) -> i32 { return %0: i32 }

func @generic_zero_d_view(%arg0: memref<i32>) {
// expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: 'memref<i32>'}}
linalg.generic {
args_in = 0,
args_out = 1,
fun = @foo,
indexing_maps = [ affine_map<() -> (1)> ],
iterator_types = []
} %arg0: memref<i32>
} %arg0: memref<1xi32>
}

// -----
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/Linalg/loops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ func @indexed_generic_region(
// -----

#broadcast_access = [
affine_map<(i, j) -> (0)>,
affine_map<(i, j) -> ()>,
affine_map<(i, j) -> (i, j)>
]

Expand Down Expand Up @@ -414,7 +414,7 @@ func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)

#reduce_1D_access = [
affine_map<(i) -> (i)>,
affine_map<(i) -> (0)>
affine_map<(i) -> ()>
]

#trait_reduce_1D = {
Expand Down Expand Up @@ -446,8 +446,8 @@ func @generic_op_1D_reduce(%arg0: memref<?xf32>, %arg1: memref<f32>)

#reduce_init_1D_access = [
affine_map<(i) -> (i)>,
affine_map<(i) -> (0)>,
affine_map<(i) -> (0)>
affine_map<(i) -> ()>,
affine_map<(i) -> ()>
]

#trait_reduce_init_1D = {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Linalg/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ func @indexed_generic_with_tensor_input_and_output(
// -----

#broadcast_access = [
affine_map<(i, j) -> (0)>,
affine_map<(i, j) -> ()>,
affine_map<(i, j) -> (i, j)>
]

Expand Down

0 comments on commit 47ec870

Please sign in to comment.