Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][AMDGPU]Add refactoring for shared-mem optimization #81791

Merged
merged 3 commits into from
Feb 15, 2024

Conversation

erman-gurses
Copy link
Contributor

Addressing the issues in this PR: #81550

Copy link

github-actions bot commented Feb 14, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@erman-gurses erman-gurses changed the title [MLIR][AMDGPU]Add refactor for shared-mem optimization [MLIR][AMDGPU]Add refactoring for shared-mem optimization Feb 14, 2024
@erman-gurses erman-gurses force-pushed the eg_refactor_shmem_opt branch 2 times, most recently from c5c69a9 to da8930d Compare February 14, 2024 21:47
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 14, 2024

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir-amdgpu

@llvm/pr-subscribers-mlir

Author: None (erman-gurses)

Changes

Addressing the issues in this PR: #81550


Full diff: https://github.com/llvm/llvm-project/pull/81791.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h (+2-1)
  • (modified) mlir/lib/Dialect/AMDGPU/CMakeLists.txt (+1-1)
  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp (+20-18)
  • (modified) mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir (-10)
diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
index b4e9ad27003db1..22bc9b9e0cf842 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h
@@ -49,7 +49,8 @@ namespace amdgpu {
 mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
                                                        Value memrefValue);
 
-void optimizeSharedMemoryReadsAndWritesOp(mlir::func::FuncOp funcOp);
+std::optional<mlir::LogicalResult>
+optimizeSharedMemoryReadsAndWritesOp(mlir::func::FuncOp funcOp);
 
 } // namespace amdgpu
 } // namespace mlir
diff --git a/mlir/lib/Dialect/AMDGPU/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
index 63b4d8b99f53fd..c47e4c5495c17b 100644
--- a/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/AMDGPU/CMakeLists.txt
@@ -1,4 +1,4 @@
 add_subdirectory(IR)
-add_subdirectory(Utils)
 add_subdirectory(TransformOps)
 add_subdirectory(Transforms)
+add_subdirectory(Utils)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
index 7c50a876e78f45..c33608a496470e 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp
@@ -50,12 +50,12 @@ constexpr int64_t kDefaultVectorSizeBits = 64;
 static Value permuteVectorOffset(OpBuilder &b, Location loc,
                                  ArrayRef<Value> indices, MemRefType memrefTy,
                                  int64_t srcDim, int64_t tgtDim) {
-  /// Adjust the src index to change how often the permutation changes
-  /// if necessary.
+  // Adjust the src index to change how often the permutation changes
+  // if necessary.
   Value src = indices[srcDim];
 
-  /// We only want to permute every N iterations of the target dim where N is
-  /// ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
+  // We only want to permute every N iterations of the target dim where N is
+  // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)).
   const int64_t permuteEveryN = std::max<int64_t>(
       1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) *
                                         memrefTy.getElementTypeBitWidth()) /
@@ -110,8 +110,8 @@ static void transformIndices(OpBuilder &builder, Location loc,
       permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim);
 }
 
-/// Return all operations within `parentOp` that read from or write to
-/// `shmMemRef`.
+// Return all operations within `parentOp` that read from or write to
+// `shmMemRef`.
 static LogicalResult
 getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
                       SmallVector<Operation *, 16> &readOps,
@@ -131,8 +131,8 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef,
       writeOps.push_back(op);
   });
 
-  /// Restrict to a supported set of ops. We also require at least 2D access,
-  /// although this could be relaxed.
+  // Restrict to a supported set of ops. We also require at least 2D access,
+  // although this could be relaxed.
   if (llvm::any_of(readOps, [](Operation *op) {
         return !isa<memref::LoadOp, vector::LoadOp, vector::TransferReadOp>(
                    op) ||
@@ -157,15 +157,15 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
       !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType))
     return failure();
 
-  /// Abort if the given value has any sub-views; we do not do any alias
-  /// analysis.
+  // Abort if the given value has any sub-views; we do not do any alias
+  // analysis.
   bool hasSubView = false;
   parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; });
   if (hasSubView)
     return failure();
 
