Skip to content

Commit 14d79af

Browse files
Manish GuptaThomasRaoux
authored andcommitted
[mlir][NVGPU] nvgpu.mmasync on F32 through TF32
Adds optional attribute to support tensor cores on F32 datatype by lowering to `mma.sync` with TF32 operands. Since, TF32 is not a native datatype in LLVM we are adding `tf32Enabled` as an attribute to allow the IR to be aware of `MmaSyncOp` datatype. Additionally, this patch adds placeholders for nvgpu-to-nvgpu transformation targeting higher precision tf32x3. For mma.sync on f32 input using tensor cores there are two possibilites: (a) tf32 (1 `mma.sync` per warp-level matrix-multiply-accumulate) (b) tf32x3 (3 `mma.sync` per warp-level matrix-multiply-accumulate) Typically, tf32 tensor core acceleration comes at a cost of accuracy from missing precision bits. While f32 has 23 precision bits, tf32 has only 10 precision bits. tf32x3 aims to recover the precision bits by splitting each operand into two tf32 values and issue three `mma.sync` tensor core operations. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D130294
1 parent bcef4d2 commit 14d79af

File tree

16 files changed

+283
-8
lines changed

16 files changed

+283
-8
lines changed

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,22 @@ def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [
110110
(vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
111111
```
112112
}];
113-
let arguments = (ins AnyVector:$matrixA, AnyVector:$matrixB,
114-
AnyVector:$matrixC, I64ArrayAttr:$mmaShape);
113+
let arguments = (ins AnyVector:$matrixA,
114+
AnyVector:$matrixB,
115+
AnyVector:$matrixC,
116+
I64ArrayAttr:$mmaShape,
117+
OptionalAttr<UnitAttr>:$tf32Enabled
118+
);
115119

116120
let results = (outs AnyVector:$res);
117121

122+
let builders = [
123+
OpBuilder<(ins "Value":$matrixA,
124+
"Value":$matrixB,
125+
"Value":$matrixC,
126+
"ArrayAttr":$mmaShape)>
127+
];
128+
118129
let assemblyFormat = [{
119130
`(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict
120131
`:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res)

mlir/include/mlir/Dialect/NVGPU/Transforms/Transforms.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
namespace mlir {
2020
namespace nvgpu {
2121

22+
///
23+
/// Passes
24+
///
25+
2226
/// Optimizes vectorized accesses to a shared memory buffer specified by
2327
/// memrefValue. This transformation assumes the following:
2428
/// 1) All relevant accesses to `memrefValue` are contained with `parentOp`.
@@ -41,6 +45,29 @@ namespace nvgpu {
4145
mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
4246
Value memrefValue);
4347

48+
///
49+
/// Rewrites patterns
50+
///
51+
52+
//===----------------------------------------------------------------------===//
53+
// NVGPU transformation options exposed as auxiliary structs.
54+
//===----------------------------------------------------------------------===//
55+
/// Enum to control the lowering of `nvgpu.mmasync`.
56+
enum class MmaSyncF32Lowering { TF32 = 0, TF32x3 = 1, Unkown = 2 };
57+
58+
/// Collect patterns to convert mma.sync on f32 input and rewrite
59+
/// to use tensor cores with user provided level of accuracy:
60+
/// (a) tf32 (1 mma.sync per warp-level matrix-multiply-accumulate)
61+
/// (b) tf32x3 (3 mma.sync per warp-level matrix-multiply-accumulate)
62+
/// Typically, tf32 tensor core acceleration comes at a cost
63+
/// of accuracy from missing precision bits. While f32 has 23 precision
64+
/// bits, tf32 has only 10 precision bits. tf32x3 aims to recover the
65+
/// precision bits by spliting each operand into two tf32 values
66+
/// and issue three mma.sync tensor core operations.
67+
void populateMmaSyncF32ToTF32Patterns(
68+
RewritePatternSet &patterns,
69+
nvgpu::MmaSyncF32Lowering precision = nvgpu::MmaSyncF32Lowering::TF32);
70+
4471
} // namespace nvgpu
4572
} // namespace mlir
4673

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,14 @@ struct MmaSyncOptoNVVM : public ConvertOpToLLVMPattern<nvgpu::MmaSyncOp> {
275275
NVVM::MMATypes ptxTypeB;
276276
Optional<NVVM::MMATypes> ptxTypeC = NVVM::MmaOp::inferOperandMMAType(
277277
cType.getElementType(), /*isAccumulator=*/true);
278-
if (!ptxTypeC) {
278+
if (!ptxTypeC)
279279
return op->emitError(
280280
"could not infer the PTX type for the accumulator/result");
281-
}
281+
282+
// Tensor Cores (mma.sync) on F32 works only with TensorFloat32 (TF32).
283+
bool tf32Enabled = op->hasAttr(op.getTf32EnabledAttrName());
284+
if (aType.getElementType().isF32() && !tf32Enabled)
285+
return failure();
282286

283287
Optional<NVVM::MMAIntOverflow> overflow(llvm::None);
284288
if (aType.getElementType().isInteger(8)) {

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,8 @@ convertContractOpToMmaSync(vector::ContractionOp op,
687687
int64_t m = op.getLhs().getType().cast<VectorType>().getShape()[0];
688688
int64_t n = op.getRhs().getType().cast<VectorType>().getShape()[0];
689689
int64_t k = op.getLhs().getType().cast<VectorType>().getShape()[1];
690-
Value matmul = b.create<nvgpu::MmaSyncOp>(
691-
op.getLoc(), opC.getType(), opA, opB, opC, b.getI64ArrayAttr({m, n, k}));
690+
Value matmul = b.create<nvgpu::MmaSyncOp>(op.getLoc(), opA, opB, opC,
691+
b.getI64ArrayAttr({m, n, k}));
692692
valueMapping[op.getResult()] = matmul;
693693
return success();
694694
}

mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ LogicalResult DeviceAsyncCopyOp::verify() {
9191
//===----------------------------------------------------------------------===//
9292
// NVGPU_MmaSyncOp
9393
//===----------------------------------------------------------------------===//
94+
void MmaSyncOp::build(::mlir::OpBuilder &odsBuilder,
95+
::mlir::OperationState &odsState, Value matrixA,
96+
Value matrixB, Value matrixC, ArrayAttr mmaShape) {
97+
build(odsBuilder, odsState, matrixC.getType(), matrixA, matrixB, matrixC,
98+
mmaShape, UnitAttr());
99+
}
94100

95101
LogicalResult MmaSyncOp::verify() {
96102

@@ -122,6 +128,9 @@ LogicalResult MmaSyncOp::verify() {
122128
// vector element type
123129
Type aType = aVector.getElementType();
124130

131+
// tensor float32 (TF32) enabled
132+
bool tf32Enabled = getOperation()->hasAttr(getTf32EnabledAttrName());
133+
125134
// nvgpu.mma.sync shape (per 32 threads or per warp)
126135
int64_t m = getMmaShape()[0].cast<IntegerAttr>().getInt();
127136
int64_t n = getMmaShape()[1].cast<IntegerAttr>().getInt();
@@ -163,6 +172,10 @@ LogicalResult MmaSyncOp::verify() {
163172
return emitOpError() << "expected " << m * n
164173
<< " warp-wide matrix C elements";
165174

175+
// verify tf32 tensor cores are enabled for only F32 datatype
176+
if (tf32Enabled && !(aType.isF32()))
177+
return emitOpError() << "expected tf32 tensor cores only for F32 operands";
178+
166179
//
167180
// Extended verification
168181
//

mlir/lib/Dialect/NVGPU/Transforms/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_dialect_library(MLIRNVGPUTransforms
2-
OptimizeSharedMemory.cpp
2+
OptimizeSharedMemory.cpp
3+
MmaSyncTF32Transform.cpp
34

45
ADDITIONAL_HEADER_DIRS
56
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/NVGPU
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
//===- OptimizeSharedMemory.cpp - MLIR NVGPU pass implementation ----------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements transforms to enable 1xtf32 and 3xtf32 nvgpu.mma sync
10+
// operations on f32 input datatype
11+
//
12+
//===----------------------------------------------------------------------===//
13+
14+
#include "PassDetail.h"
15+
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
17+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
18+
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
19+
#include "mlir/Dialect/NVGPU/Passes.h"
20+
#include "mlir/Dialect/NVGPU/Transforms/Transforms.h"
21+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
22+
#include "mlir/Interfaces/SideEffectInterfaces.h"
23+
#include "mlir/Support/LogicalResult.h"
24+
#include "llvm/ADT/STLExtras.h"
25+
#include "llvm/Support/MathExtras.h"
26+
27+
using namespace mlir;
28+
using namespace mlir::nvgpu;
29+
30+
namespace {
31+
32+
struct MmaSyncF32ToTF32Pattern : public OpRewritePattern<nvgpu::MmaSyncOp> {
33+
34+
using OpRewritePattern<nvgpu::MmaSyncOp>::OpRewritePattern;
35+
36+
MmaSyncF32ToTF32Pattern(MLIRContext *context,
37+
nvgpu::MmaSyncF32Lowering precision)
38+
: OpRewritePattern<nvgpu::MmaSyncOp>(context, /*benifit*/ 1),
39+
precision(precision) {}
40+
41+
LogicalResult matchAndRewrite(nvgpu::MmaSyncOp op,
42+
PatternRewriter &rewrite) const override {
43+
Location location = op->getLoc();
44+
45+
if (op->hasAttr(op.getTf32EnabledAttrName()))
46+
return failure();
47+
48+
if (precision == MmaSyncF32Lowering::Unkown)
49+
return emitError(location, "MmaSync F32-to-TF32 cannot be lowered with "
50+
"unknown precision level");
51+
52+
if (precision == MmaSyncF32Lowering::TF32x3)
53+
return emitError(location, "TF32x3 is not supported at the moment "
54+
"for nvgpu.mma.sync on f32 datatype");
55+
56+
if (precision == MmaSyncF32Lowering::TF32)
57+
op.setTf32EnabledAttr(rewrite.getUnitAttr());
58+
59+
return success();
60+
}
61+
62+
private:
63+
/// Precision for F32 Tensor Cores (TF32 or TF32x3)
64+
nvgpu::MmaSyncF32Lowering precision;
65+
};
66+
67+
} // namespace
68+
69+
void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns(
70+
RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) {
71+
72+
patterns.add<MmaSyncF32ToTF32Pattern>(patterns.getContext(), precision);
73+
}

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
219219
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
220220
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
221221
// CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>
222-
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
222+
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4], tf32Enabled} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
223223
// CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
224224
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)>
225225
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)>

mlir/test/Dialect/NVGPU/invalid.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ func.func @m16n8k16_fp16_vector_shape_a_extended(%arg0: vector<2x4xf16>, %arg1:
7676
}
7777
// -----
7878

79+
func.func @m16n8k16_fp16_tf32Enabled(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> {
80+
// expected-error @+1 {{expected tf32 tensor cores only for F32 operands}}
81+
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16], tf32Enabled} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
82+
return %d : vector<2x2xf16>
83+
}
84+
// -----
85+
7986
func.func @m16n8k8_fp32_vector_shape_a(%arg0: vector<4x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
8087
// expected-error @+1 {{expected 128 warp-wide matrix A elements}}
8188
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: mlir-opt %s -test-nvgpu-mmasync-f32-to-tf32-patterns="precision=tf32" -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: m16n8k4_tf32
4+
func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
5+
// CHECK: nvgpu.mma.sync
6+
// CHECK-SAME: tf32Enabled
7+
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
8+
return %d : vector<2x2xf32>
9+
}
10+
11+
// -----
12+
13+
// CHECK-LABEL: m16n8k8_tf32
14+
func.func @m16n8k8_tf32(%arg0: vector<4x1xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
15+
// CHECK: nvgpu.mma.sync
16+
// CHECK-SAME: tf32Enabled
17+
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x1xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32>
18+
return %d : vector<2x2xf32>
19+
}
20+
// -----

0 commit comments

Comments
 (0)