Skip to content

Commit

Permalink
[mlir] add permutation utility
Browse files Browse the repository at this point in the history
I found myself typing this code several times at different places
by now, so time to make this a general utility instead. Given
a permutation, it returns the permuted position of the input,
for example (i,j,k) -> (k,i,j) yields position 1 for input 0.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D108347
  • Loading branch information
aartbik committed Aug 24, 2021
1 parent 194b080 commit a643bd3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir/IR/AffineMap.h
Expand Up @@ -162,6 +162,10 @@ class AffineMap {
/// when the caller knows it is safe to do so.
unsigned getDimPosition(unsigned idx) const;

/// Extracts the permuted position where given input index resides.
/// Fails when called on a non-permutation.
unsigned getPermutedPosition(unsigned input) const;

/// Return true if any affine expression involves AffineDimExpr `position`.
bool isFunctionOfDim(unsigned position) const {
return llvm::any_of(getResults(), [&](AffineExpr e) {
Expand Down
8 changes: 8 additions & 0 deletions mlir/lib/IR/AffineMap.cpp
Expand Up @@ -336,6 +336,14 @@ unsigned AffineMap::getDimPosition(unsigned idx) const {
return getResult(idx).cast<AffineDimExpr>().getPosition();
}

unsigned AffineMap::getPermutedPosition(unsigned input) const {
assert(isPermutation() && "invalid permutation request");
for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++)
if (getDimPosition(i) == input)
return i;
llvm_unreachable("incorrect permutation request");
}

/// Folds the results of the application of an affine map on the provided
/// operands to a constant if possible. Returns false if the folding happens,
/// true otherwise.
Expand Down

0 comments on commit a643bd3

Please sign in to comment.