Support non minor id maps#178992
Conversation
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Michael Platings (mplatings) ChangesBuilds on #176785. Full diff: https://github.com/llvm/llvm-project/pull/178992.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index 98434357f826f..65433ae6eb1c6 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -95,42 +95,71 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
return true;
}
-// Return true if the given map represents a transposed matrix load,
-// i.e. (d0, d1, ...) -> (dn-1, dn-2).
-static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
+// Test whether the permutation map's first result corresponds to its last
+// dimension.
+//
+// In contexts where we only accept maps that have the last (most minor)
+// dimension as exactly one of the results, this is sufficient to classify
+// whether it represents a transpose.
+static bool isFirstResultLastMapDimension(AffineMap permutationMap) {
MLIRContext *ctx = permutationMap.getContext();
- // Local OpBuilder is fine here, we just build attributes.
- OpBuilder b(ctx);
- auto nDim = permutationMap.getNumDims();
- AffineExpr zero = b.getAffineConstantExpr(0);
- if (nDim < 2) {
- // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
- AffineExpr dim0 = b.getAffineDimExpr(0);
- return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
- }
-
- AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
- AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
- // Support both transposed and transposed+broadcasted cases.
- return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
- permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
+ unsigned nDim = permutationMap.getNumDims();
+ return nDim && permutationMap.getNumResults() &&
+ permutationMap.getResult(0) == getAffineDimExpr(nDim - 1, ctx);
}
-// Return the stide for the second-to-last dimension of |type| if it is a memref
-// and has a constant stride.
-static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
+// Return the `leadDimension` (row stride) implied by |permutationMap| for
+// |type|, if |type| is a memref with a statically-known layout.
+//
+// The `leadDimension` is the stride (in elements) between consecutive rows in
+// the 2D view described by |permutationMap|. This helper supports the subset
+// of maps permitted by vector.transfer_read:
+// - Exactly 2 results.
+// - Each result is either an affine dimension or the constant 0 (broadcast).
+//
+// Constraints:
+// - Requires the most minor memref stride to be 1.
+//
+// Broadcast:
+// - If either result is constant 0, the implied `leadDimension` is 0.
+static std::optional<int64_t>
+getStaticallyKnownRowStride(ShapedType type, AffineMap permutationMap) {
auto memrefType = dyn_cast<MemRefType>(type);
if (!memrefType)
- return false;
+ return std::nullopt;
// If the memref is 0 or 1D the horizontal stride is 0.
if (memrefType.getRank() < 2)
return 0;
int64_t offset = 0;
- SmallVector<int64_t, 2> strides;
+ SmallVector<int64_t> strides;
if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
strides.back() != 1)
return std::nullopt;
- int64_t stride = strides[strides.size() - 2];
+
+ if (permutationMap.getNumResults() != 2)
+ return std::nullopt;
+
+ unsigned strideIndex = strides.size();
+
+ for (AffineExpr result : permutationMap.getResults()) {
+ if (auto dim = dyn_cast<AffineDimExpr>(result)) {
+ strideIndex = std::min(strideIndex, dim.getPosition());
+ continue;
+ }
+ auto cst = dyn_cast<AffineConstantExpr>(result);
+ if (!cst || cst.getValue() != 0)
+ return std::nullopt;
+ // A broadcast result forces row stride to 0.
+ return 0;
+ }
+
+ // Structural validity check: ensure that the map selects at least one
+ // dimension more major than the most minor dimension. This also excludes
+ // degenerate cases where both results map to the most minor dimension.
+ if (strideIndex + 1 >= strides.size())
+ return std::nullopt;
+
+ int64_t stride = strides[strideIndex];
if (stride == ShapedType::kDynamic)
return std::nullopt;
return stride;
@@ -141,7 +170,9 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
readOp.getVectorType().getRank() != 2)
return false;
- if (!getStaticallyKnownRowStride(readOp.getShapedType()))
+
+ AffineMap permutationMap = readOp.getPermutationMap();
+ if (!getStaticallyKnownRowStride(readOp.getShapedType(), permutationMap))
return false;
// Only allow integer types if the signedness can be inferred.
@@ -150,14 +181,10 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
!isa<arith::ExtUIOp>(*readOp->user_begin())))
return false;
- AffineMap map = readOp.getPermutationMap();
MLIRContext *ctx = readOp.getContext();
- AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
- AffineExpr zero = getAffineConstantExpr(0, ctx);
- auto broadcastInnerDim =
- AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
- return map.isMinorIdentity() || map == broadcastInnerDim ||
- isTransposeMatrixLoadMap(map);
+ AffineExpr innerDim = getAffineDimExpr(permutationMap.getNumDims() - 1, ctx);
+ return permutationMap.getResult(0) == innerDim ||
+ permutationMap.getResult(1) == innerDim;
}
// Return true if the transfer op can be converted to a MMA matrix store.
@@ -170,12 +197,18 @@ transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
writeOp.getVectorType().getRank() != 2)
return false;
- if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
+
+ AffineMap permutationMap = writeOp.getPermutationMap();
+ std::optional<int64_t> stride =
+ getStaticallyKnownRowStride(writeOp.getShapedType(), permutationMap);
+ // Stride of zero means broadcast which is not permitted for writes.
+ if (!stride.has_value() || stride.value() == 0)
return false;
+
+ MLIRContext *ctx = writeOp.getContext();
+ AffineExpr innerDim = getAffineDimExpr(permutationMap.getNumDims() - 1, ctx);
// TODO: Support transpose once it is added to GPU dialect ops.
- if (!writeOp.getPermutationMap().isMinorIdentity())
- return false;
- return true;
+ return permutationMap.getResult(1) == innerDim;
}
/// Return true if the constant is a splat to a 2D vector so that it can be
@@ -547,21 +580,19 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
assert(transferReadSupportsMMAMatrixType(op) &&
"expected convertible operation");
+ AffineMap permutationMap = op.getPermutationMap();
std::optional<int64_t> stride =
- getStaticallyKnownRowStride(op.getShapedType());
+ getStaticallyKnownRowStride(op.getShapedType(), permutationMap);
if (!stride.has_value()) {
LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
- AffineMap map = op.getPermutationMap();
- bool isTranspose = isTransposeMatrixLoadMap(map);
-
- // Handle broadcast by setting the stride to 0.
- if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
- assert(cstExpr.getValue() == 0);
- stride = 0;
- }
+ // transferReadSupportsMMAMatrixType ensures that either of the map results is
+ // the most minor dimension. Under this constraint, whether the map represents
+ // a transposed view can be inferred from whether the first result is the most
+ // minor memref dimension.
+ bool isTranspose = isFirstResultLastMapDimension(permutationMap);
Value mappingResult = op.getResult();
auto elType = op.getVectorType().getElementType();
@@ -597,7 +628,7 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
assert(transferWriteSupportsMMAMatrixType(op));
std::optional<int64_t> stride =
- getStaticallyKnownRowStride(op.getShapedType());
+ getStaticallyKnownRowStride(op.getShapedType(), op.getPermutationMap());
if (!stride.has_value()) {
LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
index ef72901750479..bf858789c7e07 100644
--- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
+++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
@@ -1,5 +1,21 @@
// RUN: mlir-opt %s -pass-pipeline="builtin.module(func.func(convert-vector-to-gpu),canonicalize)" --split-input-file | FileCheck %s
+// -----
+
+// The pass currently only works for 2D vector transfers.
+// CHECK-LABEL: func @no_convert_3d
+// CHECK-NOT: gpu
+func.func @no_convert_3d(%arg0: memref<2x2x2xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<2x2x2xf16>, vector<2x2x2xf16>
+ %B = arith.addf %A, %A : vector<2x2x2xf16>
+ vector.transfer_write %B, %arg0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<2x2x2xf16>, memref<2x2x2xf16>
+ return
+}
+
+// -----
+
#map0 = affine_map<(d0, d1) -> (d1, d0)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
@@ -555,3 +571,94 @@ func.func @addf(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memre
vector.transfer_write %C, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
return
}
+
+// -----
+
+// CHECK-LABEL: func @matmul_with_strides
+// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
+// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 96 : index} : memref<16x6x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
+// CHECK-DAG: %[[C:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 144 : index} : memref<16x9x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
+// CHECK: gpu.subgroup_mma_store_matrix %[[D]], %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 144 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x9x16xf16>
+func.func @matmul_with_strides(%arg0: memref<16x16xf16>, %arg1: memref<16x6x16xf16>, %arg2: memref<16x9x16xf16>) {
+ %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f16
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
+ %B = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>, in_bounds = [true, true]} : memref<16x6x16xf16>, vector<16x16xf16>
+ %C = vector.transfer_read %arg2[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>} : memref<16x9x16xf16>, vector<16x16xf16>
+ %D = vector.contract {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %C : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
+ vector.transfer_write %D, %arg2[%c0, %c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d0, d2)>} : vector<16x16xf16>, memref<16x9x16xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @read_transpose_with_strides_3d
+func.func @read_transpose_with_strides_3d(%arg0: memref<5x7x3xf16>, %arg1: memref<2x5x3xf16>, %arg2: memref<3x5xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 21 : index, transpose} : memref<5x7x3xf16> -> !gpu.mma_matrix<3x5xf16, "COp">
+ %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d2, d0)>} : memref<5x7x3xf16>, vector<3x5xf16>
+ // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 3 : index, transpose} : memref<2x5x3xf16> -> !gpu.mma_matrix<3x5xf16, "COp">
+ %B = vector.transfer_read %arg1[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} : memref<2x5x3xf16>, vector<3x5xf16>
+ %C = arith.addf %A, %B : vector<3x5xf16>
+ vector.transfer_write %C, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<3x5xf16>, memref<3x5xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @read_transpose_with_strides_4d
+func.func @read_transpose_with_strides_4d(%arg0: memref<5x7x11x3xf16>, %arg1: memref<2x5x11x3xf16>, %arg2: memref<3x5xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 231 : index, transpose} : memref<5x7x11x3xf16> -> !gpu.mma_matrix<3x5xf16, "COp">
+ %A = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d0)>} : memref<5x7x11x3xf16>, vector<3x5xf16>
+ // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 33 : index, transpose} : memref<2x5x11x3xf16> -> !gpu.mma_matrix<3x5xf16, "COp">
+ %B = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3) -> (d3, d1)>} : memref<2x5x11x3xf16>, vector<3x5xf16>
+ %C = arith.addf %A, %B : vector<3x5xf16>
+ vector.transfer_write %C, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<3x5xf16>, memref<3x5xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @no_convert_read_transpose_not_last_dim
+// CHECK-NOT: gpu
+func.func @no_convert_read_transpose_not_last_dim(%arg0: memref<2x2x2xf16>, %arg1: memref<2x2xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ // Legal map, but does not map the last memref dim so should not be lowered to an MMA load.
+ %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d1, d0)>} : memref<2x2x2xf16>, vector<2x2xf16>
+ %B = arith.addf %A, %A : vector<2x2xf16>
+ vector.transfer_write %B, %arg1[%c0, %c0] {in_bounds = [true, true]} : vector<2x2xf16>, memref<2x2xf16>
+ return
+}
+
+// -----
+
+// Transpose write is not supported.
+// CHECK-LABEL: func @no_convert_write_transpose
+// CHECK-NOT: gpu
+func.func @no_convert_write_transpose(%arg0: memref<2x2xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<2x2xf16>, vector<2x2xf16>
+ %B = arith.addf %A, %A : vector<2x2xf16>
+ vector.transfer_write %B, %arg0[%c0, %c0] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<2x2xf16>, memref<2x2xf16>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @read_transpose_with_broadcast_3d
+func.func @read_transpose_with_broadcast_3d(%arg0: memref<2x2x2xf16>, %arg1: memref<2x2xf16>) {
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant 0.0 : f16
+ // CHECK: gpu.subgroup_mma_load_matrix %{{.*}} {leadDimension = 0 : index, transpose} : memref<2x2x2xf16> -> !gpu.mma_matrix<2x2xf16, "COp">
+ %A = vector.transfer_read %arg0[%c0, %c0, %c0], %cst {in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2) -> (d2, 0)>} : memref<2x2x2xf16>, vector<2x2xf16>
+ %B = arith.addf %A, %A : vector<2x2xf16>
+ vector.transfer_write %B, %arg1[%c0, %c0] {in_bounds = [true, true]} : vector<2x2xf16>, memref<2x2xf16>
+ return
+}
|
|
I have a feeling this might be fixing the wrong problem. I'm not super familiar with the maths but I'm guessing that usually some kind of loop reordering would be possible. Jack mentioned that for some hardware non-contiguous loads are not supported so that would be another reason not to land this. |
|
I'm persuaded that this is a reasonable change after all. It is already possible to have the same effect using memref.subview. This change makes the subview unnecessary. In the following example, |
4cf2ed1 to
cbd01c7
Compare
|
LGTM. Leave it for Jack to review. |
FranklandJack
left a comment
There was a problem hiding this comment.
Nice, thanks for all the extra tests that didn't have existing coverage.
Add support for lowering vector.transfer_read to gpu.subgroup_mma_load_matrix with transpose permutation_map with non-minor dimensions e.g. (d0, d1, d2) -> (d2, d0)
cbd01c7 to
8b15d28
Compare
|
LGTM! |
Builds on #176785.
Add support for lowering vector.transfer_read to gpu.subgroup_mma_load_matrix with transpose permutation_map with non-minor dimensions e.g. (d0, d1, d2) -> (d2, d0)