Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] Support explicit/implicit value for complex type #90771

Merged
merged 2 commits into from
May 2, 2024

Conversation

yinying-lisa-li
Copy link
Contributor

No description provided.

@yinying-lisa-li yinying-lisa-li marked this pull request as ready for review May 1, 2024 20:50
@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir bazel "Peripheral" support tier build system: utils/bazel labels May 1, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented May 1, 2024

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Yinying Li (yinying-lisa-li)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/90771.diff

6 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt (+1)
  • (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+7)
  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h (+6-3)
  • (modified) mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir (+15)
  • (modified) mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir (+14-6)
  • (modified) utils/bazel/llvm-project-overlay/mlir/BUILD.bazel (+1)
diff --git a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
index dd6f1037f71b53..6f59b69bddce86 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SparseTensor/IR/CMakeLists.txt
@@ -45,6 +45,7 @@ add_mlir_dialect_library(MLIRSparseTensorDialect
 
   LINK_LIBS PUBLIC
   MLIRArithDialect
+  MLIRComplexDialect
   MLIRDialect
   MLIRDialectUtils
   MLIRIR
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 028a69da10c1e1..dac028e8d53cb0 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -17,6 +17,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/DialectImplementation.h"
@@ -663,6 +664,9 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
         explicitVal = result;
       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
         explicitVal = result;
+      } else if (auto result =
+                     llvm::dyn_cast<::mlir::complex::NumberAttr>(attr)) {
+        explicitVal = result;
       } else {
         parser.emitError(parser.getNameLoc(),
                          "expected a numeric value for explicitVal");
@@ -678,6 +682,9 @@ Attribute SparseTensorEncodingAttr::parse(AsmParser &parser, Type type) {
         implicitVal = result;
       } else if (auto result = llvm::dyn_cast<IntegerAttr>(attr)) {
         implicitVal = result;
+      } else if (auto result =
+                     llvm::dyn_cast<::mlir::complex::NumberAttr>(attr)) {
+        implicitVal = result;
       } else {
         parser.emitError(parser.getNameLoc(),
                          "expected a numeric value for implicitVal");
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
index cf3c35f5fa4c78..d0ef8a6860bb2d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/CodegenUtils.h
@@ -401,9 +401,12 @@ inline Value constantLevelTypeEncoding(OpBuilder &builder, Location loc,
 
 // Generates a constant from a validated value carrying attribute.
 inline Value genValFromAttr(OpBuilder &builder, Location loc, Attribute attr) {
-  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
-    Type tp = cast<TypedAttr>(arrayAttr[0]).getType();
-    return builder.create<complex::ConstantOp>(loc, tp, arrayAttr);
+  if (auto complexAttr = dyn_cast<complex::NumberAttr>(attr)) {
+    Type tp = cast<ComplexType>(complexAttr.getType()).getElementType();
+    return builder.create<complex::ConstantOp>(
+        loc, complexAttr.getType(),
+        builder.getArrayAttr({FloatAttr::get(tp, complexAttr.getReal()),
+                              FloatAttr::get(tp, complexAttr.getImag())}));
   }
   return builder.create<arith::ConstantOp>(loc, cast<TypedAttr>(attr));
 }
diff --git a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
index 7eeda9a9880268..7fb1c76c1a1ff6 100644
--- a/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
+++ b/mlir/test/Dialect/SparseTensor/roundtrip_encoding.mlir
@@ -80,6 +80,21 @@ func.func private @sparse_csr(tensor<?x?xi64, #CSR_OnlyOnes>)
 
 // -----
 
+#CSR_OnlyOnes = #sparse_tensor.encoding<{
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  posWidth = 64,
+  crdWidth = 64,
+  explicitVal = #complex.number<:f32 1.0, 0.0>,
+  implicitVal = #complex.number<:f32 0.0, 0.0>
+}>
+
+// CHECK: #[[$CSR_OnlyOnes:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = #complex.number<:f32 1.000000e+00, 0.000000e+00> : complex<f32>, implicitVal = #complex.number<:f32 0.000000e+00, 0.000000e+00> : complex<f32> }>
+// CHECK-LABEL: func private @sparse_csr(
+// CHECK-SAME: tensor<?x?xcomplex<f32>, #[[$CSR_OnlyOnes]]>)
+func.func private @sparse_csr(tensor<?x?xcomplex<f32>, #CSR_OnlyOnes>)
+
+// -----
+
 #BCSR = #sparse_tensor.encoding<{
   map = (d0, d1, d2) -> (d0 : batch, d1: dense, d2 : compressed),
 }>
diff --git a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
index 82f3147d3206bd..be2172515d08bf 100755
--- a/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
+++ b/mlir/test/Dialect/SparseTensor/sparse_matmul_one.mlir
@@ -2,9 +2,9 @@
 // RUN:             --sparsification-and-bufferization | FileCheck %s
 
 #CSR_ones_complex = #sparse_tensor.encoding<{
-  map = (d0, d1) -> (d0 : dense, d1 : compressed)
-// explicitVal = (1.0, 0.0) : complex<f32>,
-// implicitVal = (0.0, 0.0) : complex<f32>
+  map = (d0, d1) -> (d0 : dense, d1 : compressed),
+  explicitVal = #complex.number<:f32 1.0, 0.0>,
+  implicitVal = #complex.number<:f32 0.0, 0.0>
 }>
 
 #CSR_ones_fp = #sparse_tensor.encoding<{
@@ -20,9 +20,17 @@
 }>
 
 // CHECK-LABEL:   func.func @matmul_complex
-//
-// TODO: make this work
-//
+// CHECK:         scf.for
+// CHECK:           scf.for
+// CHECK:             %[[X:.*]] = memref.load
+// CHECK:             scf.for
+// CHECK:               %[[I:.*]] = memref.load
+// CHECK:               %[[Y:.*]] = memref.load
+// CHECK:               %[[M:.*]] = complex.add %[[Y]], %[[X]] : complex<f32>
+// CHECK:               memref.store %[[M]]
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }
 func.func @matmul_complex(%a: tensor<10x20xcomplex<f32>>,
                           %b: tensor<20x30xcomplex<f32>, #CSR_ones_complex>,
                           %c: tensor<10x30xcomplex<f32>>) -> tensor<10x30xcomplex<f32>> {
diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
index acd2d3a14d7411..13c246a3fec6af 100644
--- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
+++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel
@@ -3066,6 +3066,7 @@ cc_library(
         ":ArithDialect",
         ":BufferizationInterfaces",
         ":BytecodeOpInterface",
+        ":ComplexDialect",
         ":DialectUtils",
         ":IR",
         ":InferTypeOpInterface",

@yinying-lisa-li yinying-lisa-li merged commit e71eacc into llvm:main May 2, 2024
4 checks passed
@yinying-lisa-li yinying-lisa-li deleted the complex branch May 2, 2024 16:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants