Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][NVGPU] Introduce nvgpu.wargroup.mma.store Op for Hopper GPUs #65441

Merged
merged 7 commits into from
Oct 5, 2023

Conversation

grypp
Copy link
Member

@grypp grypp commented Sep 6, 2023

This PR introduces a new Op called warpgroup.mma.store to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted result(s) nvgpu.warpgroup.accumulator produced by warpgroup.mma to the given memref.

An example of fragmentated matrix is given here :
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d

The warpgroup.mma.store does followings:

  1. Takes one or more nvgpu.warpgroup.accumulator type (fragmented results matrix)
  2. Calculates indexes per thread in warp-group and stores the data into give memref.

Here's an example usage:

// A warpgroup performs GEMM, results in fragmented matrix
%result1, %result2 = nvgpu.warpgroup.mma ...

// Stores the fragmented result to memref
nvgpu.warpgroup.mma.store [%result1, %result2], %matrixD : 
    !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>,
    !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> 
    to memref<128x128xf32,3>

@qcolombet
Copy link
Collaborator

The cursed typo "wargroup" is on :).

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Outdated Show resolved Hide resolved
mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Outdated Show resolved Hide resolved
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 22, 2023

@llvm/pr-subscribers-mlir-nvgpu

Changes

[MLIR][NVGPU] Introduce nvgpu.wargroup.mma.store Op for Hopper GPUs

This work introduces a new operation called wargroup.mma.store to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted results of WGMMA to the given memref.

An example of fragmentation is given here :
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d

The wargroup.mma.store does followings:

  1. Takes one or more fragmented results matrix.
  2. Calculates indexes per thread in warp group and stores the data into give memref.

Here's an example usage of the nvgpu.wargroup.mma operation:

// Performs matmul, results are fragmented and in registers
%res, %res2 = nvgpu.wargroup.mma ...

// Stores the fragmented result to the give memory
nvgpu.wargroup.mma.store [%res1, %res2], %matrixD : 
                !nvgpu.warpgroup.result&lt;tensor = !llvm.struct&lt;...&gt;&gt;, 
                !nvgpu.warpgroup.result&lt;tensor = !llvm.struct&lt;...&gt;&gt; 
                to memref&lt;128x128xf32,3&gt;

Depends on #65440


Full diff: https://github.com/llvm/llvm-project/pull/65441.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td (+20)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+81-2)
  • (modified) mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp (+32)
diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
index 90381648dac6acc..4e80c33aec6043d 100644
--- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
+++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
@@ -721,4 +721,24 @@ def NVGPU_WarpgroupMmaOp : NVGPU_Op<"warpgroup.mma"> {
   let hasVerifier = 1;
 }
 
+def NVGPU_WarpgroupMmaStoreOp : NVGPU_Op<"warpgroup.mma.store"> {
+  let description = [{
+    The `nvgpu.warpgroup.mma.store` op performs the store of fragmented result 
+    in $matrixD to give memref. 
+
+    [See the details of register fragment layout for accumulator matrix D]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d) 
+
+    Note that, the op must be run with warp group.
+  }];
+
+  let arguments = (ins Variadic<NVGPU_WarpgroupAccumulator>:$matrixD,
+                       Arg<AnyMemRef, "", [MemWrite]>:$dstMemref);
+  
+  let assemblyFormat = [{
+    `[` $matrixD `]` `,` $dstMemref attr-dict `:` type($matrixD) `to` type($dstMemref)
+  }];
+  let hasVerifier = 1;
+}
+
 #endif // NVGPU
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index f74aa05c0c4c4ff..006ecbef2546e3e 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -11,6 +11,7 @@
 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
@@ -409,8 +410,8 @@ struct ConvertNVGPUToNVVMPass
   using Base::Base;
 
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry
-        .insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect>();
+    registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect,
+                    arith::ArithDialect>();
   }
 
   void runOnOperation() override {
@@ -451,6 +452,7 @@ struct ConvertNVGPUToNVVMPass
     populateNVGPUToNVVMConversionPatterns(converter, patterns);
     LLVMConversionTarget target(getContext());
     target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
+    target.addLegalDialect<::mlir::arith::ArithDialect>();
     target.addLegalDialect<::mlir::memref::MemRefDialect>();
     target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
     mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
@@ -1299,6 +1301,82 @@ struct NVGPUWarpgroupMmaOpLowering
   }
 };
 
