Skip to content

Commit

Permalink
[mlir][nvvm] Change MMAShapeAttr to AttrDef
Browse files Browse the repository at this point in the history
MMAShapeAttr was a StructAttr

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D127348
  • Loading branch information
Mogball committed Jun 9, 2022
1 parent f814470 commit ba79bb4
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 69 deletions.
19 changes: 13 additions & 6 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Expand Up @@ -47,6 +47,15 @@ class NVVM_Op<string mnemonic, list<Trait> traits = []> :
LLVM_OpBase<NVVM_Dialect, mnemonic, traits> {
}

//===----------------------------------------------------------------------===//
// NVVM attribute definitions
//===----------------------------------------------------------------------===//

class NVVM_Attr<string attrName, string attrMnemonic, list<Trait> traits = []>
: AttrDef<NVVM_Dialect, attrName, traits> {
let mnemonic = attrMnemonic;
}

//===----------------------------------------------------------------------===//
// NVVM intrinsic operations
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -460,12 +469,10 @@ def MMAIntOverflowAttr : EnumAttr<NVVM_Dialect, MMAIntOverflow, "mma_int_overflo
}

/// Attribute to hold the MMA shape
def NVVM_MMAShapeAttr : StructAttr<"MMAShapeAttr", NVVM_Dialect, [
StructFieldAttr<"m", I32Attr>,
StructFieldAttr<"n", I32Attr>,
StructFieldAttr<"k", I32Attr>
]> {
def NVVM_MMAShapeAttr : NVVM_Attr<"MMAShape", "shape"> {
let summary = "Attribute for MMA operation shape.";
let parameters = (ins "int":$m, "int":$n, "int":$k);
let assemblyFormat = "`<` struct(params) `>`";
}

// Returns true if this combination of layout/satf for MMA ops is supported;
Expand Down Expand Up @@ -983,7 +990,7 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
string llvmBuilder = [{
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
auto intId = mlir::NVVM::MmaOp::getIntrinsicID(
$shape.m().getInt(), $shape.n().getInt(), $shape.k().getInt(),
$shape.getM(), $shape.getN(), $shape.getK(),
$b1Op, $intOverflowBehavior,
$layoutA, $layoutB,
$multiplicandAPtxType.getValue(),
Expand Down
16 changes: 11 additions & 5 deletions mlir/include/mlir/IR/Builders.h
Expand Up @@ -78,10 +78,17 @@ class Builder {
TupleType getTupleType(TypeRange elementTypes);
NoneType getNoneType();

/// Get or construct an instance of the type 'ty' with provided arguments.
/// Get or construct an instance of the type `Ty` with provided arguments.
template <typename Ty, typename... Args>
Ty getType(Args... args) {
return Ty::get(context, args...);
Ty getType(Args &&...args) {
return Ty::get(context, std::forward<Args>(args)...);
}

/// Get or construct an instance of the attribute `Attr` with provided
/// arguments.
template <typename Attr, typename... Args>
Attr getAttr(Args &&...args) {
return Attr::get(context, std::forward<Args>(args)...);
}

// Attributes.
Expand Down Expand Up @@ -510,8 +517,7 @@ class OpBuilder : public Builder {
Operation *cloneWithoutRegions(Operation &op) {
return insert(op.cloneWithoutRegions());
}
template <typename OpT>
OpT cloneWithoutRegions(OpT op) {
template <typename OpT> OpT cloneWithoutRegions(OpT op) {
return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
}

Expand Down
10 changes: 3 additions & 7 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Expand Up @@ -196,11 +196,8 @@ void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,

assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
MLIRContext *ctx = builder.getContext();
Type i32 = builder.getIntegerType(32);
result.addAttribute(
"shape", MMAShapeAttr::get(builder.getIntegerAttr(i32, shape[0]),
builder.getIntegerAttr(i32, shape[1]),
builder.getIntegerAttr(i32, shape[2]), ctx));
"shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));

result.addOperands(operandA);
result.addOperands(operandB);
Expand Down Expand Up @@ -358,9 +355,8 @@ LogicalResult MmaOp::verify() {
auto s32x2StructTy =
LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});

std::array<int64_t, 3> mmaShape{shapeAttr().m().getInt(),
shapeAttr().n().getInt(),
shapeAttr().k().getInt()};
std::array<int64_t, 3> mmaShape{shapeAttr().getM(), shapeAttr().getN(),
shapeAttr().getK()};

// These variables define the set of allowed data types for matrices A, B, C,
// and result.
Expand Down
16 changes: 8 additions & 8 deletions mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir
Expand Up @@ -12,7 +12,7 @@ func.func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2:
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
// CHECK-NOT llvm.extractvalue
// CHECK: [[d:%.+]] = nvvm.mma.sync
// CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
Expand All @@ -30,7 +30,7 @@ func.func @m16n8k16_fp16(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2:
func.func @m16n8k16_fp16_fp32(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> {
// We just need to check the mma instruction and the manipulatin of the result.
// CHECK: [[d:%.+]] = nvvm.mma.sync
// CHECK-SAME: shape = {k = 16 : i32, m = 16 : i32, n = 8 : i32}
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 16>
// CHECK-SAME: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32)>
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x2xf32>) -> vector<2x2xf32>
// CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32>
Expand Down Expand Up @@ -61,7 +61,7 @@ func.func @m16n8k8_fp16(%arg0: vector<2x2xf16>, %arg1: vector<1x2xf16>, %arg2: v
// CHECK: llvm.extractvalue %{{.*}}[1] : !llvm.array<2 x vector<2xf16>>
// CHECK-NOT llvm.extractvalue
// CHECK: [[d:%.+]] = nvvm.mma.sync
// CHECK-SAME: shape = {k = 8 : i32, m = 16 : i32, n = 8 : i32}
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 8>
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<2x2xf16>, vector<1x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
// CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
Expand Down Expand Up @@ -95,7 +95,7 @@ func.func @m16n8k32_int8(%arg0: vector<4x4xi8>, %arg1: vector<2x4xi8>, %arg2: ve
// CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>
// CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s8>
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s8>
// CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 32>
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
Expand All @@ -116,7 +116,7 @@ func.func @m16n8k32_i4(%arg0: vector<2x8xi4>, %arg1: vector<1x8xi4>, %arg2: vect
// CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>
// CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: shape = {k = 32 : i32, m = 16 : i32, n = 8 : i32}
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 32>
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<2x8xi4>, vector<1x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
Expand All @@ -143,7 +143,7 @@ func.func @m16n8k64_i4(%arg0: vector<4x8xi4>, %arg1: vector<2x8xi4>, %arg2: vect
// CHECK-SAME: intOverflowBehavior = #nvvm.mma_int_overflow<satfinite>
// CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<s4>
// CHECK-SAME: shape = {k = 64 : i32, m = 16 : i32, n = 8 : i32}
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 64>
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 64]} : (vector<4x8xi4>, vector<2x8xi4>, vector<2x2xi32>) -> vector<2x2xi32>
return %d : vector<2x2xi32>
}
Expand All @@ -156,7 +156,7 @@ func.func @m8n8k4_f64(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vec
// CHECK: llvm.extractvalue
// CHECK: llvm.extractvalue
// CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}]
// CHECK-SAME: shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}
// CHECK-SAME: shape = #nvvm.shape<m = 8, n = 8, k = 4>
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64>
// CHECK: llvm.mlir.undef : vector<2xf64>
// CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f64, f64)>
Expand Down Expand Up @@ -217,7 +217,7 @@ func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: v
// CHECK: [[d:%.+]] = nvvm.mma.sync A[{{%.+}}, {{%.+}}] B[{{%.+}}] C[{{%.+}}, {{%.+}}, {{%.+}}, {{%.+}}]
// CHECK-SAME: multiplicandAPtxType = #nvvm.mma_type<tf32>
// CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type<tf32>
// CHECK-SAME: shape = {k = 4 : i32, m = 16 : i32, n = 8 : i32}
// CHECK-SAME: shape = #nvvm.shape<m = 16, n = 8, k = 4>
// CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)>
%d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<4x1xf32>) -> vector<4x1xf32>
// CHECK: [[el:%.+]] = llvm.extractvalue [[d]][0]
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Expand Up @@ -541,7 +541,7 @@ func.func @nvvm_invalid_mma_0(%a0 : f16, %a1 : f16,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
// expected-error@+1 {{Could not match types for the A operands; expected one of 2xvector<2xf16> but got f16, f16}}
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
{layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
{layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (f16, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
}

Expand All @@ -553,7 +553,7 @@ func.func @nvvm_invalid_mma_1(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
// expected-error@+1 {{Could not match allowed types for the result; expected one of !llvm.struct<(vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>)>, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)> but got !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>}}
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
{layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
{layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = #nvvm.shape<m = 8, n = 8, k = 4>} : (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f16)>
}

Expand All @@ -565,7 +565,7 @@ func.func @nvvm_invalid_mma_2(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%c4 : f32, %c5 : f32, %c6 : f32, %c7 : f32) {
// expected-error@+1 {{op requires attribute 'layoutA'}}
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1, %c2, %c3, %c4, %c5, %c6, %c7]
{shape = {k = 4 : i32, m = 8 : i32, n = 8 : i32}}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
{shape = #nvvm.shape<m = 8, n = 8, k = 4>}: (vector<2xf16>, vector<2xf16>, f32) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
llvm.return %0 : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
}

Expand All @@ -575,7 +575,7 @@ func.func @nvvm_invalid_mma_3(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,
%c0 : vector<2xf16>, %c1 : vector<2xf16>) {
// expected-error@+1 {{unimplemented variant for MMA shape <8, 8, 16>}}
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = {k = 16 : i32, m = 8 : i32, n = 8 : i32}} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0, %b1] C[%c0, %c1] {layoutA=#nvvm.mma_layout<row>, layoutB=#nvvm.mma_layout<col>, shape = #nvvm.shape<m = 8, n = 8, k = 16>} : (vector<2xf16>, vector<2xf16>, vector<2xf16>) -> !llvm.struct<(vector<2xf16>, vector<2xf16>)>
llvm.return %0 : !llvm.struct<(vector<2xf16>, vector<2xf16>)>
}

Expand All @@ -588,7 +588,7 @@ func.func @nvvm_invalid_mma_8(%a0 : i32, %a1 : i32,
%0 = nvvm.mma.sync A[%a0, %a1] B[%b0] C[%c0, %c1, %c2, %c3]
{layoutA = #nvvm.mma_layout<row>, layoutB = #nvvm.mma_layout<col>,
multiplicandAPtxType = #nvvm.mma_type<b1>, multiplicandBPtxType = #nvvm.mma_type<b1>,
shape = {k = 128 : i32, m = 16 : i32, n = 8 : i32}} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
shape = #nvvm.shape<m = 16, n = 8, k = 128>} : (i32, i32, i32) -> !llvm.struct<(i32,i32,i32,i32)>
llvm.return %0 : !llvm.struct<(i32,i32,i32,i32)>
}

Expand Down

0 comments on commit ba79bb4

Please sign in to comment.