|
| 1 | +#include "mlir/Dialect/SparseTensor/Utils/Merger.h" |
| 2 | +#include "gmock/gmock.h" |
| 3 | +#include "gtest/gtest.h" |
| 4 | +#include <memory> |
| 5 | + |
| 6 | +using namespace mlir::sparse_tensor; |
| 7 | + |
| 8 | +namespace { |
| 9 | + |
| 10 | +/// Simple recursive data structure used to match expressions in Mergers. |
| 11 | +struct Pattern { |
| 12 | + Kind kind; |
| 13 | + |
| 14 | + /// Expressions representing tensors simply have a tensor number. |
| 15 | + unsigned tensorNum; |
| 16 | + |
| 17 | + /// Tensor operations point to their children. |
| 18 | + std::shared_ptr<Pattern> e0; |
| 19 | + std::shared_ptr<Pattern> e1; |
| 20 | + |
| 21 | + /// Constructors. |
| 22 | + /// Rather than using these, please use the readable helper constructor |
| 23 | + /// functions below to make tests more readable. |
| 24 | + Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {} |
| 25 | + Pattern(Kind kind, std::shared_ptr<Pattern> e0, std::shared_ptr<Pattern> e1) |
| 26 | + : kind(kind), e0(e0), e1(e1) { |
| 27 | + assert(kind >= Kind::kMulF); |
| 28 | + assert(e0 && e1); |
| 29 | + } |
| 30 | +}; |
| 31 | + |
| 32 | +/// |
| 33 | +/// Readable Pattern builder functions. |
| 34 | +/// These should be preferred over the actual constructors. |
| 35 | +/// |
| 36 | + |
| 37 | +static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) { |
| 38 | + return std::make_shared<Pattern>(tensorNum); |
| 39 | +} |
| 40 | + |
| 41 | +static std::shared_ptr<Pattern> addfPattern(std::shared_ptr<Pattern> e0, |
| 42 | + std::shared_ptr<Pattern> e1) { |
| 43 | + return std::make_shared<Pattern>(Kind::kAddF, e0, e1); |
| 44 | +} |
| 45 | + |
| 46 | +static std::shared_ptr<Pattern> mulfPattern(std::shared_ptr<Pattern> e0, |
| 47 | + std::shared_ptr<Pattern> e1) { |
| 48 | + return std::make_shared<Pattern>(Kind::kMulF, e0, e1); |
| 49 | +} |
| 50 | + |
| 51 | +class MergerTestBase : public ::testing::Test { |
| 52 | +protected: |
| 53 | + MergerTestBase(unsigned numTensors, unsigned numLoops) |
| 54 | + : numTensors(numTensors), numLoops(numLoops), |
| 55 | + merger(numTensors, numLoops) {} |
| 56 | + |
| 57 | + /// |
| 58 | + /// Expression construction helpers. |
| 59 | + /// |
| 60 | + |
| 61 | + unsigned tensor(unsigned tensor) { |
| 62 | + return merger.addExp(Kind::kTensor, tensor); |
| 63 | + } |
| 64 | + |
| 65 | + unsigned addf(unsigned e0, unsigned e1) { |
| 66 | + return merger.addExp(Kind::kAddF, e0, e1); |
| 67 | + } |
| 68 | + |
| 69 | + unsigned mulf(unsigned e0, unsigned e1) { |
| 70 | + return merger.addExp(Kind::kMulF, e0, e1); |
| 71 | + } |
| 72 | + |
| 73 | + /// |
| 74 | + /// Comparison helpers. |
| 75 | + /// |
| 76 | + |
| 77 | + /// For readability of tests. |
| 78 | + unsigned lat(unsigned lat) { return lat; } |
| 79 | + |
| 80 | + /// Returns true if a lattice point with an expression matching the given |
| 81 | + /// pattern and bits matching the given bits is present in lattice points |
| 82 | + /// [p, p+n) of lattice set s. This is useful for testing partial ordering |
| 83 | + /// constraints between lattice points. We generally know how contiguous |
| 84 | + /// groups of lattice points should be ordered with respect to other groups, |
| 85 | + /// but there is no required ordering within groups. |
| 86 | + bool latPointWithinRange(unsigned s, unsigned p, unsigned n, |
| 87 | + std::shared_ptr<Pattern> pattern, |
| 88 | + llvm::BitVector bits) { |
| 89 | + for (unsigned i = p; i < p + n; ++i) { |
| 90 | + if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) && |
| 91 | + compareBits(s, i, bits)) |
| 92 | + return true; |
| 93 | + } |
| 94 | + return false; |
| 95 | + } |
| 96 | + |
| 97 | + /// Wrapper over latPointWithinRange for readability of tests. |
| 98 | + void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n, |
| 99 | + std::shared_ptr<Pattern> pattern, |
| 100 | + llvm::BitVector bits) { |
| 101 | + EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits)); |
| 102 | + } |
| 103 | + |
| 104 | + /// Wrapper over expectLatPointWithinRange for a single lat point. |
| 105 | + void expectLatPoint(unsigned s, unsigned p, std::shared_ptr<Pattern> pattern, |
| 106 | + llvm::BitVector bits) { |
| 107 | + EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits)); |
| 108 | + } |
| 109 | + |
| 110 | + /// Converts a vector of (loop, tensor) pairs to a bitvector with the |
| 111 | + /// corresponding bits set. |
| 112 | + llvm::BitVector |
| 113 | + loopsToBits(std::vector<std::pair<unsigned, unsigned>> loops) { |
| 114 | + llvm::BitVector testBits = llvm::BitVector(numTensors + 1, false); |
| 115 | + for (auto l : loops) { |
| 116 | + auto loop = std::get<0>(l); |
| 117 | + auto tensor = std::get<1>(l); |
| 118 | + testBits.set(numTensors * loop + tensor); |
| 119 | + } |
| 120 | + return testBits; |
| 121 | + } |
| 122 | + |
| 123 | + /// Returns true if the bits of lattice point p in set s match the given bits. |
| 124 | + bool compareBits(unsigned s, unsigned p, llvm::BitVector bits) { |
| 125 | + return merger.lat(merger.set(s)[p]).bits == bits; |
| 126 | + } |
| 127 | + |
| 128 | + /// Check that there are n lattice points in set s. |
| 129 | + void expectNumLatPoints(unsigned s, unsigned n) { |
| 130 | + EXPECT_THAT(merger.set(s).size(), n); |
| 131 | + } |
| 132 | + |
| 133 | + /// Compares expressions for equality. Equality is defined recursively as: |
| 134 | + /// - Two expressions can only be equal if they have the same Kind. |
| 135 | + /// - Two binary expressions are equal if they have the same Kind and their |
| 136 | + /// children are equal. |
| 137 | + /// - Expressions with Kind invariant or tensor are equal if they have the |
| 138 | + /// same expression id. |
| 139 | + bool compareExpression(unsigned e, std::shared_ptr<Pattern> pattern) { |
| 140 | + auto tensorExp = merger.exp(e); |
| 141 | + if (tensorExp.kind != pattern->kind) |
| 142 | + return false; |
| 143 | + assert(tensorExp.kind != Kind::kInvariant && |
| 144 | + "Invariant comparison not yet supported"); |
| 145 | + switch (tensorExp.kind) { |
| 146 | + case Kind::kTensor: |
| 147 | + return tensorExp.tensor == pattern->tensorNum; |
| 148 | + case Kind::kZero: |
| 149 | + return true; |
| 150 | + case Kind::kMulF: |
| 151 | + case Kind::kMulI: |
| 152 | + case Kind::kAddF: |
| 153 | + case Kind::kAddI: |
| 154 | + case Kind::kSubF: |
| 155 | + case Kind::kSubI: |
| 156 | + return compareExpression(tensorExp.children.e0, pattern->e0) && |
| 157 | + compareExpression(tensorExp.children.e1, pattern->e1); |
| 158 | + default: |
| 159 | + llvm_unreachable("Unhandled Kind"); |
| 160 | + } |
| 161 | + } |
| 162 | + |
| 163 | + unsigned numTensors; |
| 164 | + unsigned numLoops; |
| 165 | + Merger merger; |
| 166 | +}; |
| 167 | + |
| 168 | +class MergerTest3T1L : public MergerTestBase { |
| 169 | +protected: |
| 170 | + // Our three tensors (two inputs, one output). |
| 171 | + const unsigned t0 = 0, t1 = 1, t2 = 2; |
| 172 | + |
| 173 | + // Our single loop. |
| 174 | + const unsigned l0 = 0; |
| 175 | + |
| 176 | + MergerTest3T1L() : MergerTestBase(3, 1) { |
| 177 | + // Tensor 0: sparse input vector. |
| 178 | + merger.addExp(Kind::kTensor, t0, -1u); |
| 179 | + merger.setDim(t0, l0, Dim::kSparse); |
| 180 | + |
| 181 | + // Tensor 1: sparse input vector. |
| 182 | + merger.addExp(Kind::kTensor, t1, -1u); |
| 183 | + merger.setDim(t1, l0, Dim::kSparse); |
| 184 | + |
| 185 | + // Tensor 2: dense output vector. |
| 186 | + merger.addExp(Kind::kTensor, t2, -1u); |
| 187 | + merger.setDim(t2, l0, Dim::kDense); |
| 188 | + } |
| 189 | +}; |
| 190 | + |
| 191 | +} // anonymous namespace |
| 192 | + |
| 193 | +/// Vector addition of 2 vectors, i.e.: |
| 194 | +/// a(i) = b(i) + c(i) |
| 195 | +/// which should form the 3 lattice points |
| 196 | +/// { |
| 197 | +/// lat( i_00 i_01 / (tensor_0 + tensor_1) ) |
| 198 | +/// lat( i_00 / tensor_0 ) |
| 199 | +/// lat( i_01 / tensor_1 ) |
| 200 | +/// } |
| 201 | +/// and after optimization, will reduce to the 2 lattice points |
| 202 | +/// { |
| 203 | +/// lat( i_00 i_01 / (tensor_0 + tensor_1) ) |
| 204 | +/// lat( i_00 / tensor_0 ) |
| 205 | +/// } |
| 206 | +TEST_F(MergerTest3T1L, VectorAdd2) { |
| 207 | + // Construct expression. |
| 208 | + auto e = addf(tensor(t0), tensor(t1)); |
| 209 | + |
| 210 | + // Build lattices and check. |
| 211 | + auto s = merger.buildLattices(e, l0); |
| 212 | + expectNumLatPoints(s, 3); |
| 213 | + expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), |
| 214 | + loopsToBits({{l0, t0}, {l0, t1}})); |
| 215 | + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), |
| 216 | + loopsToBits({{l0, t0}})); |
| 217 | + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), |
| 218 | + loopsToBits({{l0, t1}})); |
| 219 | + |
| 220 | + // Optimize lattices and check. |
| 221 | + s = merger.optimizeSet(s); |
| 222 | + expectNumLatPoints(s, 3); |
| 223 | + expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), |
| 224 | + loopsToBits({{l0, t0}, {l0, t1}})); |
| 225 | + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), |
| 226 | + loopsToBits({{l0, t0}})); |
| 227 | + expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), |
| 228 | + loopsToBits({{l0, t1}})); |
| 229 | +} |
| 230 | + |
| 231 | +/// Vector multiplication of 2 vectors, i.e.: |
| 232 | +/// a(i) = b(i) * c(i) |
| 233 | +/// which should form the single lattice point |
| 234 | +/// { |
| 235 | +/// lat( i_00 i_01 / (tensor_0 * tensor_1) ) |
| 236 | +/// } |
| 237 | +TEST_F(MergerTest3T1L, VectorMul2) { |
| 238 | + // Construct expression. |
| 239 | + auto e = mulf(t0, t1); |
| 240 | + |
| 241 | + // Build lattices and check. |
| 242 | + auto s = merger.buildLattices(e, l0); |
| 243 | + expectNumLatPoints(s, 1); |
| 244 | + expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), |
| 245 | + loopsToBits({{l0, t0}, {l0, t1}})); |
| 246 | + |
| 247 | + // Optimize lattices and check. |
| 248 | + s = merger.optimizeSet(s); |
| 249 | + expectNumLatPoints(s, 1); |
| 250 | + expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), |
| 251 | + loopsToBits({{l0, t0}, {l0, t1}})); |
| 252 | +} |
0 commit comments