From a037d88929460ff9571927c56d6db215be086149 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 13 May 2024 15:10:25 -0400 Subject: [PATCH] [mlir][gpu] Support extf before contract when converting to MMA ops (#91988) This commit allows `inferFragType` to see through all arith.ext op and other elementwise users before reaching contract op for figuring out the fragment type. --- .../Conversion/VectorToGPU/VectorToGPU.cpp | 15 ++++++++--- .../VectorToGPU/vector-to-mma-ops.mlir | 27 +++++++++++++++++++ 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 782cc92f83fee..332f0a2eecfcf 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -515,6 +515,14 @@ struct CombineTransferReadOpTranspose final // TODO: Change the GPU dialect to abstract the layout at the this level and // only care about it during lowering to NVVM. static const char *inferFragType(Operation *op) { + // We can have arith.ext ops before reaching contract ops. See through them + // and other kinds of elementwise ops. + if (op->hasOneUse()) { + Operation *userOp = *op->user_begin(); + if (userOp->hasTrait()) + return inferFragType(userOp); + } + for (Operation *users : op->getUsers()) { auto contract = dyn_cast(users); if (!contract) @@ -560,13 +568,12 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, if (op->hasOneUse()) { auto *user = *op->user_begin(); // Infer the signedness of the mma type from the integer extend. - bool isSignedExtend = isa(user); - if (isSignedExtend || isa(user)) { + if (isa(user)) { elType = IntegerType::get( op.getContext(), cast(elType).getWidth(), - isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned); + isa(user) ? IntegerType::Signed + : IntegerType::Unsigned); mappingResult = user->getResult(0); - fragType = inferFragType(user); } } gpu::MMAMatrixType type = diff --git a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir index 962ed7de584a2..8526ff1392599 100644 --- a/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir +++ b/mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir @@ -490,3 +490,30 @@ func.func @fold_transpose_into_transfer_read(%alloc: memref<64x128xf16>, %vector } // ----- + +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @cast_f16_to_f32_read +// CHECK: %[[A:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp"> +// CHECK: %[[C:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp"> +// CHECK: %[[AE:.+]] = gpu.subgroup_mma_elementwise extf %[[A]] : (!gpu.mma_matrix<16x16xf16, "AOp">) -> !gpu.mma_matrix<16x16xf32, "AOp"> +// CHECK: %[[CE:.+]] = gpu.subgroup_mma_elementwise extf %[[C]] : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp"> +// CHECK: %[[B:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp"> +// CHECK: %[[BE:.+]] = gpu.subgroup_mma_elementwise extf %[[B]] : (!gpu.mma_matrix<16x16xf16, "BOp">) -> !gpu.mma_matrix<16x16xf32, "BOp"> +// CHECK: gpu.subgroup_mma_compute %[[AE]], %[[BE]], %[[CE]] +func.func @cast_f16_to_f32_read(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16xf32>) { + %c0 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f16 + %A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16> + %Aext = arith.extf %A : vector<16x16xf16> to vector<16x16xf32> + %Bext = arith.extf %B : vector<16x16xf16> to vector<16x16xf32> + %Cext = arith.extf %C : vector<16x16xf16> to vector<16x16xf32> + %D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind} + %Aext, %Bext, %Cext : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32> + vector.transfer_write %D, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32> + return +}