-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][memref][spirv] Add conversion for memref.extract_aligned_point… #86750
base: main
Are you sure you want to change the base?
[mlir][memref][spirv] Add conversion for memref.extract_aligned_point… #86750
Conversation
…er_as_index to SPIR-V Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. Index conversion is done based on 'use-64bit-index' option.
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Md Abdullah Shahneous Bari (mshahneo) Changes…er_as_index to SPIR-V Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. Index conversion is done based on 'use-64bit-index' option. Full diff: https://github.com/llvm/llvm-project/pull/86750.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
index 81b9f55cac80f7..0ec3ad700fe807 100644
--- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
+++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
@@ -308,6 +308,17 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> {
}
};
+/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU.
+class ExtractAlignedPointerAsIndexOpPattern
+ : public OpConversionPattern<memref::ExtractAlignedPointerAsIndexOp> {
+public:
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::ExtractAlignedPointerAsIndexOp extractOp,
+ OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
} // namespace
//===----------------------------------------------------------------------===//
@@ -922,6 +933,20 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
return success();
}
+//===----------------------------------------------------------------------===//
+// ExtractAlignedPointerAsIndexOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ExtractAlignedPointerAsIndexOpPattern::matchAndRewrite(
+ memref::ExtractAlignedPointerAsIndexOp extractOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
+ Type indexType = typeConverter.getIndexType();
+ rewriter.replaceOpWithNewOp<spirv::ConvertPtrToUOp>(extractOp, indexType,
+ adaptor.getSource());
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Pattern population
//===----------------------------------------------------------------------===//
@@ -929,10 +954,11 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
namespace mlir {
void populateMemRefToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
- patterns.add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
- DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern,
- LoadOpPattern, MemorySpaceCastOpPattern, StoreOpPattern,
- ReinterpretCastPattern, CastPattern>(typeConverter,
- patterns.getContext());
+ patterns
+ .add<AllocaOpPattern, AllocOpPattern, AtomicRMWOpPattern,
+ DeallocOpPattern, IntLoadOpPattern, IntStoreOpPattern, LoadOpPattern,
+ MemorySpaceCastOpPattern, StoreOpPattern, ReinterpretCastPattern,
+ CastPattern, ExtractAlignedPointerAsIndexOpPattern>(
+ typeConverter, patterns.getContext());
}
} // namespace mlir
diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
index 10c03a270005f1..bc2af8b6edadcc 100644
--- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
+++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8}, cse)" %s | FileCheck %s
+// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s
// Check that with proper compute and storage extensions, we don't need to
// perform special tricks.
@@ -414,6 +415,43 @@ func.func @cast_to_static_zero_elems(%arg: memref<?xf32, #spirv.storage_class<Cr
}
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Int64, Addresses], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_kernel
+func.func @extract_aligned_pointer_as_index_kernel(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
+ %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
+ // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i32
+ // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i64
+ // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
+ // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+ // CHECK: return %[[R:.*]] : index
+ return %0: index
+}
+}
+
+// -----
+
+module attributes {
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Shader, Int64, Addresses], []>, #spirv.resource_limits<>>
+} {
+// CHECK-LABEL: func @extract_aligned_pointer_as_index_shader
+func.func @extract_aligned_pointer_as_index_shader(%m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>>) -> index {
+ %0 = memref.extract_aligned_pointer_as_index %m: memref<?xf32, #spirv.storage_class<CrossWorkgroup>> -> index
+ // CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i32
+ // CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i64
+ // CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index
+ // CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index
+
+ // CHECK: return %[[R:.*]] : index
+ return %0: index
+}
+}
+
+
// -----
// Check nontemporal attribute
|
Just a friendly ping, @kuhar, @antiagainst :) |
@@ -308,6 +308,17 @@ class CastPattern final : public OpConversionPattern<memref::CastOp> { | |||
} | |||
}; | |||
|
|||
/// Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. | |||
class ExtractAlignedPointerAsIndexOpPattern |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class ExtractAlignedPointerAsIndexOpPattern | |
class ExtractAlignedPointerAsIndexOpPattern final |
// CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i32 | ||
// CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<f32, CrossWorkgroup> to i64 | ||
// CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index | ||
// CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you group CHECK and CHECK64 together? It's a bit hard to read when interleaved
// CHECK: %[[I32:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i32 | ||
// CHECK64: %[[I64:.*]] = spirv.ConvertPtrToU {{%.*}} : !spirv.ptr<!spirv.struct<(!spirv.rtarray<f32>)>, CrossWorkgroup> to i64 | ||
// CHECK: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I32]] : i32 to index | ||
// CHECK64: %[[R:.*]] = builtin.unrealized_conversion_cast %[[I64]] : i64 to index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here
// ----- | ||
|
||
module attributes { | ||
spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kernel, Int64, Addresses], []>, #spirv.resource_limits<>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ConvertPtrToU
requires the PhysicalStorageBufferAddresses
capability, no?
@@ -1,4 +1,5 @@ | |||
// RUN: mlir-opt --split-input-file --convert-memref-to-spirv="bool-num-bits=8" --cse %s | FileCheck %s | |||
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8}, cse)" %s | FileCheck %s | |||
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s | FileCheck --check-prefix=CHECK64 %s | |
// RUN: mlir-opt --split-input-file -pass-pipeline="builtin.module(convert-memref-to-spirv{bool-num-bits=8 use-64bit-index=true}, cse)" %s \ | |
// RUN: | FileCheck --check-prefix=CHECK64 %s |
…ew SPIR-V pipeline. The patch has a upstream PR pending review: llvm/llvm-project#86750. The patch can be removed once the PR gets merged and LLVM version is updated. The test cases in this PR only provides lowering from func to spirv. XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
…ew SPIR-V pipeline. The patch has a upstream PR pending review: llvm/llvm-project#86750. The patch can be removed once the PR gets merged and LLVM version is updated. The test cases in this PR only provides lowering from func to spirv. XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
…ew SPIR-V pipeline. The patch has a upstream PR pending review: llvm/llvm-project#86750. The patch can be removed once the PR gets merged and LLVM version is updated. The test cases in this PR only provides lowering from func to spirv. XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
…ew SPIR-V pipeline. The patch has a upstream PR pending review: llvm/llvm-project#86750. The patch can be removed once the PR gets merged and LLVM version is updated. The test cases in this PR only provides lowering from func to spirv. XeGPU test cases using the full pipeline is coming as part of next PR from Dimple (drprajap).
…er_as_index to SPIR-V
Converts memref.extract_aligned_pointer_as_index to spirv.ConvertPtrToU. Index conversion is done based on 'use-64bit-index' option.