-
Notifications
You must be signed in to change notification settings - Fork 10.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][sparse] Support explicit/implicit value for complex type #90771
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
yinying-lisa-li
requested review from
rupprecht,
aartbik,
PeimingLiu and
matthias-springer
as code owners
May 1, 2024 20:50
llvmbot
added
mlir:sparse
Sparse compiler in MLIR
mlir
bazel
"Peripheral" support tier build system: utils/bazel
labels
May 1, 2024
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Yinying Li (yinying-lisa-li) ChangesFull diff: https://github.com/llvm/llvm-project/pull/90771.diff 6 Files Affected:
diff --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
index dd6f1037f71b53..6f59b69bddce86 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
@@ -45,6 +45,7 @@ add_mlir_dialect_library(MLIRSparseTensorDialect
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRComplexDialect
MLIRDialect
MLIRDialectUtils
MLIRIR
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 028a69da10c1e1..dac028e8d53cb0 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
@@ -663,6 +664,9 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
explicitVal = result;
} else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
explicitVal = result;
+ } else if (auto result =
+ llvm::dyn_cast<::mlir::complex::NumberAttr>(attr)) {
+ explicitVal = result;
} else {
parser.emitError(parser.getNameLoc(),
"expected a numeric value for explicitVal");
@@ -678,6 +682,9 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
implicitVal = result;
} else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
implicitVal = result;
+ } else if (auto result =
+ llvm::dyn_cast<::mlir::complex::NumberAttr>(attr)) {
+ implicitVal = result;
} else {
parser.emitError(parser.getNameLoc(),
"expected a numeric value for implicitVal");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index cf3c35f5fa4c78..d0ef8a6860bb2d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -401,9 +401,12 @@ inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
// Generates a constant from a validated value carrying attribute.
inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) {
- if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
- Type tp = cast<TypedAttr>(arrayAttr[0]).getType();
- return builder.create<complex::ConstantOp>(loc, tp, arrayAttr);
+ if (auto complexAttr = dyn_cast<complex::NumberAttr>(attr)) {
+ Type tp = cast<ComplexType>(complexAttr.getType()).getElementType();
+ return builder.create<complex::ConstantOp>(
+ loc, complexAttr.getType(),
+ builder.getArrayAttr({FloatAttr::get(tp, complexAttr.getReal()),
+ FloatAttr::get(tp, complexAttr.getImag())}));
}
return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(attr));
}
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 7eeda9a9880268..7fb1c76c1a1ff6 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -80,6 +80,21 @@ func.func private @sparse_csr(tensor<?x?xi64, #CSR_OnlyOnes>)
// -----
+#CSR_OnlyOnes = #sparse_tensor.encoding<{
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ posWidth = 64,
+ crdWidth = 64,
+ explicitVal = #complex.number<:f32 1.0, 0.0>,
+ implicitVal = #complex.number<:f32 0.0, 0.0>
+}>
+
+// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = #complex.number<:f32 1.000000e+00, 0.000000e+00> : complex<f32>, implicitVal = #complex.number<:f32 0.000000e+00, 0.000000e+00> : complex<f32> }>
+// CHECK-LABEL: func private @sparse_csr(
+// CHECK-SAME: tensor<?x?xcomplex<f32>, #[[$CSR_OnlyOnes]]>)
+func.func private @sparse_csr(tensor<?x?xcomplex<f32>, #CSR_OnlyOnes>)
+
+// -----
+
#BCSR = #sparse_tensor.encoding<{
map = (d0, d1, d2) -> (d0 : batch, d1: dense, d2 : compressed),
}>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
index 82f3147d3206bd..be2172515d08bf 100755
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
@@ -2,9 +2,9 @@
// RUN: --sparsification-and-bufferization | FileCheck %s
#CSR_ones_complex = #sparse_tensor.encoding<{
- map = (d0, d1) -> (d0 : dense, d1 : compressed)
-// explicitVal = (1.0, 0.0) : complex<f32>,
-// implicitVal = (0.0, 0.0) : complex<f32>
+ map = (d0, d1) -> (d0 : dense, d1 : compressed),
+ explicitVal = #complex.number<:f32 1.0, 0.0>,
+ implicitVal = #complex.number<:f32 0.0, 0.0>
}>
#CSR_ones_fp = #sparse_tensor.encoding<{
@@ -20,9 +20,17 @@
}>
// CHECK-LABEL: func.func @matmul_complex
-//
-// TODO: make this work
-//
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: %[[X:.*]] = memref.load
+// CHECK: scf.for
+// CHECK: %[[I:.*]] = memref.load
+// CHECK: %[[Y:.*]] = memref.load
+// CHECK: %[[M:.*]] = complex.add %[[Y]], %[[X]] : complex<f32>
+// CHECK: memref.store %[[M]]
+// CHECK: }
+// CHECK: }
+// CHECK: }
func.func @matmul_complex(%a: tensor<10x20xcomplex<f32>>,
%b: tensor<20x30xcomplex<f32>, #CSR_ones_complex>,
%c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>> {
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index acd2d3a14d7411..13c246a3fec6af 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3066,6 +3066,7 @@ cc_library(
":ArithDialect",
":BufferizationInterfaces",
":BytecodeOpInterface",
+ ":ComplexDialect",
":DialectUtils",
":IR",
":InferTypeOpInterface",
|
aartbik
reviewed
May 1, 2024
aartbik
approved these changes
May 1, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
No description provided.