Skip to content

Commit

Permalink
[mlir] make remaining memref dialect ops produce strided layouts
Browse files Browse the repository at this point in the history
The three following ops in the memref dialect: transpose, expand_shape,
collapse_shape, have been originally designed to operate on memrefs with
strided layouts but had to go through the affine map representation as the type
did not support anything else. Make these ops produce memref values with
StridedLayoutAttr instead now that it is available.

Depends On D133938

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D133947
  • Loading branch information
ftynse committed Sep 16, 2022
1 parent 2791162 commit 46b90a7
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 143 deletions.
5 changes: 0 additions & 5 deletions mlir/include/mlir/IR/BuiltinTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,6 @@ bool isStrided(MemRefType t);
/// Return null if the layout is not compatible with a strided layout.
AffineMap getStridedLinearLayoutMap(MemRefType t);

/// Helper determining if a memref is static-shape and contiguous-row-major
/// layout, while still allowing for an arbitrary offset (any static or
/// dynamic value).
bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType);

} // namespace mlir

#endif // MLIR_IR_BUILTINTYPES_H
20 changes: 19 additions & 1 deletion mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,25 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
auto srcType = op.getSource().getType().cast<BaseMemRefType>();
auto targetType = op.getTarget().getType().cast<BaseMemRefType>();

auto isContiguousMemrefType = [](BaseMemRefType type) {
auto isStaticShapeAndContiguousRowMajor = [](MemRefType type) {
if (!type.hasStaticShape())
return false;

SmallVector<int64_t> strides;
int64_t offset;
if (failed(getStridesAndOffset(type, strides, offset)))
return false;

int64_t runningStride = 1;
for (unsigned i = strides.size(); i > 0; --i) {
if (strides[i - 1] != runningStride)
return false;
runningStride *= type.getDimSize(i - 1);
}
return true;
};

auto isContiguousMemrefType = [&](BaseMemRefType type) {
auto memrefType = type.dyn_cast<mlir::MemRefType>();
// We can use memcpy for memrefs if they have an identity layout or are
// contiguous with an arbitrary offset. Ignore empty memrefs, which is a
Expand Down
75 changes: 34 additions & 41 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1761,7 +1761,7 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {

/// Compute the layout map after expanding a given source MemRef type with the
/// specified reassociation indices.
static FailureOr<AffineMap>
static FailureOr<StridedLayoutAttr>
computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
ArrayRef<ReassociationIndices> reassociation) {
int64_t srcOffset;
Expand Down Expand Up @@ -1798,8 +1798,7 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
}
auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
resultStrides.resize(resultShape.size(), 1);
return makeStridedLinearLayoutMap(resultStrides, srcOffset,
srcType.getContext());
return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
}

static FailureOr<MemRefType>
Expand All @@ -1814,14 +1813,12 @@ computeExpandedType(MemRefType srcType, ArrayRef<int64_t> resultShape,
}

// Source may not be contiguous. Compute the layout map.
FailureOr<AffineMap> computedLayout =
FailureOr<StridedLayoutAttr> computedLayout =
computeExpandedLayoutMap(srcType, resultShape, reassociation);
if (failed(computedLayout))
return failure();
auto computedType =
MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
srcType.getMemorySpaceAsInt());
return canonicalizeStridedLayout(computedType);
return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
srcType.getMemorySpace());
}

