Skip to content

Commit 4084334

Browse files
committed
[mlir][sparse] Add Merger unit tests (with gcc5 build fix)
This is a fix of https://reviews.llvm.org/D104956, which broke the gcc5 build. We opt to use unit tests rather than check tests as the lattice/merger code is a small C++ component with a well-defined API. Testing this API via check tests would be far less direct and readable. In addition, as the check tests will only be able to test the API indirectly, the tests may break based on unrelated changes; e.g. changes in linalg. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D105828
1 parent d5d4777 commit 4084334

File tree

3 files changed

+260
-0
lines changed

3 files changed

+260
-0
lines changed

mlir/unittests/Dialect/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ target_link_libraries(MLIRDialectTests
77
MLIRDialect)
88

99
add_subdirectory(Quant)
10+
add_subdirectory(SparseTensor)
1011
add_subdirectory(SPIRV)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
add_mlir_unittest(MLIRSparseTensorTests
2+
MergerTest.cpp
3+
)
4+
target_link_libraries(MLIRSparseTensorTests
5+
PRIVATE
6+
MLIRSparseTensorUtils
7+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
2+
#include "gmock/gmock.h"
3+
#include "gtest/gtest.h"
4+
#include <memory>
5+
6+
using namespace mlir::sparse_tensor;
7+
8+
namespace {
9+
10+
/// Simple recursive data structure used to match expressions in Mergers.
11+
struct Pattern {
12+
Kind kind;
13+
14+
/// Expressions representing tensors simply have a tensor number.
15+
unsigned tensorNum;
16+
17+
/// Tensor operations point to their children.
18+
std::shared_ptr<Pattern> e0;
19+
std::shared_ptr<Pattern> e1;
20+
21+
/// Constructors.
22+
/// Rather than using these, please use the readable helper constructor
23+
/// functions below to make tests more readable.
24+
Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {}
25+
Pattern(Kind kind, std::shared_ptr<Pattern> e0, std::shared_ptr<Pattern> e1)
26+
: kind(kind), e0(e0), e1(e1) {
27+
assert(kind >= Kind::kMulF);
28+
assert(e0 && e1);
29+
}
30+
};
31+
32+
///
33+
/// Readable Pattern builder functions.
34+
/// These should be preferred over the actual constructors.
35+
///
36+
37+
static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
38+
return std::make_shared<Pattern>(tensorNum);
39+
}
40+
41+
static std::shared_ptr<Pattern> addfPattern(std::shared_ptr<Pattern> e0,
42+
std::shared_ptr<Pattern> e1) {
43+
return std::make_shared<Pattern>(Kind::kAddF, e0, e1);
44+
}
45+
46+
static std::shared_ptr<Pattern> mulfPattern(std::shared_ptr<Pattern> e0,
47+
std::shared_ptr<Pattern> e1) {
48+
return std::make_shared<Pattern>(Kind::kMulF, e0, e1);
49+
}
50+
51+
class MergerTestBase : public ::testing::Test {
52+
protected:
53+
MergerTestBase(unsigned numTensors, unsigned numLoops)
54+
: numTensors(numTensors), numLoops(numLoops),
55+
merger(numTensors, numLoops) {}
56+
57+
///
58+
/// Expression construction helpers.
59+
///
60+
61+
unsigned tensor(unsigned tensor) {
62+
return merger.addExp(Kind::kTensor, tensor);
63+
}
64+
65+
unsigned addf(unsigned e0, unsigned e1) {
66+
return merger.addExp(Kind::kAddF, e0, e1);
67+
}
68+
69+
unsigned mulf(unsigned e0, unsigned e1) {
70+
return merger.addExp(Kind::kMulF, e0, e1);
71+
}
72+
73+
///
74+
/// Comparison helpers.
75+
///
76+
77+
/// For readability of tests.
78+
unsigned lat(unsigned lat) { return lat; }
79+
80+
/// Returns true if a lattice point with an expression matching the given
81+
/// pattern and bits matching the given bits is present in lattice points
82+
/// [p, p+n) of lattice set s. This is useful for testing partial ordering
83+
/// constraints between lattice points. We generally know how contiguous
84+
/// groups of lattice points should be ordered with respect to other groups,
85+
/// but there is no required ordering within groups.
86+
bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
87+
std::shared_ptr<Pattern> pattern,
88+
llvm::BitVector bits) {
89+
for (unsigned i = p; i < p + n; ++i) {
90+
if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
91+
compareBits(s, i, bits))
92+
return true;
93+
}
94+
return false;
95+
}
96+
97+
/// Wrapper over latPointWithinRange for readability of tests.
98+
void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
99+
std::shared_ptr<Pattern> pattern,
100+
llvm::BitVector bits) {
101+
EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits));
102+
}
103+
104+
/// Wrapper over expectLatPointWithinRange for a single lat point.
105+
void expectLatPoint(unsigned s, unsigned p, std::shared_ptr<Pattern> pattern,
106+
llvm::BitVector bits) {
107+
EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits));
108+
}
109+
110+
/// Converts a vector of (loop, tensor) pairs to a bitvector with the
111+
/// corresponding bits set.
112+
llvm::BitVector
113+
loopsToBits(std::vector<std::pair<unsigned, unsigned>> loops) {
114+
llvm::BitVector testBits = llvm::BitVector(numTensors + 1, false);
115+
for (auto l : loops) {
116+
auto loop = std::get<0>(l);
117+
auto tensor = std::get<1>(l);
118+
testBits.set(numTensors * loop + tensor);
119+
}
120+
return testBits;
121+
}
122+
123+
/// Returns true if the bits of lattice point p in set s match the given bits.
124+
bool compareBits(unsigned s, unsigned p, llvm::BitVector bits) {
125+
return merger.lat(merger.set(s)[p]).bits == bits;
126+
}
127+
128+
/// Check that there are n lattice points in set s.
129+
void expectNumLatPoints(unsigned s, unsigned n) {
130+
EXPECT_THAT(merger.set(s).size(), n);
131+
}
132+
133+
/// Compares expressions for equality. Equality is defined recursively as:
134+
/// - Two expressions can only be equal if they have the same Kind.
135+
/// - Two binary expressions are equal if they have the same Kind and their
136+
/// children are equal.
137+
/// - Expressions with Kind invariant or tensor are equal if they have the
138+
/// same expression id.
139+
bool compareExpression(unsigned e, std::shared_ptr<Pattern> pattern) {
140+
auto tensorExp = merger.exp(e);
141+
if (tensorExp.kind != pattern->kind)
142+
return false;
143+
assert(tensorExp.kind != Kind::kInvariant &&
144+
"Invariant comparison not yet supported");
145+
switch (tensorExp.kind) {
146+
case Kind::kTensor:
147+
return tensorExp.tensor == pattern->tensorNum;
148+
case Kind::kZero:
149+
return true;
150+
case Kind::kMulF:
151+
case Kind::kMulI:
152+
case Kind::kAddF:
153+
case Kind::kAddI:
154+
case Kind::kSubF:
155+
case Kind::kSubI:
156+
return compareExpression(tensorExp.children.e0, pattern->e0) &&
157+
compareExpression(tensorExp.children.e1, pattern->e1);
158+
default:
159+
llvm_unreachable("Unhandled Kind");
160+
}
161+
}
162+
163+
unsigned numTensors;
164+
unsigned numLoops;
165+
Merger merger;
166+
};
167+
168+
class MergerTest3T1L : public MergerTestBase {
169+
protected:
170+
// Our three tensors (two inputs, one output).
171+
const unsigned t0 = 0, t1 = 1, t2 = 2;
172+
173+
// Our single loop.
174+
const unsigned l0 = 0;
175+
176+
MergerTest3T1L() : MergerTestBase(3, 1) {
177+
// Tensor 0: sparse input vector.
178+
merger.addExp(Kind::kTensor, t0, -1u);
179+
merger.setDim(t0, l0, Dim::kSparse);
180+
181+
// Tensor 1: sparse input vector.
182+
merger.addExp(Kind::kTensor, t1, -1u);
183+
merger.setDim(t1, l0, Dim::kSparse);
184+
185+
// Tensor 2: dense output vector.
186+
merger.addExp(Kind::kTensor, t2, -1u);
187+
merger.setDim(t2, l0, Dim::kDense);
188+
}
189+
};
190+
191+
} // anonymous namespace
192+
193+
/// Vector addition of 2 vectors, i.e.:
194+
/// a(i) = b(i) + c(i)
195+
/// which should form the 3 lattice points
196+
/// {
197+
/// lat( i_00 i_01 / (tensor_0 + tensor_1) )
198+
/// lat( i_00 / tensor_0 )
199+
/// lat( i_01 / tensor_1 )
200+
/// }
201+
/// and after optimization, will reduce to the 2 lattice points
202+
/// {
203+
/// lat( i_00 i_01 / (tensor_0 + tensor_1) )
204+
/// lat( i_00 / tensor_0 )
205+
/// }
206+
TEST_F(MergerTest3T1L, VectorAdd2) {
207+
// Construct expression.
208+
auto e = addf(tensor(t0), tensor(t1));
209+
210+
// Build lattices and check.
211+
auto s = merger.buildLattices(e, l0);
212+
expectNumLatPoints(s, 3);
213+
expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
214+
loopsToBits({{l0, t0}, {l0, t1}}));
215+
expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
216+
loopsToBits({{l0, t0}}));
217+
expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
218+
loopsToBits({{l0, t1}}));
219+
220+
// Optimize lattices and check.
221+
s = merger.optimizeSet(s);
222+
expectNumLatPoints(s, 3);
223+
expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
224+
loopsToBits({{l0, t0}, {l0, t1}}));
225+
expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
226+
loopsToBits({{l0, t0}}));
227+
expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
228+
loopsToBits({{l0, t1}}));
229+
}
230+
231+
/// Vector multiplication of 2 vectors, i.e.:
232+
/// a(i) = b(i) * c(i)
233+
/// which should form the single lattice point
234+
/// {
235+
/// lat( i_00 i_01 / (tensor_0 * tensor_1) )
236+
/// }
237+
TEST_F(MergerTest3T1L, VectorMul2) {
238+
// Construct expression.
239+
auto e = mulf(t0, t1);
240+
241+
// Build lattices and check.
242+
auto s = merger.buildLattices(e, l0);
243+
expectNumLatPoints(s, 1);
244+
expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
245+
loopsToBits({{l0, t0}, {l0, t1}}));
246+
247+
// Optimize lattices and check.
248+
s = merger.optimizeSet(s);
249+
expectNumLatPoints(s, 1);
250+
expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
251+
loopsToBits({{l0, t0}, {l0, t1}}));
252+
}

0 commit comments

Comments
 (0)