Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][vector] Restrict DropInnerMostUnitDimsTransfer{Read|Write} #96218

Merged
merged 5 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,33 @@ class DropInnerMostUnitDimsTransferWrite
if (dimsToDrop == 0)
return failure();

// We need to consider 3 cases for the dim to drop:
// 1. if "in bounds", it can safely be assumeed that the corresponding
// index is equal to 0 (safe to collapse) (*)
// 2. if "out of bounds" and the corresponding index is 0, it is
// effectively "in bounds" (safe to collapse)
// 3. If "out of bounds" and the correspondong index is != 0,
// be conservative and bail out (not safe to collapse)
// (*) This pattern only drops unit dims, so the only possible "in bounds"
// index is "0". This could be added as a folder.
// TODO: Deal with 3. by e.g. proppaging the "out of bounds" flag to other
// dims.
bool indexOutOfBounds = true;
if (writeOp.getInBounds())
indexOutOfBounds = llvm::any_of(
llvm::zip(writeOp.getInBounds()->getValue().take_back(dimsToDrop),
writeOp.getIndices().take_back(dimsToDrop)),
[](auto zipped) {
auto inBounds = cast<BoolAttr>(std::get<0>(zipped)).getValue();
auto nonZeroIdx = !isZeroIndex(std::get<1>(zipped));
return !inBounds && nonZeroIdx;
});
else
indexOutOfBounds = !llvm::all_of(
writeOp.getIndices().take_back(dimsToDrop), isZeroIndex);
if (indexOutOfBounds)
return failure();
Copy link
Member

@MacDue MacDue Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks much more complex than I'd expect. I don't see the need to check the indices at all.

Also, dropping the dim if the index is zero and marked as out-of-bounds does not seem valid. If index zero is out-of-bounds, then we can't safely write to that unit dimension (I'm not sure of the precise meaning of that, but it does not seem like a safe transform in that case...).

Why not just:

auto inBounds = writeOp.getInBoundsValues();
auto droppedInBounds = ArrayRef<bool>(inBounds).take_back(dimsToDrop);
if (llvm::is_contained(droppedInBounds, false)) 
  return failure();

Copy link
Member

@MacDue MacDue Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(or maybe even):

if (writeOp.hasOutOfBoundsDim())
  return failure();

out-of-bounds dims are an edge-case (that I've not really seen used in practice, so are probably not worth worrying about much).

Copy link
Member

@MacDue MacDue Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note: If index 0 and in-bounds = false for a unit-dim actually means in-bounds = true, that should be an xferOp fold (not something everyone that looks at in-bounds values needs to interpret).

Edit: Looks like it already is! https://godbolt.org/z/9Gqo37YeG :) It's an actual fold too (foldTransferInBoundsAttribute, so it's running pretty much all the time).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropping the dim if the index is zero and marked as out-of-bounds does not seem valid. If index zero is out-of-bounds, then we can't safely write to that unit dimension

As you observed further down:

If index 0 and in-bounds = false for a unit-dim actually means in-bounds = true

:) So:

  1. if the index is == 0 then in_bounds is effectively irrelevant (safe to collapse)
  2. if index is != 0, but in_bounds = true, the index is effectively ==0 and (safe to collapse)
  3. if index != 0 and in_bounds = false, bail out.

Note that once in_bounds is mandatory, I will be able to simply the above as:

      if(llvm::any_of(
          llvm::zip(writeOp.getInBounds()->getValue().take_back(dimsToDrop),
                    writeOp.getIndices().take_back(dimsToDrop)),
          [](auto zipped) {
            auto inBounds = cast<BoolAttr>(std::get<0>(zipped)).getValue();
            auto nonZeroIdx = !isZeroIndex(std::get<1>(zipped));
            return !inBounds && nonZeroIdx;
          }))
      return failure();

out-of-bounds dims are an edge-case

