Skip to content

Commit

Permalink
Extract MemRefType::getStridesAndOffset as a free function and fix dy…
Browse files Browse the repository at this point in the history
…namic offset determination.

This also adds coverage with a missing test, which uncovered a bug in the conditional for testing whether an offset is dynamic or not.

PiperOrigin-RevId: 272505798
  • Loading branch information
Nicolas Vasilache authored and tensorflower-gardener committed Oct 2, 2019
1 parent f294e0e commit 9604bb6
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 42 deletions.
48 changes: 25 additions & 23 deletions mlir/include/mlir/IR/StandardTypes.h
Expand Up @@ -367,31 +367,8 @@ class MemRefType
/// Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpace() const;

/// Returns the strides of the MemRef if the layout map is in strided form.
/// MemRefs with layout maps in strided form include:
/// 1. empty or identity layout map, in which case the stride information is
/// the canonical form computed from sizes;
/// 2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`,
/// where K and ki's are constants or symbols.
///
/// A stride specification is a list of integer values that are either static
/// or dynamic (encoded with kDynamicStrideOrOffset). Strides encode the
/// distance in the number of elements between successive entries along a
/// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>`
/// specifies a view into a non-contiguous memory region of `42` by `16` `f32`
/// elements in which the distance between two consecutive elements along the
/// outer dimension is `1` and the distance between two consecutive elements
/// along the inner dimension is `64`.
///
/// If a simple strided form cannot be extracted from the composition of the
/// layout map, returns llvm::None.
///
/// The convention is that the strides for dimensions d0, .. dn appear in
/// order to make indexing intuitive into the result.
static constexpr int64_t kDynamicStrideOrOffset =
std::numeric_limits<int64_t>::min();
LogicalResult getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
int64_t &offset) const;

static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; }

Expand Down Expand Up @@ -492,6 +469,31 @@ class NoneType : public Type::TypeBase<NoneType, Type> {
static bool kindof(unsigned kind) { return kind == StandardTypes::None; }
};

/// Returns the strides of the MemRef if the layout map is in strided form.
/// MemRefs with layout maps in strided form include:
/// 1. empty or identity layout map, in which case the stride information is
/// the canonical form computed from sizes;
/// 2. single affine map layout of the form `K + k0 * d0 + ... kn * dn`,
/// where K and ki's are constants or symbols.
///
/// A stride specification is a list of integer values that are either static
/// or dynamic (encoded with kDynamicStrideOrOffset). Strides encode the
/// distance in the number of elements between successive entries along a
/// particular dimension. For example, `memref<42x16xf32, (64 * d0 + d1)>`
/// specifies a view into a non-contiguous memory region of `42` by `16` `f32`
/// elements in which the distance between two consecutive elements along the
/// outer dimension is `1` and the distance between two consecutive elements
/// along the inner dimension is `64`.
///
/// If a simple strided form cannot be extracted from the composition of the
/// layout map, returns llvm::None.
///
/// The convention is that the strides for dimensions d0, .. dn appear in
/// order to make indexing intuitive into the result.
LogicalResult getStridesAndOffset(MemRefType t,
SmallVectorImpl<int64_t> &strides,
int64_t &offset);