void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
Expand Down Expand Up @@ -1855,10 +1852,9 @@ LogicalResult ExpandShapeOp::verify() {
return emitOpError("invalid source layout map");

// Check actual result type.
auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
if (*expectedResultType != canonicalizedResultType)
if (*expectedResultType != resultType)
return emitOpError("expected expanded type to be ")
<< *expectedResultType << " but found " << canonicalizedResultType;
<< *expectedResultType << " but found " << resultType;

return success();
}
Expand All @@ -1877,7 +1873,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
/// not possible to check this by inspecting a MemRefType in the general case.
/// If non-contiguity cannot be checked statically, the collapse is assumed to
/// be valid (and thus accepted by this function) unless `strict = true`.
static FailureOr<AffineMap>
static FailureOr<StridedLayoutAttr>
computeCollapsedLayoutMap(MemRefType srcType,
ArrayRef<ReassociationIndices> reassociation,
bool strict = false) {
Expand Down Expand Up @@ -1940,13 +1936,12 @@ computeCollapsedLayoutMap(MemRefType srcType,
return failure();
}
}
return makeStridedLinearLayoutMap(resultStrides, srcOffset,
srcType.getContext());
return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
}

bool CollapseShapeOp::isGuaranteedCollapsible(
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
// MemRefs with standard layout are always collapsible.
// MemRefs with identity layout are always collapsible.
if (srcType.getLayout().isIdentity())
return true;

Expand Down Expand Up @@ -1978,14 +1973,12 @@ computeCollapsedType(MemRefType srcType,
// Source may not be fully contiguous. Compute the layout map.
// Note: Dimensions that are collapsed into a single dim are assumed to be
// contiguous.
FailureOr<AffineMap> computedLayout =
FailureOr<StridedLayoutAttr> computedLayout =
computeCollapsedLayoutMap(srcType, reassociation);
assert(succeeded(computedLayout) &&
"invalid source layout map or collapsing non-contiguous dims");
auto computedType =
MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
srcType.getMemorySpaceAsInt());
return canonicalizeStridedLayout(computedType);
return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
srcType.getMemorySpace());
}

void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
Expand Down Expand Up @@ -2021,21 +2014,19 @@ LogicalResult CollapseShapeOp::verify() {
// Source may not be fully contiguous. Compute the layout map.
// Note: Dimensions that are collapsed into a single dim are assumed to be
// contiguous.
FailureOr<AffineMap> computedLayout =
FailureOr<StridedLayoutAttr> computedLayout =
computeCollapsedLayoutMap(srcType, getReassociationIndices());
if (failed(computedLayout))
return emitOpError(
"invalid source layout map or collapsing non-contiguous dims");
auto computedType =
expectedResultType =
MemRefType::get(resultType.getShape(), srcType.getElementType(),
*computedLayout, srcType.getMemorySpaceAsInt());
expectedResultType = canonicalizeStridedLayout(computedType);
*computedLayout, srcType.getMemorySpace());
}

auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
if (expectedResultType != canonicalizedResultType)
if (expectedResultType != resultType)
return emitOpError("expected collapsed type to be ")
<< expectedResultType << " but found " << canonicalizedResultType;
<< expectedResultType << " but found " << resultType;

return success();
}
Expand Down Expand Up @@ -2709,24 +2700,26 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
AffineMap permutationMap) {
auto rank = memRefType.getRank();
auto originalSizes = memRefType.getShape();
// Compute permuted sizes.
SmallVector<int64_t, 4> sizes(rank, 0);
for (const auto &en : llvm::enumerate(permutationMap.getResults()))
sizes[en.index()] =
originalSizes[en.value().cast<AffineDimExpr>().getPosition()];

// Compute permuted strides.
int64_t offset;
SmallVector<int64_t, 4> strides;
auto res = getStridesAndOffset(memRefType, strides, offset);
assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
SmallVector<int64_t, 4> originalStrides;
auto res = getStridesAndOffset(memRefType, originalStrides, offset);
assert(succeeded(res) &&
originalStrides.size() == static_cast<unsigned>(rank));
(void)res;
auto map =
makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
map = permutationMap ? map.compose(permutationMap) : map;

// Compute permuted sizes and strides.
SmallVector<int64_t> sizes(rank, 0);
SmallVector<int64_t> strides(rank, 1);
for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
unsigned position = en.value().cast<AffineDimExpr>().getPosition();
sizes[en.index()] = originalSizes[position];
strides[en.index()] = originalStrides[position];
}

