From a643bd3189aeee138187509a9bfec2f798798d76 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 23 Aug 2021 19:00:38 -0700 Subject: [PATCH] [mlir] add permutation utility I found myself typing this code several times at different places by now, so time to make this a general utility instead. Given a permutation, it returns the permuted position of the input, for example (i,j,k) -> (k,i,j) yields position 1 for input 0. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D108347 --- mlir/include/mlir/IR/AffineMap.h | 4 ++++ mlir/lib/IR/AffineMap.cpp | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index f687253b36feb..906c53db4b32d 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -162,6 +162,10 @@ class AffineMap { /// when the caller knows it is safe to do so. unsigned getDimPosition(unsigned idx) const; + /// Extracts the permuted position where given input index resides. + /// Fails when called on a non-permutation. + unsigned getPermutedPosition(unsigned input) const; + /// Return true if any affine expression involves AffineDimExpr `position`. bool isFunctionOfDim(unsigned position) const { return llvm::any_of(getResults(), [&](AffineExpr e) { diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 2257617bd903e..9c6f25d3c53e0 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -336,6 +336,14 @@ unsigned AffineMap::getDimPosition(unsigned idx) const { return getResult(idx).cast().getPosition(); } +unsigned AffineMap::getPermutedPosition(unsigned input) const { + assert(isPermutation() && "invalid permutation request"); + for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++) + if (getDimPosition(i) == input) + return i; + llvm_unreachable("incorrect permutation request"); +} + /// Folds the results of the application of an affine map on the provided /// operands to a constant if possible. Returns false if the folding happens, /// true otherwise.