+struct NVGPUWarpgroupMmaStoreOpLowering
+    : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> {
+  using ConvertOpToLLVMPattern<
+      nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
+
+  void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op,
+                             OpAdaptor adaptor,
+                             ConversionPatternRewriter &rewriter,
+                             int offset) const {
+    Location loc = op->getLoc();
+    Type i32 = rewriter.getI32Type();
+
+    auto makeConst = [&](int32_t index) -> Value {
+      return rewriter.create<LLVM::ConstantOp>(
+          loc, i32, rewriter.getI32IntegerAttr(index));
+    };
+    Value c4 = makeConst(4);
+    Value c32 = makeConst(kWarpSize);
+    Value c8 = makeConst(8);
+    Value c2 = makeConst(2);
+    Value c1 = makeConst(1);
+    Value c16 = makeConst(16);
+
+    auto makeMul = [&](Value lhs, Value rhs) -> Value {
+      return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs);
+    };
+    auto makeAdd = [&](Value lhs, Value rhs) -> Value {
+      return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs);
+    };
+
+    Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32);
+    Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32);
+    Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32);
+    Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4);
+    Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4);
+
+    auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
+                                   TypedValue<::mlir::MemRefType> memref) {
+      Type it = rewriter.getIndexType();
+      Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x);
+      Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y);
+      Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1));
+      Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i);
+      Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1);
+      rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0});
+      rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1});
+    };
+
+    Value tj = makeMul(lane4modId, c2);
+    Value ti = makeAdd(lane4Id, makeMul(warpId, c16));
+    if (offset)
+      ti = makeAdd(ti, makeConst(offset));
+    for (int i = 0; i < 2; ++i) {
+      Value idx = makeAdd(ti, makeMul(makeConst(i), c8));
+      for (int j = 0; j < 16; ++j) {
+        Value idy = makeAdd(tj, makeMul(makeConst(j), c8));
+        int sIndex = i * 2 + j * 4;
+        makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref());
+      }
+    }
+  }
+
+  LogicalResult
+  matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    int offset = 0;
+    for (auto result : adaptor.getMatrixD()) {
+      auto stype = result.getType().cast<LLVM::LLVMStructType>();
+      storeFragmentedMatrix(result, op, adaptor, rewriter, offset);
+      offset += stype.getBody().size();
+    }
+    rewriter.eraseOp(op);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
@@ -1315,6 +1393,7 @@ void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
       NVGPUMBarrierArriveExpectTxLowering,   // nvgpu.mbarrier.arrive.expect_tx
       NVGPUGenerateGmmaDescriptorLowering,   // nvgpu.wgmma.generate.descriptor
       NVGPUWarpgroupMmaOpLowering,           // nvgpu.warpgroup.mma
+      NVGPUWarpgroupMmaStoreOpLowering,      // nvgpu.warpgroup.mma.store`
       MmaSyncOptoNVVM, MmaLdMatrixOpToNVVM, NVGPUAsyncCopyLowering,
       NVGPUAsyncCreateGroupLowering, NVGPUAsyncWaitLowering,
       NVGPUMmaSparseSyncLowering>(converter);
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index d96ed69982870b4..1486bba5d3e57f6 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -12,6 +12,7 @@
 
 #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
@@ -529,6 +530,37 @@ LogicalResult WarpgroupMmaOp::verify() {
   return success();
 }
 
+LogicalResult WarpgroupMmaStoreOp::verify() {
+  Type stype = getMatrixD()
+                   .front()
+                   .getType()
+                   .cast<WarpgroupAccumulatorType>()
+                   .getFragmented();
+
+  for (auto result : getMatrixD()) {
+    auto resultStype = result.getType()
+                           .cast<WarpgroupAccumulatorType>()
+                           .getFragmented()
+                           .dyn_cast<LLVM::LLVMStructType>();
+    if (!resultStype)
+      return emitOpError() << "result is " << result.getType()
+                           << "  but must keep type of llvm struct";
+    if (stype != resultStype)
+      return emitOpError() << "all results must be the same type";
+
+    // todo improve this limitation
+    if (!resultStype.getBody().front().isF32()) {
+      return emitOpError() << "supporst only f32 results for the time being";
+    }
+  }
+
+  if (!llvm::all_equal(stype.cast<LLVM::LLVMStructType>().getBody())) {
+    return emitOpError() << "all element types must be equal  ";
+  }
+
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // TableGen'd dialect, type, and op definitions
 //===----------------------------------------------------------------------===//

Copy link
Collaborator

@qcolombet qcolombet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks almost good to me.

I'd like to see a bit more comments and I believe there some cmake/blaze missing changes.

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Outdated Show resolved Hide resolved
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp Outdated Show resolved Hide resolved
@github-actions
Copy link

github-actions bot commented Oct 2, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Collaborator

@qcolombet qcolombet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the singleton implementation for the warpsize value is broken and anyway overkill.
I believe we still miss a change in a cmake file.

Other than that couple of nits but LGTM.

mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td Outdated Show resolved Hide resolved
mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp Outdated Show resolved Hide resolved
This work introduces a new operation called `warpgroup.mma.store` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate storing fragmanted results of WGMMA to the given memref.

An example of fragmentation is given here :
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-d

The `warpgroup.mma.store` does followings:
1) Takes one or more fragmented results matrix.
2) Calculates indexes per thread in warp group and stores the data into give memref.

Here's an example usage of the `nvgpu.warpgroup.mma` operation:
```
// Performs matmul, results are fragmented and in registers
%res, %res2 = nvgpu.warpgroup.mma ...

// Stores the fragmented result to the give memory
nvgpu.warpgroup.mma.store [%res1, %res2], %matrixD :
                !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>,
                !nvgpu.warpgroup.result<tensor = !llvm.struct<...>>
                to memref<128x128xf32,3>
```

Depends on llvm#65440
@grypp grypp merged commit d20fbc9 into llvm:main Oct 5, 2023
2 of 3 checks passed
@grypp grypp deleted the nvgpu-store branch October 5, 2023 08:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants