Skip to content

Commit

Permalink
[mlir][sparse] Add support for complex.im and complex.re to the spars…
Browse files Browse the repository at this point in the history
…e compiler.

Add a test.

Reviewed By: aartbik

Differential Revision: https://reviews.llvm.org/D125834
  • Loading branch information
bixia1 committed May 18, 2022
1 parent 66dfa36 commit 69edacb
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/SparseTensor/Utils/Merger.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ enum Kind {
kCastU, // unsigned
kCastIdx,
kTruncI,
kCIm, // complex.im
kCRe, // complex.re
kBitCast,
kBinaryBranch, // semiring unary branch created from a binary op
kUnary, // semiring unary op
Expand Down
23 changes: 23 additions & 0 deletions mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
case kTanhF:
case kNegF:
case kNegI:
case kCIm:
case kCRe:
assert(x != -1u && y == -1u && !v && !o);
children.e0 = x;
children.e1 = y;
Expand Down Expand Up @@ -291,6 +293,8 @@ bool Merger::isSingleCondition(unsigned t, unsigned e) const {
case kCastU:
case kCastIdx:
case kTruncI:
case kCIm:
case kCRe:
case kBitCast:
return isSingleCondition(t, tensorExps[e].children.e0);
case kDivF: // note: x / c only
Expand Down Expand Up @@ -367,6 +371,10 @@ static const char *kindToOpSymbol(Kind kind) {
case kCastU:
case kCastIdx:
case kTruncI:
case kCIm:
return "complex.im";
case kCRe:
return "complex.re";
case kBitCast:
return "cast";
case kBinaryBranch:
Expand Down Expand Up @@ -526,6 +534,8 @@ unsigned Merger::buildLattices(unsigned e, unsigned i) {
}
case kAbsF:
case kCeilF:
case kCIm:
case kCRe:
case kFloorF:
case kSqrtF:
case kExpm1F:
Expand Down Expand Up @@ -776,6 +786,10 @@ Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
return addExp(kCastIdx, e, v);
if (isa<arith::TruncIOp>(def))
return addExp(kTruncI, e, v);
if (isa<complex::ImOp>(def))
return addExp(kCIm, e);
if (isa<complex::ReOp>(def))
return addExp(kCRe, e);
if (isa<arith::BitcastOp>(def))
return addExp(kBitCast, e, v);
if (isa<sparse_tensor::UnaryOp>(def))
Expand Down Expand Up @@ -930,6 +944,15 @@ Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
case kTruncI:
return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
case kCIm:
case kCRe: {
auto type = v0.getType().template cast<ComplexType>();
auto eltType = type.getElementType().template cast<FloatType>();
if (tensorExps[e].kind == kCIm)
return rewriter.create<complex::ImOp>(loc, eltType, v0);

return rewriter.create<complex::ReOp>(loc, eltType, v0);
}
case kBitCast:
return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
// Binary ops.
Expand Down
93 changes: 93 additions & 0 deletions mlir/test/Integration/Dialect/SparseTensor/CPU/sparse_re_im.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// RUN: mlir-opt %s --sparse-compiler | \
// RUN: mlir-cpu-runner \
// RUN: -e entry -entry-point-result=void \
// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
// RUN: FileCheck %s

#SparseVector = #sparse_tensor.encoding<{dimLevelType = ["compressed"]}>

#trait_op = {
indexing_maps = [
affine_map<(i) -> (i)>, // a (in)
affine_map<(i) -> (i)> // x (out)
],
iterator_types = ["parallel"],
doc = "x(i) = OP a(i)"
}

module {
func.func @cre(%arga: tensor<?xcomplex<f32>, #SparseVector>)
-> tensor<?xf32, #SparseVector> {
%c = arith.constant 0 : index
%d = tensor.dim %arga, %c : tensor<?xcomplex<f32>, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xf32, #SparseVector>
%0 = linalg.generic #trait_op
ins(%arga: tensor<?xcomplex<f32>, #SparseVector>)
outs(%xv: tensor<?xf32, #SparseVector>) {
^bb(%a: complex<f32>, %x: f32):
%1 = complex.re %a : complex<f32>
linalg.yield %1 : f32
} -> tensor<?xf32, #SparseVector>
return %0 : tensor<?xf32, #SparseVector>
}

func.func @cim(%arga: tensor<?xcomplex<f32>, #SparseVector>)
-> tensor<?xf32, #SparseVector> {
%c = arith.constant 0 : index
%d = tensor.dim %arga, %c : tensor<?xcomplex<f32>, #SparseVector>
%xv = sparse_tensor.init [%d] : tensor<?xf32, #SparseVector>
%0 = linalg.generic #trait_op
ins(%arga: tensor<?xcomplex<f32>, #SparseVector>)
outs(%xv: tensor<?xf32, #SparseVector>) {
^bb(%a: complex<f32>, %x: f32):
%1 = complex.im %a : complex<f32>
linalg.yield %1 : f32
} -> tensor<?xf32, #SparseVector>
return %0 : tensor<?xf32, #SparseVector>
}

func.func @dump(%arg0: tensor<?xf32, #SparseVector>) {
%c0 = arith.constant 0 : index
%d0 = arith.constant -1.0 : f32
%values = sparse_tensor.values %arg0 : tensor<?xf32, #SparseVector> to memref<?xf32>
%0 = vector.transfer_read %values[%c0], %d0: memref<?xf32>, vector<4xf32>
vector.print %0 : vector<4xf32>
%indices = sparse_tensor.indices %arg0, %c0 : tensor<?xf32, #SparseVector> to memref<?xindex>
%1 = vector.transfer_read %indices[%c0], %c0: memref<?xindex>, vector<4xindex>
vector.print %1 : vector<4xindex>
return
}

// Driver method to call and verify functions cim and cre.
func.func @entry() {
// Setup sparse vectors.
%v1 = arith.constant sparse<
[ [0], [20], [31] ],
[ (5.13, 2.0), (3.0, 4.0), (5.0, 6.0) ] > : tensor<32xcomplex<f32>>
%sv1 = sparse_tensor.convert %v1 : tensor<32xcomplex<f32>> to tensor<?xcomplex<f32>, #SparseVector>

// Call sparse vector kernels.
%0 = call @cre(%sv1)
: (tensor<?xcomplex<f32>, #SparseVector>) -> tensor<?xf32, #SparseVector>

%1 = call @cim(%sv1)
: (tensor<?xcomplex<f32>, #SparseVector>) -> tensor<?xf32, #SparseVector>

//
// Verify the results.
//
// CHECK: ( 5.13, 3, 5, -1 )
// CHECK-NEXT: ( 0, 20, 31, 0 )
// CHECK-NEXT: ( 2, 4, 6, -1 )
// CHECK-NEXT: ( 0, 20, 31, 0 )
//
call @dump(%0) : (tensor<?xf32, #SparseVector>) -> ()
call @dump(%1) : (tensor<?xf32, #SparseVector>) -> ()

// Release the resources.
sparse_tensor.release %sv1 : tensor<?xcomplex<f32>, #SparseVector>
sparse_tensor.release %0 : tensor<?xf32, #SparseVector>
sparse_tensor.release %1 : tensor<?xf32, #SparseVector>
return
}
}

0 comments on commit 69edacb

Please sign in to comment.