return MemRefType::Builder(memRefType)
.setShape(sizes)
.setLayout(AffineMapAttr::get(map));
.setLayout(
StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
}

void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,10 @@ struct CollapseShapeOpInterface
int64_t offset;
if (failed(getStridesAndOffset(bufferType, strides, offset)))
return failure();
AffineMap resultLayout =
makeStridedLinearLayoutMap({}, offset, op->getContext());
resultType =
MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
bufferType.getMemorySpaceAsInt());
resultType = MemRefType::get(
{}, tensorResultType.getElementType(),
StridedLayoutAttr::get(op->getContext(), offset, {}),
bufferType.getMemorySpace());
}

replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(
Expand Down
3 changes: 2 additions & 1 deletion mlir/lib/IR/AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2250,7 +2250,8 @@ void AsmPrinter::Impl::printType(Type type) {
os << 'x';
}
printType(memrefTy.getElementType());
if (!memrefTy.getLayout().isIdentity()) {
MemRefLayoutAttrInterface layout = memrefTy.getLayout();
if (!layout.isa<AffineMapAttr>() || !layout.isIdentity()) {
os << ", ";
printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
}
Expand Down
37 changes: 0 additions & 37 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,40 +1027,3 @@ AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
return AffineMap();
return makeStridedLinearLayoutMap(strides, offset, t.getContext());
}

/// Return the AffineExpr representation of the offset, assuming `memRefType`
/// is a strided memref.
static AffineExpr getOffsetExpr(MemRefType memrefType) {
SmallVector<AffineExpr> strides;
AffineExpr offset;
if (failed(getStridesAndOffset(memrefType, strides, offset)))
assert(false && "expected strided memref");
return offset;
}

/// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
/// `offset` AffineExpr.
static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
ArrayRef<int64_t> shape,
Type elementType,
AffineExpr offset) {
AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
AffineExpr contiguousRowMajor = canonical + offset;
AffineMap contiguousRowMajorMap =
AffineMap::inferFromExprList({contiguousRowMajor})[0];
return MemRefType::get(shape, elementType, contiguousRowMajorMap);
}

/// Helper determining if a memref is static-shape and contiguous-row-major
/// layout, while still allowing for an arbitrary offset (any static or
/// dynamic value).
bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
if (!memrefType.hasStaticShape())
return false;
AffineExpr offset = getOffsetExpr(memrefType);
MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
memrefType.getContext(), memrefType.getShape(),
memrefType.getElementType(), offset);
return canonicalizeStridedLayout(memrefType) ==
canonicalizeStridedLayout(contiguousRowMajorMemRefType);
}
36 changes: 17 additions & 19 deletions mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
// CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
%0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d0 * s2 + d1)>>
%0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
return
}

Expand Down Expand Up @@ -725,12 +725,12 @@ func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf
// -----

func.func @collapse_shape_dynamic_with_non_identity_layout(
%arg0 : memref<4x?x?xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)>>) ->
memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> {
%arg0 : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) ->
memref<4x?xf32, strided<[?, ?], offset: ?>> {
%0 = memref.collapse_shape %arg0 [[0], [1, 2]]:
memref<4x?x?xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)>> into
memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
return %0 : memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> into
memref<4x?xf32, strided<[?, ?], offset: ?>>
return %0 : memref<4x?xf32, strided<[?, ?], offset: ?>>
}
// CHECK-LABEL: func @collapse_shape_dynamic_with_non_identity_layout(
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
Expand Down Expand Up @@ -898,12 +898,12 @@ func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
// -----

func.func @expand_shape_dynamic_with_non_identity_layout(
%arg0 : memref<1x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>) ->
memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>> {
%arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>) ->
memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
%0 = memref.expand_shape %arg0 [[0], [1, 2]]:
memref<1x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> into
memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>
return %0 : memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>
memref<1x?xf32, strided<[?, ?], offset: ?>> into
memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
}
// CHECK-LABEL: func @expand_shape_dynamic_with_non_identity_layout(
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
Expand Down Expand Up @@ -982,10 +982,10 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
// -----

// CHECK-LABEL: func @collapse_static_shape_with_non_identity_layout
func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>>) -> memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> {
func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>>) -> memref<64xf32, strided<[1], offset: ?>> {
// CHECK-NOT: memref.collapse_shape
%1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>> into memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
return %1 : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
%1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>> into memref<64xf32, strided<[1], offset: ?>>
return %1 : memref<64xf32, strided<[1], offset: ?>>
}

// -----
Expand Down Expand Up @@ -1069,13 +1069,11 @@ func.func @memref_copy_contiguous(%in: memref<16x2xi32>, %offset: index) {
// -----

// CHECK-LABEL: func @memref_copy_0d_offset
#map0 = affine_map<(d0) -> (d0 + 1)>
#map1 = affine_map<() -> (1)>
func.func @memref_copy_0d_offset(%in: memref<2xi32>) {
%buf = memref.alloc() : memref<i32>
%sub = memref.subview %in[1] [1] [1] : memref<2xi32> to memref<1xi32, #map0>
%scalar = memref.collapse_shape %sub [] : memref<1xi32, #map0> into memref<i32, #map1>
memref.copy %scalar, %buf : memref<i32, #map1> to memref<i32>
%sub = memref.subview %in[1] [1] [1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>>
%scalar = memref.collapse_shape %sub [] : memref<1xi32, strided<[1], offset: 1>> into memref<i32, strided<[], offset: 1>>
memref.copy %scalar, %buf : memref<i32, strided<[], offset: 1>> to memref<i32>
// CHECK: llvm.intr.memcpy
return
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ func.func @buffer_forwarding_conflict(
%f = linalg.fill ins(%f0 : f32) outs(%a : tensor<?xf32>) -> tensor<?xf32>

// CHECK: memref.copy %[[FUNC_ARG]], %[[ALLOC]] : memref<?xf32> to memref<?xf32>
// CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref<?xf32> to memref<?xf32>
// CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]] : memref<?xf32> to memref<?xf32>
// CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref<?xf32> to memref<?xf32, strided<[1]>>
// CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]] : memref<?xf32> to memref<?xf32, strided<[1]>>
%r0 = tensor.insert_slice %f into %t[0][%sz][1]: tensor<?xf32> into tensor<?xf32>

// CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]
Expand Down
6 changes: 2 additions & 4 deletions mlir/test/Dialect/Linalg/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
// Test that we can lower all the way to LLVM without crashing, don't check results here.
// DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1

// CHECK: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>

func.func @views(%arg0: index) {
%c0 = arith.constant 0 : index
%0 = arith.muli %arg0, %arg0 : index
Expand Down Expand Up @@ -70,12 +68,12 @@ func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32)
// -----

func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
%0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>>
%0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
return
}
// CHECK-LABEL: func @transpose
// CHECK: memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
// CHECK-SAME: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, #[[$strided3DT]]>
// CHECK-SAME: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>

// -----

Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ func.func @expand_shape_to_smaller_rank(%arg0: memref<1xf32>) {

func.func @expand_shape_invalid_result_layout(
%arg0: memref<30x20xf32, strided<[4000, 2], offset: 100>>) {
// expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 60000 + d1 * 4000 + d2 * 2 + 100)>>' but found 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 5000 + d1 * 4000 + d2 * 2 + 100)>>'}}
// expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>' but found 'memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>>'}}
%0 = memref.expand_shape %arg0 [[0, 1], [2]] :
memref<30x20xf32, strided<[4000, 2], offset: 100>>
into memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>>
Expand Down
Loading

0 comments on commit 46b90a7

Please sign in to comment.