Skip to content

Commit

Permalink
Implement miopen.mfma_v2 to gpu.mfma lowering logic and test.
Browse files Browse the repository at this point in the history
  • Loading branch information
whchung committed Aug 17, 2020
1 parent 205cba9 commit cb55fb7
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 3 deletions.
14 changes: 14 additions & 0 deletions mlir/lib/Conversion/MIOpenToGPU/MIOpenToGPU.cpp
Expand Up @@ -545,6 +545,20 @@ void LowerMIOpenOpsToGPUPass::runOnOperation() {

op.erase();
});

gpuFunc.walk([&](miopen::MFMAV2Op op) {
auto loc = op.getLoc();
OpBuilder b(op.getContext());
b.setInsertionPoint(op);

auto gpuMfmaOp = b.create<gpu::MFMAOp>(loc, op.getType(), op.sourceA(), op.sourceB(), op.destC());
gpuMfmaOp.setAttr("instr", op.getAttr("instr"));
gpuMfmaOp.setAttr("imm", op.getAttr("imm"));

op.replaceAllUsesWith(gpuMfmaOp.destD());
op.erase();
});

});
}
}
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Conversion/MIOpenToGPU/mfma_v2.mlir
@@ -0,0 +1,36 @@
// RUN: mlir-opt -convert-miopen-to-gpu="kernel-name=mfma_f32" %s | FileCheck %s --check-prefix=MFMA_F32
// RUN: mlir-opt -convert-miopen-to-gpu="kernel-name=mfma_f16" %s | FileCheck %s --check-prefix=MFMA_F16
// RUN: mlir-opt -convert-miopen-to-gpu="kernel-name=mfma_bf16" %s | FileCheck %s --check-prefix=MFMA_BF16

module {
func @mfma_f32(%a : f32, %b : f32, %c : vector<32xf32>) {
%d0 = miopen.mfma_v2(%a, %b, %c) { instr = "mfma_f32_32x32x1f32", imm = [1, 0, 0]}: f32, vector<32xf32>
// MFMA_F32: %[[D0:.*]] = gpu.mfma(%{{.*}}, %{{.*}}, %{{.*}}) {imm = [1, 0, 0], instr = "mfma_f32_32x32x1f32"} : f32, vector<32xf32>
%d1 = miopen.mfma_v2(%a, %b, %d0) { instr = "mfma_f32_32x32x1f32", imm = [1, 0, 0]}: f32, vector<32xf32>
// MFMA_F32-NEXT: %[[D1:.*]] = gpu.mfma(%{{.*}}, %{{.*}}, %[[D0]]) {imm = [1, 0, 0], instr = "mfma_f32_32x32x1f32"} : f32, vector<32xf32>

return
}

// ----

func @mfma_f16(%a : vector<4xf16>, %b : vector<4xf16>, %c : vector<32xf32>) {
%d0 = miopen.mfma_v2(%a, %b, %c) { instr = "mfma_f32_32x32x4f16", imm = [1, 0, 0]}: vector<4xf16>, vector<32xf32>
// MFMA_F16: %[[D0:.*]] = gpu.mfma(%{{.*}}, %{{.*}}, %{{.*}}) {imm = [1, 0, 0], instr = "mfma_f32_32x32x4f16"} : vector<4xf16>, vector<32xf32>
%d1 = miopen.mfma_v2(%a, %b, %d0) { instr = "mfma_f32_32x32x4f16", imm = [1, 0, 0]}: vector<4xf16>, vector<32xf32>
// MFMA_F16: %[[D1:.*]] = gpu.mfma(%{{.*}}, %{{.*}}, %[[D0]]) {imm = [1, 0, 0], instr = "mfma_f32_32x32x4f16"} : vector<4xf16>, vector<32xf32>

return
}

// ----

func @mfma_bf16(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : vector<32xf32>) {
%d0 = miopen.mfma_v2(%a, %b, %c) { instr = "mfma_f32_32x32x2bf16", imm = [1, 0, 0]}: vector<2xbf16>, vector<32xf32>
// MFMA_BF16: %[[D0:.*]] = gpu.mfma(%{{.*}}, %{{.*}}, %{{.*}}) {imm = [1, 0, 0], instr = "mfma_f32_32x32x2bf16"} : vector<2xbf16>, vector<32xf32>
%d1 = miopen.mfma_v2(%a, %b, %d0) { instr = "mfma_f32_32x32x2bf16", imm = [1, 0, 0]}: vector<2xbf16>, vector<32xf32>
// MFMA_BF16: %[[D1:.*]] = gpu.mfma(%{{.*}}, %{{.*}}, %[[D0]]) {imm = [1, 0, 0], instr = "mfma_f32_32x32x2bf16"} : vector<2xbf16>, vector<32xf32>

return
}
}
6 changes: 3 additions & 3 deletions mlir/test/Dialect/MIOpen/ops_2.mlir
Expand Up @@ -248,23 +248,23 @@ func @miopen_mfma_bf16(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : memref<64x
// ----

func @miopen_mfma_v2_f32(%a : f32, %b : f32, %c : vector<32xf32>) -> vector<32xf32> {
%d = miopen.mfma_v2(%a, %b, %c) { m_per_wave = 64, n_per_wave = 64, instr = "mfma_f32_32x32x1f32", imm = [1, 0, 0] } : f32, vector<32xf32>
%d = miopen.mfma_v2(%a, %b, %c) { instr = "mfma_f32_32x32x1f32", imm = [1, 0, 0] } : f32, vector<32xf32>
return %d : vector<32xf32>
}

// CHECK-LABEL: func @miopen_mfma_v2_f32
// CHECK: miopen.mfma_v2

func @miopen_mfma_v2_f16(%a : vector<4xf16>, %b : vector<4xf16>, %c : vector<32xf32>) -> vector<32xf32> {
%d = miopen.mfma_v2(%a, %b, %c) { m_per_wave = 64, n_per_wave = 64, instr = "mfma_f32_32x32x4f16", imm = [1, 0, 0] } : vector<4xf16>, vector<32xf32>
%d = miopen.mfma_v2(%a, %b, %c) { instr = "mfma_f32_32x32x4f16", imm = [1, 0, 0] } : vector<4xf16>, vector<32xf32>
return %d : vector<32xf32>
}

// CHECK-LABEL: func @miopen_mfma_v2_f16
// CHECK: miopen.mfma_v2

func @miopen_mfma_v2_bf16(%a : vector<2xbf16>, %b : vector<2xbf16>, %c : vector<32xf32>) -> vector<32xf32> {
%d = miopen.mfma_v2(%a, %b, %c) { m_per_wave = 64, n_per_wave = 64, instr = "mfma_f32_32x32x2bf16", imm = [1, 0, 0] } : vector<2xbf16>, vector<32xf32>
%d = miopen.mfma_v2(%a, %b, %c) { instr = "mfma_f32_32x32x2bf16", imm = [1, 0, 0] } : vector<2xbf16>, vector<32xf32>
return %d : vector<32xf32>
}

Expand Down

0 comments on commit cb55fb7

Please sign in to comment.