Skip to content

Commit

Permalink
[mlir][linalg] fix crash when promoting rank-reducing memref.subviews
Browse files Browse the repository at this point in the history
This change adds support for promoting `linalg` operation operands that
are produced by rank-reducing `memref.subview` ops.

Differential Revision: https://reviews.llvm.org/D127086
  • Loading branch information
christopherbate committed Jun 6, 2022
1 parent a8cf78c commit 99069ab
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
7 changes: 6 additions & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
Expand Up @@ -25,6 +25,7 @@
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/FoldUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -219,7 +220,11 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
SmallVector<OpFoldResult> partialSizes;
fullSizes.reserve(rank);
partialSizes.reserve(rank);
llvm::SmallBitVector droppedDims = subView.getDroppedDims();
int64_t resultDimIdx = 0;
for (const auto &en : llvm::enumerate(subView.getOrCreateRanges(b, loc))) {
if (droppedDims[en.index()])
continue;
auto rangeValue = en.value();
// Try to extract a tight constant.
LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n");
Expand All @@ -232,7 +237,7 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
fullSizes.push_back(size);
partialSizes.push_back(
b.createOrFold<memref::DimOp>(loc, subView, en.index()));
b.createOrFold<memref::DimOp>(loc, subView, resultDimIdx++));
}
SmallVector<int64_t, 4> dynSizes(fullSizes.size(), -1);
// If a callback is not specified, then use the default implementation for
Expand Down
49 changes: 46 additions & 3 deletions mlir/test/Dialect/Linalg/promote.mlir
@@ -1,6 +1,6 @@
// RUN: mlir-opt %s -linalg-promote-subviews | FileCheck %s
// RUN: mlir-opt %s -linalg-promote-subviews="test-promote-dynamic" | FileCheck %s --check-prefix=DYNAMIC
// RUN: mlir-opt %s -linalg-promote-subviews="test-use-alloca" | FileCheck %s --check-prefix=ALLOCA
// RUN: mlir-opt %s -linalg-promote-subviews -split-input-file | FileCheck %s
// RUN: mlir-opt %s -linalg-promote-subviews="test-promote-dynamic" -split-input-file | FileCheck %s --check-prefix=DYNAMIC
// RUN: mlir-opt %s -linalg-promote-subviews="test-use-alloca" -split-input-file | FileCheck %s --check-prefix=ALLOCA

#map1 = affine_map<(d0) -> (d0 + 2)>
#map2 = affine_map<(d0) -> (d0 + 4)>
Expand Down Expand Up @@ -145,3 +145,46 @@ func.func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
// CHECK: memref.dealloc %[[tmpA_f64]] : memref<64xi8>
// CHECK: memref.dealloc %[[tmpB_f64]] : memref<96xi8>
// CHECK: memref.dealloc %[[tmpC_f64]] : memref<48xi8>

// -----

#map0 = affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3, s4] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4)>
#map2 = affine_map<(d0, d1)[s0] -> (d0 * 128 + s0 + d1)>
#map5 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
#map6 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map7 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map8 = affine_map<(d0, d1, d2) -> (d0, d1)>

// CHECK: promote_rank_reducing_subviews([[arg0:%.+]]: memref<{{.*}}>, [[arg1:%.+]]: memref<{{.*}}>, [[arg2:%.+]]: memref<{{.*}}>, [[lb1:%.+]]: index, [[lb2:%.+]]: index, [[lb3:%.+]]: index, [[lb4:%.+]]: index, [[lb5:%.+]]: index, [[lb6:%.+]]: index, [[ub1:%.+]]: index, [[ub2:%.+]]: index
func.func @promote_rank_reducing_subviews(%arg0: memref<?x?x?x64xf32, #map0>, %arg1: memref<128x3x3x64xf32, #map0>, %arg2: memref<?x?x?x128xf32>,
%arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %ub1: index, %ub2: index) {
%13 = memref.subview %arg0[%arg3, 0, %arg4, %arg8] [1, 1, %ub1, 32] [1, 1, 1, 1] : memref<?x?x?x64xf32, #map0> to memref<?x32xf32, #map5>
%14 = memref.subview %arg1[0, %arg6, %arg7, %arg8] [128, 1, 1, 32] [1, 1, 1, 1] : memref<128x3x3x64xf32, #map0> to memref<128x32xf32, #map5>
%9 = memref.subview %arg2[%arg3, %arg4, %arg5, 0] [1, 1, %ub2, 128] [1, 1, 1, 1] : memref<?x?x?x128xf32> to memref<?x128xf32, #map2>

// CHECK: [[a_alloc:%.+]] = memref.alloc
// CHECK: [[a_view:%.+]] = memref.view [[a_alloc]]{{.*}}
// CHECK: [[a_pro_subview:%.+]] = memref.subview [[a_view]][0, 0] [[[ub1]], {{%.+}}] [1, 1]

// CHECK: memref.alloc
// CHECK: [[b_view:%.+]] = memref.view
// CHECK: [[b_pro_subview:%.+]] = memref.subview [[b_view]]

// CHECK: memref.alloc
// CHECK: [[c_view:%.+]] = memref.view
// CHECK: [[c_pro_subview:%.+]] = memref.subview [[c_view]]

// CHECK-COUNT-3: memref.copy
// CHECK: linalg.generic
// CHECK-SAME: ins([[a_pro_subview]], [[b_pro_subview]]
// CHECK-SAME: outs([[c_pro_subview]]

linalg.generic {indexing_maps = [#map6, #map7, #map8], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : memref<?x32xf32, #map5>, memref<128x32xf32, #map5>) outs(%9 : memref<?x128xf32, #map2>) {
^bb0(%arg9: f32, %arg10: f32, %arg11: f32):
%15 = arith.mulf %arg9, %arg10 : f32
%16 = arith.addf %arg11, %15 : f32
linalg.yield %16 : f32
}

return
}

0 comments on commit 99069ab

Please sign in to comment.