-  /// Check if this is necessary given the assumption of 128b accesses:
-  /// If dim[rank-1] is small enough to fit 8 rows in a 128B line.
+  // Check if this is necessary given the assumption of 128b accesses:
+  // If dim[rank-1] is small enough to fit 8 rows in a 128B line.
   const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1);
   const int64_t rowsPerLine =
       (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) /
@@ -175,8 +175,8 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   if (rowsPerLine >= threadGroupSize)
     return failure();
 
-  /// Get sets of operations within the function that read/write to shared
-  /// memory.
+  // Get sets of operations within the function that read/write to shared
+  // memory.
   SmallVector<Operation *, 16> shmReadOps;
   SmallVector<Operation *, 16> shmWriteOps;
   if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps,
@@ -191,7 +191,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   int64_t tgtDim = memRefType.getRank() - 1;
   int64_t srcDim = memRefType.getRank() - 2;
 
-  /// Transform indices for the ops writing to shared memory.
+  // Transform indices for the ops writing to shared memory.
   while (!shmWriteOps.empty()) {
     Operation *shmWriteOp = shmWriteOps.pop_back_val();
     builder.setInsertionPoint(shmWriteOp);
@@ -203,7 +203,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
     amdgpu::setIndices(shmWriteOp, transformedIndices);
   }
 
-  /// Transform indices for the ops reading from shared memory.
+  // Transform indices for the ops reading from shared memory.
   while (!shmReadOps.empty()) {
     Operation *shmReadOp = shmReadOps.pop_back_val();
     builder.setInsertionPoint(shmReadOp);
@@ -218,7 +218,8 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp,
   return success();
 }
 
-void amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
+std::optional<mlir::LogicalResult>
+amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
   SmallVector<memref::AllocOp> shmAllocOps;
   funcOp.walk([&](memref::AllocOp allocOp) {
     if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(allocOp.getType()))
@@ -228,8 +229,9 @@ void amdgpu::optimizeSharedMemoryReadsAndWritesOp(func::FuncOp funcOp) {
   for (auto allocOp : shmAllocOps) {
     if (failed(amdgpu::optimizeSharedMemoryReadsAndWrites(funcOp,
                                                           allocOp.getMemref())))
-      return;
+      return failure();
   }
+  return success();
 }
 
 struct OptimizeSharedMemoryPass
diff --git a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
index dfdd1b17e244e3..143e7c2d270952 100644
--- a/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
+++ b/mlir/test/Dialect/AMDGPU/transform_optimize_shmem_reads_writes.mlir
@@ -7,22 +7,17 @@
                     %fragRow: index, %fragCol: index, 
                     %fragColPerm: index,
                     %stRow: index, %stCol: index) {
-    // CHECK:    %[[cst:.+]] = arith.constant 0.000000e+00 : f16                  
     %cst = arith.constant 0.000000e+00 : f16
 
-    // CHECK: [[shmA:%.+]] = memref.alloc
-    // CHECK: [[shmB:%.+]] = memref.alloc
     %shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3>
     %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3>
 
-    // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
     // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
     // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
     // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
     // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
-    // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
     vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3>
     gpu.barrier
     gpu.barrier
@@ -31,17 +26,13 @@
     // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
     // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]       
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]] 
-    // CHECK:  vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16>
     %1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16>
-
-    // CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     %2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16>
     // CHECK: [[c7:%.+]] = arith.constant 7 : index                  
     // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]]       
     // CHECK: [[c2:%.+]] = arith.constant 2 : index                 
     // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]]     
     // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]]  
-    // CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
     vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3>
     gpu.barrier
     gpu.barrier
@@ -50,7 +41,6 @@
     // CHECK: [[c2:%.+]] = arith.constant 2 : index                     
     // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] 
     // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]]
-    // CHECK:  vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16>
     %3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16>
     return
   }

@erman-gurses
Copy link
Contributor Author

erman-gurses commented Feb 14, 2024

Hi @ftynse, let me know please if I am missing something.

@erman-gurses erman-gurses self-assigned this Feb 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants