Skip to content

Commit

Permalink
Allow memref_cast from static strides to dynamic strides.
Browse files Browse the repository at this point in the history
Memref_cast supports cast from static shape to dynamic shape
memrefs. The same should be true for strides as well, i.e a memref
with static strides can be casted to a memref with dynamic strides.

PiperOrigin-RevId: 282381862
  • Loading branch information
Mahesh Ravishankar authored and tensorflower-gardener committed Nov 25, 2019
1 parent 0114554 commit 1ea231b
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 3 deletions.
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/StandardOps/Ops.td
Expand Up @@ -831,6 +831,16 @@ def MemRefCastOp : CastOp<"memref_cast"> {
%2 = memref_cast %1 : memref<?x?xf32> to memref<4x4xf32>
Erase static shape information, replacing it with dynamic information.
%3 = memref_cast %1 : memref<4xf32> to memref<?xf32>

The same holds true for offsets and strides.

Assert that the input dynamic shape matches the destination static stride.
%4 = memref_cast %1 : memref<12x4xf32, offset:?, strides: [?, ?]> to
memref<12x4xf32, offset:5, strides: [4, 1]>
Erase static offset and stride information, replacing it with
dynamic information.
%5 = memref_cast %1 : memref<12x4xf32, offset:5, strides: [4, 1]> to
memref<12x4xf32, offset:?, strides: [?, ?]>
}];

let arguments = (ins AnyMemRef:$source);
Expand Down
24 changes: 22 additions & 2 deletions mlir/lib/Dialect/StandardOps/Ops.cpp
Expand Up @@ -1777,8 +1777,28 @@ bool MemRefCastOp::areCastCompatible(Type a, Type b) {
return false;
if (aT.getElementType() != bT.getElementType())
return false;
if (aT.getAffineMaps() != bT.getAffineMaps())
return false;
if (aT.getAffineMaps() != bT.getAffineMaps()) {
int64_t aOffset, bOffset;
SmallVector<int64_t, 4> aStrides, bStrides;
if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
aStrides.size() != bStrides.size())
return false;

// Strides along a dimension/offset are compatible if the value in the
// source memref is static and the value in the target memref is the
// same. They are also compatible if either one is dynamic (see description
// of MemRefCastOp for details).
auto checkCompatible = [](int64_t a, int64_t b) {
return (a == MemRefType::getDynamicStrideOrOffset() ||
b == MemRefType::getDynamicStrideOrOffset() || a == b);
};
if (!checkCompatible(aOffset, bOffset))
return false;
for (auto aStride : enumerate(aStrides))
if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
return false;
}
if (aT.getMemorySpace() != bT.getMemorySpace())
return false;

Expand Down
10 changes: 9 additions & 1 deletion mlir/test/IR/core-ops.mlir
Expand Up @@ -13,6 +13,7 @@
// CHECK-DAG: #[[VIEW_MAP3:map[0-9]+]] = (d0, d1)[s0] -> (d0 * s0 + d1)

// CHECK-DAG: #[[BASE_MAP0:map[0-9]+]] = (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)
// CHECK-DAG: #[[BASE_MAP3:map[0-9]+]] = (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)
// CHECK-DAG: #[[SUBVIEW_MAP0:map[0-9]+]] = (d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)

// CHECK-DAG: #[[BASE_MAP1:map[0-9]+]] = (d0)[s0] -> (d0 + s0)
Expand Down Expand Up @@ -476,12 +477,19 @@ func @tensor_cast(%arg0: tensor<*xf32>, %arg1 : tensor<4x4xf32>, %arg2: tensor<?
}

// CHECK-LABEL: func @memref_cast(%arg0
func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>) {
func @memref_cast(%arg0: memref<4xf32>, %arg1 : memref<?xf32>, %arg2 : memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>) {
// CHECK: %0 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>
%0 = memref_cast %arg0 : memref<4xf32> to memref<?xf32>

// CHECK: %1 = memref_cast %arg1 : memref<?xf32> to memref<4xf32>
%1 = memref_cast %arg1 : memref<?xf32> to memref<4xf32>

// CHECK: {{%.*}} = memref_cast %arg2 : memref<64x16x4xf32, #[[BASE_MAP0]]> to memref<64x16x4xf32, #[[BASE_MAP3]]>
%2 = memref_cast %arg2 : memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]> to memref<64x16x4xf32, offset: ?, strides: [?, ?, ?]>

// CHECK: {{%.*}} = memref_cast {{%.*}} : memref<64x16x4xf32, #[[BASE_MAP3]]> to memref<64x16x4xf32, #[[BASE_MAP0]]>
%3 = memref_cast %2 : memref<64x16x4xf32, offset: ?, strides: [?, ?, ?]> to memref<64x16x4xf32, offset: 0, strides: [64, 4, 1]>

return
}

Expand Down
16 changes: 16 additions & 0 deletions mlir/test/IR/invalid-ops.mlir
Expand Up @@ -951,3 +951,19 @@ func @invalid_subview(%arg0 : index, %arg1 : memref<?x8x?xf32>) {
%0 = subview %arg1[%c0, %c0, %c0][%c1, %arg0, %c1][%c1, %c1, %c1] : memref<?x8x?xf32> to memref<?x8x?xf32, offset:?, strides:[?, ?, ?]>
return
}

// -----

func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) {
// expected-error@+1{{operand type 'memref<12x4x16xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>' and result type 'memref<12x4x16xf32, (d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>' are cast incompatible}}
%0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]>
return
}

// -----

func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) {
// expected-error@+1{{operand type 'memref<12x4x16xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>' and result type 'memref<12x4x16xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2 + 16)>' are cast incompatible}}
%0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:16, strides:[64, 16, 1]>
return
}

0 comments on commit 1ea231b

Please sign in to comment.