diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index b38dd8effe669..7763831141c6b 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -759,6 +759,9 @@ struct GreedyFusion { const DenseSet &srcEscapingMemRefs, unsigned producerId, unsigned consumerId, bool removeSrcNode) { + // We can't generate private memrefs if their size can't be computed. + if (!getMemRefIntOrFloatEltSizeInBytes(cast(memref.getType()))) + return false; const Node *consumerNode = mdg->getNode(consumerId); // If `memref` is an escaping one, do not create a private memref // for the below scenarios, since doing so will leave the escaping diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir index 2830235431c76..07d2d06f1451d 100644 --- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir +++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer fusion-maximal}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER-MAXIMAL // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(spirv.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=SPIRV @@ -345,3 +346,37 @@ func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %produce // PRODUCER-CONSUMER-NEXT: } return } + +#map = affine_map<()[s0] -> (s0 + 5)> +#map1 = affine_map<()[s0] -> (s0 + 17)> + +// Test with non-int/float memref types. + +// PRODUCER-CONSUMER-MAXIMAL-LABEL: func @memref_index_type +func.func @memref_index_type() { + %0 = llvm.mlir.constant(2 : index) : i64 + %2 = llvm.mlir.constant(0 : index) : i64 + %3 = builtin.unrealized_conversion_cast %2 : i64 to index + %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x18xf32> + %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<3xf32> + %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<3xindex> + affine.for %arg3 = 0 to 3 { + %4 = affine.load %alloc_2[%arg3] : memref<3xindex> + %5 = builtin.unrealized_conversion_cast %4 : index to i64 + %6 = llvm.sub %0, %5 : i64 + %7 = builtin.unrealized_conversion_cast %6 : i64 to index + affine.store %7, %alloc_2[%arg3] : memref<3xindex> + } + affine.for %arg3 = 0 to 3 { + %4 = affine.load %alloc_2[%arg3] : memref<3xindex> + %5 = affine.apply #map()[%4] + %6 = affine.apply #map1()[%3] + %7 = memref.load %alloc[%5, %6] : memref<8x18xf32> + affine.store %7, %alloc_1[%arg3] : memref<3xf32> + } + // Expect fusion. + // PRODUCER-CONSUMER-MAXIMAL: affine.for + // PRODUCER-CONSUMER-MAXIMAL-NOT: affine.for + // PRODUCER-CONSUMER-MAXIMAL: return + return +}