/// Given a list of strides (in which MemRefType::kDynamicStrideOrOffset
/// represents a dynamic value), return the single result AffineMap which
/// represents the linearized strided layout map. Dimensions correspond to the
Expand Down
14 changes: 7 additions & 7 deletions mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Expand Up @@ -152,7 +152,7 @@ static unsigned kStridePosInMemRefDescriptor = 3;
Type LLVMTypeConverter::convertMemRefType(MemRefType type) {
int64_t offset;
SmallVector<int64_t, 4> strides;
bool strideSuccess = succeeded(type.getStridesAndOffset(strides, offset));
bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset));
assert(strideSuccess &&
"Non-strided layout maps must have been normalized away");
(void)strideSuccess;
Expand Down Expand Up @@ -571,14 +571,14 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {

int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = type.getStridesAndOffset(strides, offset);
auto successStrides = getStridesAndOffset(type, strides, offset);
if (failed(successStrides))
return matchFailure();

// Dynamic strides are ok if they can be deduced from dynamic sizes (which
// is guaranteed when succeeded(successStrides)).
// Dynamic offset however can never be alloc'ed.
if (offset != MemRefType::kDynamicStrideOrOffset)
// is guaranteed when succeeded(successStrides)). Dynamic offset however can
// never be alloc'ed.
if (offset == MemRefType::kDynamicStrideOrOffset)
return matchFailure();

return matchSuccess();
Expand Down Expand Up @@ -652,7 +652,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {

int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = type.getStridesAndOffset(strides, offset);
auto successStrides = getStridesAndOffset(type, strides, offset);
assert(succeeded(successStrides) && "unexpected non-strided memref");
(void)successStrides;
assert(offset != MemRefType::kDynamicStrideOrOffset &&
Expand Down Expand Up @@ -952,7 +952,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern<Derived> {
auto ptrType = getMemRefElementPtrType(type, this->lowering);
int64_t offset;
SmallVector<int64_t, 4> strides;
auto successStrides = type.getStridesAndOffset(strides, offset);
auto successStrides = getStridesAndOffset(type, strides, offset);
assert(succeeded(successStrides) && "unexpected non-strided memref");
(void)successStrides;
return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Expand Up @@ -322,7 +322,7 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result,
auto memRefType = base->getType().cast<MemRefType>();
int64_t offset;
SmallVector<int64_t, 4> strides;
auto res = memRefType.getStridesAndOffset(strides, offset);
auto res = getStridesAndOffset(memRefType, strides, offset);
assert(succeeded(res) && strides.size() == indexings.size());
(void)res;

Expand Down Expand Up @@ -466,7 +466,7 @@ void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result,
// Compute permuted strides.
int64_t offset;
SmallVector<int64_t, 4> strides;
auto res = memRefType.getStridesAndOffset(strides, offset);
auto res = getStridesAndOffset(memRefType, strides, offset);
assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
(void)res;
auto map = makeStridedLinearLayoutMap(strides, offset, b->getContext());
Expand Down
18 changes: 9 additions & 9 deletions mlir/lib/IR/StandardTypes.cpp
Expand Up @@ -23,7 +23,6 @@
#include "mlir/Support/STLExtras.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/raw_ostream.h"

using namespace mlir;
using namespace mlir::detail;
Expand Down Expand Up @@ -544,9 +543,10 @@ static void extractStridesFromTerm(AffineExpr e,
llvm_unreachable("unexpected binary operation");
}

LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
int64_t &offset) const {
auto affineMaps = getAffineMaps();
LogicalResult mlir::getStridesAndOffset(MemRefType t,
SmallVectorImpl<int64_t> &strides,
int64_t &offset) {
auto affineMaps = t.getAffineMaps();
// For now strides are only computed on a single affine map with a single
// result (i.e. the closed subset of linearization maps that are compatible
// with striding semantics).
Expand All @@ -555,22 +555,22 @@ LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl<int64_t> &strides,
return failure();
AffineExpr stridedExpr;
if (affineMaps.empty() || affineMaps[0].isIdentity()) {
if (getRank() == 0) {
if (t.getRank() == 0) {
// Handle 0-D corner case.
offset = 0;
return success();
}
stridedExpr = makeCanonicalStridedLayoutExpr(getShape(), getContext());
stridedExpr = makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
} else if (affineMaps[0].getNumResults() == 1) {
stridedExpr = affineMaps[0].getResult(0);
}
if (!stridedExpr)
return failure();

bool failed = false;
strides = SmallVector<int64_t, 4>(getRank(), 0);
strides = SmallVector<int64_t, 4>(t.getRank(), 0);
bool seenOffset = false;
SmallVector<bool, 4> seen(getRank(), false);
SmallVector<bool, 4> seen(t.getRank(), false);
if (stridedExpr.isa<AffineBinaryOpExpr>()) {
stridedExpr.walk([&](AffineExpr e) {
if (!failed)
Expand Down Expand Up @@ -688,6 +688,6 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
bool mlir::isStrided(MemRefType t) {
int64_t offset;
SmallVector<int64_t, 4> stridesAndOffset;
auto res = t.getStridesAndOffset(stridesAndOffset, offset);
auto res = getStridesAndOffset(t, stridesAndOffset, offset);
return succeeded(res);
}
6 changes: 6 additions & 0 deletions mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir
Expand Up @@ -11,3 +11,9 @@ func @address_space(%arg0 : memref<32xf32, (d0) -> (d0), 7>) {
std.return
}

// CHECK-LABEL: func @strided_memref(
func @strided_memref(%ind: index) {
%0 = alloc()[%ind] : memref<32x64xf32, (i, j)[M] -> (32 + M * i + j)>
std.return
}

2 changes: 1 addition & 1 deletion mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp
Expand Up @@ -37,7 +37,7 @@ void TestMemRefStrideCalculation::runOnFunction() {
auto memrefType = allocOp.getResult()->getType().cast<MemRefType>();
int64_t offset;
SmallVector<int64_t, 4> strides;
if (failed(memrefType.getStridesAndOffset(strides, offset))) {
if (failed(getStridesAndOffset(memrefType, strides, offset))) {
llvm::outs() << "MemRefType " << memrefType << " cannot be converted to "
<< "strided form\n";
return;
Expand Down

0 comments on commit 9604bb6

Please sign in to comment.