Not sure we can say that - the default for in_bounds is "out of bounds".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an actual fold too (foldTransferInBoundsAttribute, so it's running pretty much all the time).

I knew about the folder, but incorrectly assumed that the "folder" wasn't guaranteed to be run before the pattern. I was wrong:

/// Also performs folding and simple dead-code elimination before attempting to
/// match any of the provided patterns.

/// Also performs folding and simple dead-code elimination before attempting to
/// match any of the provided patterns.

So:

Why not just:

auto inBounds = writeOp.getInBoundsValues();
auto droppedInBounds = ArrayRef(inBounds).take_back(dimsToDrop);
if (llvm::is_contained(droppedInBounds, false))
return failure();

Yeah, that's perfectly sufficient and covers all the cases. Thanks!

Copy link
Contributor

@nujaa nujaa Jul 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure we can say that - the default for in_bounds is "out of bounds".

I think this is definitely safer.


auto resultTargetVecType =
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
targetType.getElementType(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,6 @@ func.func @contiguous_inner_most_outer_dim_dyn_scalable_inner_dim(%a: index, %b:

// -----

func.func @contiguous_inner_most_dim_non_zero_idx(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0.0 : f32
%1 = vector.transfer_read %A[%i, %c0], %f0 : memref<16x1xf32>, vector<8x1xf32>
return %1 : vector<8x1xf32>
}
// CHECK: func @contiguous_inner_most_dim_non_zero_idx(%[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index) -> vector<8x1xf32>
// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
// CHECK-SAME: memref<16x1xf32> to memref<16xf32, strided<[1]>>
// CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_0]]
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[V]] : vector<8xf32> to vector<8x1xf32>
// CHECK: return %[[RESULT]]

// The index to be dropped is != 0 - this is currently not supported.
func.func @negative_contiguous_inner_most_dim_non_zero_idxs(%A: memref<16x1xf32>, %i:index) -> (vector<8x1xf32>) {
%f0 = arith.constant 0.0 : f32
Expand All @@ -136,24 +123,6 @@ func.func @negative_contiguous_inner_most_dim_non_zero_idxs(%A: memref<16x1xf32>
// CHECK-NOT: memref.subview
// CHECK: vector.transfer_read

// Same as the top example within this split, but with the outer vector
// dim scalable. Note that this example only makes sense when "8 = [8]" (i.e.
// vscale = 1). This is assumed (implicitly) via the `in_bounds` attribute.

func.func @contiguous_inner_most_dim_non_zero_idx_scalable_inner_dim(%A: memref<16x1xf32>, %i:index) -> (vector<[8]x1xf32>) {
%c0 = arith.constant 0 : index
%f0 = arith.constant 0.0 : f32
%1 = vector.transfer_read %A[%i, %c0], %f0 : memref<16x1xf32>, vector<[8]x1xf32>
return %1 : vector<[8]x1xf32>
}
// CHECK-LABEL: func @contiguous_inner_most_dim_non_zero_idx_scalable_inner_dim(
// CHECK-SAME: %[[SRC:.+]]: memref<16x1xf32>, %[[I:.+]]: index) -> vector<[8]x1xf32>
// CHECK: %[[SRC_0:.+]] = memref.subview %[[SRC]]
// CHECK-SAME: memref<16x1xf32> to memref<16xf32, strided<[1]>>
// CHECK: %[[V:.+]] = vector.transfer_read %[[SRC_0]]
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[V]] : vector<[8]xf32> to vector<[8]x1xf32>
// CHECK: return %[[RESULT]]

// -----

func.func @contiguous_inner_most_dim_with_subview(%A: memref<1000x1xf32>, %i:index, %ii:index) -> (vector<4x1xf32>) {
Expand Down Expand Up @@ -367,6 +336,63 @@ func.func @contiguous_inner_most_dynamic_outer_scalable_inner_dim(%a: index, %b:

// -----

// Test the impact of changing the in_bounds attribute. The behaviour will
// depend on whether the index is == 0 or != 0.

// The index to be dropped is == 0, so it's safe to collapse. The other index
// should be preserved correctly.
func.func @contiguous_inner_most_zero_idx_in_bounds(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%i, %c0] {in_bounds = [true, true]} : vector<8x1xf32>, memref<16x1xf32>
return
}
// CHECK-LABEL: func.func @contiguous_inner_most_zero_idx_in_bounds(
// CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>,
// CHECK-SAME: %[[VEC:.*]]: vector<8x1xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) {
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>>
// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32>
// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>>

func.func @contiguous_inner_most_zero_idx_out_of_bounds(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%i, %c0] {in_bounds = [true, false]} : vector<8x1xf32>, memref<16x1xf32>
return
}
// CHECK-LABEL: func.func @contiguous_inner_most_zero_idx_out_of_bounds
// CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>,
// CHECK-SAME: %[[VEC:.*]]: vector<8x1xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) {
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>>
// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32>
// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>>

// The index to be dropped is unknown, but since it's "in bounds", it has to be
// == 0. It's safe to collapse the corresponding dim.

func.func @contiguous_inner_most_dim_non_zero_idx_in_bounds(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
vector.transfer_write %arg1, %arg0[%i, %i] {in_bounds = [true, true]} : vector<8x1xf32>, memref<16x1xf32>
return
}
// CHECK-LABEL: func @contiguous_inner_most_dim_non_zero_idx_in_bounds
// CHECK-SAME: %[[MEM:.*]]: memref<16x1xf32>,
// CHECK-SAME: %[[VEC:.*]]: vector<8x1xf32>,
// CHECK-SAME: %[[IDX:.*]]: index) {
// CHECK: %[[SV:.*]] = memref.subview %[[MEM]][0, 0] [16, 1] [1, 1] : memref<16x1xf32> to memref<16xf32, strided<[1]>>
// CHECK: %[[SC:.*]] = vector.shape_cast %[[VEC]] : vector<8x1xf32> to vector<8xf32>
// CHECK: vector.transfer_write %[[SC]], %[[SV]]{{\[}}%[[IDX]]] {in_bounds = [true]} : vector<8xf32>, memref<16xf32, strided<[1]>>

func.func @negative_contiguous_inner_most_dim_non_zero_idx_out_of_bounds(%arg0: memref<16x1xf32>, %arg1: vector<8x1xf32>, %i: index) {
vector.transfer_write %arg1, %arg0[%i, %i] {in_bounds = [true, false]} : vector<8x1xf32>, memref<16x1xf32>
return
}
// CHECK-LABEL: func @negative_contiguous_inner_most_dim_non_zero_idx_out_of_bounds
// CHECK-NOT: memref.subview
// CHECK-NOT: memref.shape_cast
// CHECK: vector.transfer_write

// -----

func.func @drop_inner_most_dim(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
Expand Down
Loading