diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td index ec0c18bd74824..a25b9aa068077 100644 --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -53,8 +53,10 @@ def NVGPU_DeviceAsyncToken : DialectType< class NVGPU_Op traits = []> : Op {} -def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix", - [MemoryEffects<[MemRead]>]> { +def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix", [ + MemoryEffects<[MemRead]>, + PredOpTrait<"srcMemref and res have same element type", + TCresVTEtIsSameAsOp<0, 0>>]> { let description = [{ The `nvgpu.ldmatrix` op represents loading a matrix fragment from memory. The load source and result type must be compatible with lowering @@ -79,12 +81,14 @@ def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix", let assemblyFormat = [{ $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res) }]; + + let hasVerifier = 1; } def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [ - NoSideEffect, - PredOpTrait<"matrixA and matrixB have same element type", TCopVTEtIsSameAs<0, 1>>, - ]> { + NoSideEffect, + PredOpTrait<"matrixA and matrixB have same element type", + TCopVTEtIsSameAs<0, 1>>]> { let description = [{ The `nvgpu.mma.sync` op represents the distributed form of a collective matrix-multiply-and-accumulate (mma) operation that is compatible with @@ -120,8 +124,8 @@ def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [ } -def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", - [AttrSizedOperandSegments]> { +def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [ + AttrSizedOperandSegments]> { let summary = "device-side asynchronous copy"; let description = [{ The `gpu.device_async_copy` op initiates an asynchronous copy operation of diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index ac937e0fea0eb..1ced01179dd82 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -88,6 +88,10 @@ LogicalResult DeviceAsyncCopyOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// NVGPU_MmaSyncOp +//===----------------------------------------------------------------------===// + LogicalResult MmaSyncOp::verify() { // Fundamental tensor core mma.sync op @@ -186,5 +190,56 @@ LogicalResult MmaSyncOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// NVGPU_LdMatrixOp +//===----------------------------------------------------------------------===// +LogicalResult LdMatrixOp::verify() { + + // ldmatrix reads data from source in shared memory + auto srcMemref = getSrcMemref().getType().cast(); + + // ldmatrix writes data to result/destination in vector registers + auto resVector = getRes().getType().cast(); + + // vector register shape, element type, and bitwidth + ArrayRef resShape = resVector.getShape(); + Type resType = resVector.getElementType(); + int64_t elementBitWidth = resType.getIntOrFloatBitWidth(); + + // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread + int64_t numElementsPer32b = 32 / elementBitWidth; + + // number of 8-by-8 tiles + int64_t numTiles = getNumTiles(); + + // transpose elements in vector registers at 16b granularity when true + bool isTranspose = getTranspose(); + + // address space id for shared memory + unsigned smemAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace(); + + // + // verification + // + + if (!(srcMemref.getMemorySpaceAsInt() == smemAddressSpace)) + return emitError() + << "expected nvgpu.ldmatrix srcMemref must have memory space " + << smemAddressSpace; + if (elementBitWidth > 32) + return emitError() << "nvgpu.ldmatrix works for 32b or lower"; + if (isTranspose && !(elementBitWidth == 16)) + return emitError() + << "nvgpu.ldmatrix transpose works only at 16b granularity"; + if (!(resShape[1] == numElementsPer32b)) + return emitError() << "expected vector register shape[1] = " + << numElementsPer32b; + if (!(resShape[0] == numTiles)) + return emitError() + << "expected vector register shape[0] and numTiles to match"; + + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir index 6be9cda42ccb3..5f1894faeb709 100644 --- a/mlir/test/Dialect/NVGPU/invalid.mlir +++ b/mlir/test/Dialect/NVGPU/invalid.mlir @@ -1,4 +1,53 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s + +func.func @ldmatrix_address_space_f16_x4(%arg0: memref<128x128xf16, 2>) -> vector<4x1xf16> { + %c0 = arith.constant 0 : index + // expected-error @+1 {{expected nvgpu.ldmatrix srcMemref must have memory space 3}} + %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 2> -> vector<4x1xf16> + return %a : vector<4x1xf16> +} +// ----- + +func.func @ldmatrix_num_elements_f16_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x1xf16> { + %c0 = arith.constant 0 : index + // expected-error @+1 {{expected vector register shape[1] = 2}} + %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x1xf16> + return %a : vector<4x1xf16> +} +// ----- + +func.func @ldmatrix_num_tiles_f16_x4(%arg0: memref<128x128xf16, 3>) -> vector<2x2xf16> { + %c0 = arith.constant 0 : index + // expected-error @+1 {{expected vector register shape[0] and numTiles to match}} + %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<2x2xf16> + return %a : vector<2x2xf16> +} +// ----- + +func.func @ldmatrix_num_tiles_f32_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x2xf32> { + %c0 = arith.constant 0 : index + // expected-error @+1 {{expected vector register shape[1] = 1}} + %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf32, 3> -> vector<4x2xf32> + return %a : vector<4x2xf32> +} +// ----- + +func.func @ldmatrix_trans_f32_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x1xf32> { + %c0 = arith.constant 0 : index + // expected-error @+1 {{nvgpu.ldmatrix transpose works only at 16b granularity}} + %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = true, numTiles = 4 : i32} : memref<128x128xf32, 3> -> vector<4x1xf32> + return %a : vector<4x1xf32> +} +// ----- + +func.func @ldmatrix_type_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x2xf16> { + %c0 = arith.constant 0 : index + // expected-error @+1 {{'nvgpu.ldmatrix' op failed to verify that srcMemref and res have same element type}} + %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf32, 3> -> vector<4x2xf16> + return %a : vector<4x2xf16> +} +// ----- + func.func @m16n8k16_fp16_vector_shape_a(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { // expected-error @+1 {{expected 256 warp-wide matrix A elements}} %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>