Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][memref][spirv] Add conversion for memref.extract_aligned_point… #86750

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mshahneo
Copy link
Contributor

…er_as_index to SPIR-V

Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. Index conversion is done based on 'use-64bit-index' option.

…er_as_index to SPIR-V

Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
Index conversion is done based on 'use-64bit-index' option.
@llvmbot
Copy link
Collaborator

llvmbot commented Mar 26, 2024

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Md Abdullah Shahneous Bari (mshahneo)

Changes

…er_as_index to SPIR-V

Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. Index conversion is done based on 'use-64bit-index' option.


Full diff: https://github.com/llvm/llvm-project/pull/86750.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp (+31-5)
  • (modified) mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir (+39-1)
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 81b9f55cac80f7..0ec3ad700fe807 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -308,6 +308,17 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> {
   }
 };
 
+/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
+class ExtractAlignedPointerAsIndexOpPattern
+    : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
+public:
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
+                  OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -922,6 +933,20 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// ExtractAlignedPointerAsIndexOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
+    memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+  Type indexType = typeConverter.getIndexType();
+  rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
+                                                      adaptor.getSource());
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Pattern population
 //===----------------------------------------------------------------------===//
@@ -929,10 +954,11 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
 namespace mlir {
 void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
                                    RewritePatternSet &patterns) {
-  patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
-               DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
-               LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
-               ReinterpretCastPattern, CastPattern>(typeConverter,
-                                                    patterns.getContext());
+  patterns
+      .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
+           DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
+           MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
+           CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
+          typeConverter, patterns.getContext());
 }
 } // namespace mlir
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 10c03a270005f1..bc2af8b6edadcc 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8}, cse)" %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s
 
 // Check that with proper compute and storage extensions, we don't need to
 // perform special tricks.
@@ -414,6 +415,43 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
 
 }
 
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Int64, Addresses], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_kernel
+func.func @extract_aligned_pointer_as_index_kernel(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
+  %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
+  // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i32
+  // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i64
+  // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
+  // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+  // CHECK: return %[[R:.*]] : index
+  return %0: index
+}
+}
+
+// -----
+
+module attributes {
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader, Int64, Addresses], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_shader
+func.func @extract_aligned_pointer_as_index_shader(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
+  %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
+  // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i32
+  // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i64
+  // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
+  // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+  // CHECK: return %[[R:.*]] : index
+  return %0: index
+}
+}
+
+
 // -----
 
 // Check nontemporal attribute

@mshahneo
Copy link
Contributor Author

Just a friendly ping, @kuhar, @antiagainst :)

@@ -308,6 +308,17 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> {
}
};

/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
class ExtractAlignedPointerAsIndexOpPattern
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class ExtractAlignedPointerAsIndexOpPattern
class ExtractAlignedPointerAsIndexOpPattern final

Comment on lines +426 to +429
// CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i32
// CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i64
// CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
// CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you group CHECK and CHECK64 together? It's a bit hard to read when interleaved

Comment on lines +444 to +447
// CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i32
// CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i64
// CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
// CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here

// -----

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Int64, Addresses], []>, #spirv.resource_limits<>>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConvertPtrToU requires the PhysicalStorageBufferAddresses capability, no?

@@ -1,4 +1,5 @@
// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8}, cse)" %s | FileCheck %s
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s \
// RUN: | FileCheck --check-prefix=CHECK64 %s

mshahneo added a commit to mshahneo/mlir-extensions that referenced this pull request Apr 18, 2024
…ew SPIR-V pipeline.

The patch has a upstream PR pending review: llvm/llvm-project#86750.
The patch can be removed once the PR gets merged and LLVM version is updated.

The test cases in this PR only provides lowering from func to spirv.
XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
mshahneo added a commit to mshahneo/mlir-extensions that referenced this pull request Apr 18, 2024
…ew SPIR-V pipeline.

The patch has a upstream PR pending review: llvm/llvm-project#86750.
The patch can be removed once the PR gets merged and LLVM version is updated.

The test cases in this PR only provides lowering from func to spirv.
XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
mshahneo added a commit to mshahneo/mlir-extensions that referenced this pull request Apr 18, 2024
…ew SPIR-V pipeline.

The patch has a upstream PR pending review: llvm/llvm-project#86750.
The patch can be removed once the PR gets merged and LLVM version is updated.

The test cases in this PR only provides lowering from func to spirv.
XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
leshikus pushed a commit to leshikus/mlir-extensions that referenced this pull request Apr 30, 2024
…ew SPIR-V pipeline.

The patch has a upstream PR pending review: llvm/llvm-project#86750.
The patch can be removed once the PR gets merged and LLVM version is updated.

The test cases in this PR only provides lowering from func to spirv.
XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants