Skip to content

Commit

Permalink
[mlir][gpu] Support extf before contract when converting to MMA ops (#…
Browse files Browse the repository at this point in the history
…91988)

This commit allows `inferFragType` to see through all arith.ext op
and other elementwise users before reaching contract op for
figuring out the fragment type.
  • Loading branch information
antiagainst committed May 13, 2024
1 parent 5944579 commit a037d88
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
15 changes: 11 additions & 4 deletions mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,14 @@ struct CombineTransferReadOpTranspose final
// TODO: Change the GPU dialect to abstract the layout at the this level and
// only care about it during lowering to NVVM.
static const char *inferFragType(Operation *op) {
// We can have arith.ext ops before reaching contract ops. See through them
// and other kinds of elementwise ops.
if (op->hasOneUse()) {
Operation *userOp = *op->user_begin();
if (userOp->hasTrait<OpTrait::Elementwise>())
return inferFragType(userOp);
}

for (Operation *users : op->getUsers()) {
auto contract = dyn_cast<vector::ContractionOp>(users);
if (!contract)
Expand Down Expand Up @@ -560,13 +568,12 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
if (op->hasOneUse()) {
auto *user = *op->user_begin();
// Infer the signedness of the mma type from the integer extend.
bool isSignedExtend = isa<arith::ExtSIOp>(user);
if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
elType = IntegerType::get(
op.getContext(), cast<IntegerType>(elType).getWidth(),
isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
isa<arith::ExtSIOp>(user) ? IntegerType::Signed
: IntegerType::Unsigned);
mappingResult = user->getResult(0);
fragType = inferFragType(user);
}
}
gpu::MMAMatrixType type =
Expand Down
27 changes: 27 additions & 0 deletions mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -490,3 +490,30 @@ func.func @fold_transpose_into_transfer_read(%alloc: memref<64x128xf16>, %vector
}

// -----

#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map3 = affine_map<(d0, d1, d2) -> (d0, d1)>

// CHECK-LABEL: func @cast_f16_to_f32_read
// CHECK: %[[A:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
// CHECK: %[[C:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
// CHECK: %[[AE:.+]] = gpu.subgroup_mma_elementwise extf %[[A]] : (!gpu.mma_matrix<16x16xf16, "AOp">) -> !gpu.mma_matrix<16x16xf32, "AOp">
// CHECK: %[[CE:.+]] = gpu.subgroup_mma_elementwise extf %[[C]] : (!gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf32, "COp">
// CHECK: %[[B:.+]] = gpu.subgroup_mma_load_matrix {{.+}} {leadDimension = 16 : index, transpose} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
// CHECK: %[[BE:.+]] = gpu.subgroup_mma_elementwise extf %[[B]] : (!gpu.mma_matrix<16x16xf16, "BOp">) -> !gpu.mma_matrix<16x16xf32, "BOp">
// CHECK: gpu.subgroup_mma_compute %[[AE]], %[[BE]], %[[CE]]
func.func @cast_f16_to_f32_read(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>, %arg2: memref<16x16xf16>, %arg3: memref<16x16xf32>) {
%c0 = arith.constant 0 : index
%cst = arith.constant 0.000000e+00 : f16
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
%B = vector.transfer_read %arg1[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
%C = vector.transfer_read %arg2[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
%Aext = arith.extf %A : vector<16x16xf16> to vector<16x16xf32>
%Bext = arith.extf %B : vector<16x16xf16> to vector<16x16xf32>
%Cext = arith.extf %C : vector<16x16xf16> to vector<16x16xf32>
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>}
%Aext, %Bext, %Cext : vector<16x16xf32>, vector<16x16xf32> into vector<16x16xf32>
vector.transfer_write %D, %arg3[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf32>, memref<16x16xf32>
return
}

0 comments on commit a037d88

Please sign in to comment.