/
AffineStructures.cpp
1783 lines (1564 loc) · 68.2 KB
/
AffineStructures.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
//===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Structures for affine/polyhedral analysis of affine dialect ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Analysis/Presburger/LinearTransform.h"
#include "mlir/Analysis/Presburger/Simplex.h"
#include "mlir/Analysis/Presburger/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/IR/AffineValueMap.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "affine-structures"
using namespace mlir;
using namespace presburger;
namespace {
// See comments for SimpleAffineExprFlattener.
// An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
// constraint information associated with mod's, floordiv's, and ceildiv's
// in FlatAffineValueConstraints 'localVarCst'.
struct AffineExprFlattener : public SimpleAffineExprFlattener {
public:
// Constraints connecting newly introduced local variables (for mod's and
// div's) to existing (dimensional and symbolic) ones. These are always
// inequalities.
IntegerPolyhedron localVarCst;
AffineExprFlattener(unsigned nDims, unsigned nSymbols)
: SimpleAffineExprFlattener(nDims, nSymbols),
localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)) {}
private:
// Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
// The local identifier added is always a floordiv of a pure add/mul affine
// function of other identifiers, coefficients of which are specified in
// `dividend' and with respect to the positive constant `divisor'. localExpr
// is the simplified tree expression (AffineExpr) corresponding to the
// quantifier.
void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
AffineExpr localExpr) override {
SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
// Update localVarCst.
localVarCst.addLocalFloorDiv(dividend, divisor);
}
};
} // namespace
// Flattens the expressions in map. Returns failure if 'expr' was unable to be
// flattened (i.e., semi-affine expressions not handled yet).
static LogicalResult
getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
unsigned numSymbols,
std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
FlatAffineValueConstraints *localVarCst) {
if (exprs.empty()) {
localVarCst->reset(numDims, numSymbols);
return success();
}
AffineExprFlattener flattener(numDims, numSymbols);
// Use the same flattener to simplify each expression successively. This way
// local identifiers / expressions are shared.
for (auto expr : exprs) {
if (!expr.isPureAffine())
return failure();
flattener.walkPostOrder(expr);
}
assert(flattener.operandExprStack.size() == exprs.size());
flattenedExprs->clear();
flattenedExprs->assign(flattener.operandExprStack.begin(),
flattener.operandExprStack.end());
if (localVarCst)
localVarCst->clearAndCopyFrom(flattener.localVarCst);
return success();
}
// Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
// be flattened (semi-affine expressions not handled yet).
LogicalResult
mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
unsigned numSymbols,
SmallVectorImpl<int64_t> *flattenedExpr,
FlatAffineValueConstraints *localVarCst) {
std::vector<SmallVector<int64_t, 8>> flattenedExprs;
LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
&flattenedExprs, localVarCst);
*flattenedExpr = flattenedExprs[0];
return ret;
}
/// Flattens the expressions in map. Returns failure if 'expr' was unable to be
/// flattened (i.e., semi-affine expressions not handled yet).
LogicalResult mlir::getFlattenedAffineExprs(
AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
FlatAffineValueConstraints *localVarCst) {
if (map.getNumResults() == 0) {
localVarCst->reset(map.getNumDims(), map.getNumSymbols());
return success();
}
return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
map.getNumSymbols(), flattenedExprs,
localVarCst);
}
LogicalResult mlir::getFlattenedAffineExprs(
IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
FlatAffineValueConstraints *localVarCst) {
if (set.getNumConstraints() == 0) {
localVarCst->reset(set.getNumDims(), set.getNumSymbols());
return success();
}
return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
set.getNumSymbols(), flattenedExprs,
localVarCst);
}
//===----------------------------------------------------------------------===//
// FlatAffineConstraints / FlatAffineValueConstraints.
//===----------------------------------------------------------------------===//
std::unique_ptr<FlatAffineValueConstraints>
FlatAffineValueConstraints::clone() const {
return std::make_unique<FlatAffineValueConstraints>(*this);
}
// Construct from an IntegerSet.
FlatAffineValueConstraints::FlatAffineValueConstraints(IntegerSet set)
: IntegerPolyhedron(set.getNumInequalities(), set.getNumEqualities(),
set.getNumDims() + set.getNumSymbols() + 1,
PresburgerSpace::getSetSpace(set.getNumDims(),
set.getNumSymbols(),
/*numLocals=*/0)) {
// Resize values.
values.resize(getNumIds(), None);
// Flatten expressions and add them to the constraint system.
std::vector<SmallVector<int64_t, 8>> flatExprs;
FlatAffineValueConstraints localVarCst;
if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
assert(false && "flattening unimplemented for semi-affine integer sets");
return;
}
assert(flatExprs.size() == set.getNumConstraints());
insertId(IdKind::Local, getNumIdKind(IdKind::Local),
/*num=*/localVarCst.getNumLocalIds());
for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
const auto &flatExpr = flatExprs[i];
assert(flatExpr.size() == getNumCols());
if (set.getEqFlags()[i]) {
addEquality(flatExpr);
} else {
addInequality(flatExpr);
}
}
// Add the other constraints involving local id's from flattening.
append(localVarCst);
}
// Construct a hyperrectangular constraint set from ValueRanges that represent
// induction variables, lower and upper bounds. `ivs`, `lbs` and `ubs` are
// expected to match one to one. The order of variables and constraints is:
//
// ivs | lbs | ubs | eq/ineq
// ----+-----+-----+---------
// 1 -1 0 >= 0
// ----+-----+-----+---------
// -1 0 1 >= 0
//
// All dimensions as set as DimId.
FlatAffineValueConstraints
FlatAffineValueConstraints::getHyperrectangular(ValueRange ivs, ValueRange lbs,
ValueRange ubs) {
FlatAffineValueConstraints res;
unsigned nIvs = ivs.size();
assert(nIvs == lbs.size() && "expected as many lower bounds as ivs");
assert(nIvs == ubs.size() && "expected as many upper bounds as ivs");
if (nIvs == 0)
return res;
res.appendDimId(ivs);
unsigned lbsStart = res.appendDimId(lbs);
unsigned ubsStart = res.appendDimId(ubs);
MLIRContext *ctx = ivs.front().getContext();
for (int ivIdx = 0, e = nIvs; ivIdx < e; ++ivIdx) {
// iv - lb >= 0
AffineMap lb = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
getAffineDimExpr(lbsStart + ivIdx, ctx));
if (failed(res.addBound(BoundType::LB, ivIdx, lb)))
llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
// -iv + ub >= 0
AffineMap ub = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
getAffineDimExpr(ubsStart + ivIdx, ctx));
if (failed(res.addBound(BoundType::UB, ivIdx, ub)))
llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
}
return res;
}
void FlatAffineValueConstraints::reset(unsigned numReservedInequalities,
unsigned numReservedEqualities,
unsigned newNumReservedCols,
unsigned newNumDims,
unsigned newNumSymbols,
unsigned newNumLocals) {
assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
"minimum 1 column");
*this = FlatAffineValueConstraints(numReservedInequalities,
numReservedEqualities, newNumReservedCols,
newNumDims, newNumSymbols, newNumLocals);
}
void FlatAffineValueConstraints::reset(unsigned newNumDims,
unsigned newNumSymbols,
unsigned newNumLocals) {
reset(/*numReservedInequalities=*/0, /*numReservedEqualities=*/0,
/*numReservedCols=*/newNumDims + newNumSymbols + newNumLocals + 1,
newNumDims, newNumSymbols, newNumLocals);
}
void FlatAffineValueConstraints::reset(
unsigned numReservedInequalities, unsigned numReservedEqualities,
unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols,
unsigned newNumLocals, ArrayRef<Value> valArgs) {
assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
"minimum 1 column");
SmallVector<Optional<Value>, 8> newVals;
if (!valArgs.empty())
newVals.assign(valArgs.begin(), valArgs.end());
*this = FlatAffineValueConstraints(
numReservedInequalities, numReservedEqualities, newNumReservedCols,
newNumDims, newNumSymbols, newNumLocals, newVals);
}
void FlatAffineValueConstraints::reset(unsigned newNumDims,
unsigned newNumSymbols,
unsigned newNumLocals,
ArrayRef<Value> valArgs) {
reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
newNumSymbols, newNumLocals, valArgs);
}
unsigned FlatAffineValueConstraints::appendDimId(ValueRange vals) {
unsigned pos = getNumDimIds();
insertId(IdKind::SetDim, pos, vals);
return pos;
}
unsigned FlatAffineValueConstraints::appendSymbolId(ValueRange vals) {
unsigned pos = getNumSymbolIds();
insertId(IdKind::Symbol, pos, vals);
return pos;
}
unsigned FlatAffineValueConstraints::insertDimId(unsigned pos,
ValueRange vals) {
return insertId(IdKind::SetDim, pos, vals);
}
unsigned FlatAffineValueConstraints::insertSymbolId(unsigned pos,
ValueRange vals) {
return insertId(IdKind::Symbol, pos, vals);
}
unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos,
unsigned num) {
unsigned absolutePos = IntegerPolyhedron::insertId(kind, pos, num);
values.insert(values.begin() + absolutePos, num, None);
assert(values.size() == getNumIds());
return absolutePos;
}
unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos,
ValueRange vals) {
assert(!vals.empty() && "expected ValueRange with Values");
unsigned num = vals.size();
unsigned absolutePos = IntegerPolyhedron::insertId(kind, pos, num);
// If a Value is provided, insert it; otherwise use None.
for (unsigned i = 0; i < num; ++i)
values.insert(values.begin() + absolutePos + i,
vals[i] ? Optional<Value>(vals[i]) : None);
assert(values.size() == getNumIds());
return absolutePos;
}
bool FlatAffineValueConstraints::hasValues() const {
return llvm::find_if(values, [](Optional<Value> id) {
return id.hasValue();
}) != values.end();
}
/// Checks if two constraint systems are in the same space, i.e., if they are
/// associated with the same set of identifiers, appearing in the same order.
static bool areIdsAligned(const FlatAffineValueConstraints &a,
const FlatAffineValueConstraints &b) {
return a.getNumDimIds() == b.getNumDimIds() &&
a.getNumSymbolIds() == b.getNumSymbolIds() &&
a.getNumIds() == b.getNumIds() &&
a.getMaybeValues().equals(b.getMaybeValues());
}
/// Calls areIdsAligned to check if two constraint systems have the same set
/// of identifiers in the same order.
bool FlatAffineValueConstraints::areIdsAlignedWithOther(
const FlatAffineValueConstraints &other) {
return areIdsAligned(*this, other);
}
/// Checks if the SSA values associated with `cst`'s identifiers in range
/// [start, end) are unique.
static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique(
const FlatAffineValueConstraints &cst, unsigned start, unsigned end) {
assert(start <= cst.getNumIds() && "Start position out of bounds");
assert(end <= cst.getNumIds() && "End position out of bounds");
if (start >= end)
return true;
SmallPtrSet<Value, 8> uniqueIds;
ArrayRef<Optional<Value>> maybeValues = cst.getMaybeValues();
for (Optional<Value> val : maybeValues) {
if (val.hasValue() && !uniqueIds.insert(val.getValue()).second)
return false;
}
return true;
}
/// Checks if the SSA values associated with `cst`'s identifiers are unique.
static bool LLVM_ATTRIBUTE_UNUSED
areIdsUnique(const FlatAffineValueConstraints &cst) {
return areIdsUnique(cst, 0, cst.getNumIds());
}
/// Checks if the SSA values associated with `cst`'s identifiers of kind `kind`
/// are unique.
static bool LLVM_ATTRIBUTE_UNUSED
areIdsUnique(const FlatAffineValueConstraints &cst, IdKind kind) {
if (kind == IdKind::SetDim)
return areIdsUnique(cst, 0, cst.getNumDimIds());
if (kind == IdKind::Symbol)
return areIdsUnique(cst, cst.getNumDimIds(), cst.getNumDimAndSymbolIds());
if (kind == IdKind::Local)
return areIdsUnique(cst, cst.getNumDimAndSymbolIds(), cst.getNumIds());
llvm_unreachable("Unexpected IdKind");
}
/// Merge and align the identifiers of A and B starting at 'offset', so that
/// both constraint systems get the union of the contained identifiers that is
/// dimension-wise and symbol-wise unique; both constraint systems are updated
/// so that they have the union of all identifiers, with A's original
/// identifiers appearing first followed by any of B's identifiers that didn't
/// appear in A. Local identifiers in B that have the same division
/// representation as local identifiers in A are merged into one.
// E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
// Output: both A, B have (%i, %j, %k) [%M, %N, %P]
static void mergeAndAlignIds(unsigned offset, FlatAffineValueConstraints *a,
FlatAffineValueConstraints *b) {
assert(offset <= a->getNumDimIds() && offset <= b->getNumDimIds());
// A merge/align isn't meaningful if a cst's ids aren't distinct.
assert(areIdsUnique(*a) && "A's values aren't unique");
assert(areIdsUnique(*b) && "B's values aren't unique");
assert(std::all_of(a->getMaybeValues().begin() + offset,
a->getMaybeValues().begin() + a->getNumDimAndSymbolIds(),
[](Optional<Value> id) { return id.hasValue(); }));
assert(std::all_of(b->getMaybeValues().begin() + offset,
b->getMaybeValues().begin() + b->getNumDimAndSymbolIds(),
[](Optional<Value> id) { return id.hasValue(); }));
SmallVector<Value, 4> aDimValues;
a->getValues(offset, a->getNumDimIds(), &aDimValues);
{
// Merge dims from A into B.
unsigned d = offset;
for (auto aDimValue : aDimValues) {
unsigned loc;
if (b->findId(aDimValue, &loc)) {
assert(loc >= offset && "A's dim appears in B's aligned range");
assert(loc < b->getNumDimIds() &&
"A's dim appears in B's non-dim position");
b->swapId(d, loc);
} else {
b->insertDimId(d, aDimValue);
}
d++;
}
// Dimensions that are in B, but not in A, are added at the end.
for (unsigned t = a->getNumDimIds(), e = b->getNumDimIds(); t < e; t++) {
a->appendDimId(b->getValue(t));
}
assert(a->getNumDimIds() == b->getNumDimIds() &&
"expected same number of dims");
}
// Merge and align symbols of A and B
a->mergeSymbolIds(*b);
// Merge and align local ids of A and B
a->mergeLocalIds(*b);
assert(areIdsAligned(*a, *b) && "IDs expected to be aligned");
}
// Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
void FlatAffineValueConstraints::mergeAndAlignIdsWithOther(
unsigned offset, FlatAffineValueConstraints *other) {
mergeAndAlignIds(offset, this, other);
}
LogicalResult
FlatAffineValueConstraints::composeMap(const AffineValueMap *vMap) {
return composeMatchingMap(
computeAlignedMap(vMap->getAffineMap(), vMap->getOperands()));
}
// Similar to `composeMap` except that no Values need be associated with the
// constraint system nor are they looked at -- the dimensions and symbols of
// `other` are expected to correspond 1:1 to `this` system.
LogicalResult FlatAffineValueConstraints::composeMatchingMap(AffineMap other) {
assert(other.getNumDims() == getNumDimIds() && "dim mismatch");
assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
std::vector<SmallVector<int64_t, 8>> flatExprs;
if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
return failure();
assert(flatExprs.size() == other.getNumResults());
// Add dimensions corresponding to the map's results.
insertDimId(/*pos=*/0, /*num=*/other.getNumResults());
// We add one equality for each result connecting the result dim of the map to
// the other identifiers.
// E.g.: if the expression is 16*i0 + i1, and this is the r^th
// iteration/result of the value map, we are adding the equality:
// d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
// add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
const auto &flatExpr = flatExprs[r];
assert(flatExpr.size() >= other.getNumInputs() + 1);
SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
// Set the coefficient for this result to one.
eqToAdd[r] = 1;
// Dims and symbols.
for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
// Negate `eq[r]` since the newly added dimension will be set to this one.
eqToAdd[e + i] = -flatExpr[i];
}
// Local columns of `eq` are at the beginning.
unsigned j = getNumDimIds() + getNumSymbolIds();
unsigned end = flatExpr.size() - 1;
for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
eqToAdd[j] = -flatExpr[i];
}
// Constant term.
eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
// Add the equality connecting the result of the map to this constraint set.
addEquality(eqToAdd);
}
return success();
}
// Turn a symbol into a dimension.
static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value id) {
unsigned pos;
if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
pos < cst->getNumDimAndSymbolIds()) {
cst->swapId(pos, cst->getNumDimIds());
cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1);
}
}
/// Merge and align symbols of `this` and `other` such that both get union of
/// of symbols that are unique. Symbols in `this` and `other` should be
/// unique. Symbols with Value as `None` are considered to be inequal to all
/// other symbols.
void FlatAffineValueConstraints::mergeSymbolIds(
FlatAffineValueConstraints &other) {
assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique");
assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique");
SmallVector<Value, 4> aSymValues;
getValues(getNumDimIds(), getNumDimAndSymbolIds(), &aSymValues);
// Merge symbols: merge symbols into `other` first from `this`.
unsigned s = other.getNumDimIds();
for (Value aSymValue : aSymValues) {
unsigned loc;
// If the id is a symbol in `other`, then align it, otherwise assume that
// it is a new symbol
if (other.findId(aSymValue, &loc) && loc >= other.getNumDimIds() &&
loc < other.getNumDimAndSymbolIds())
other.swapId(s, loc);
else
other.insertSymbolId(s - other.getNumDimIds(), aSymValue);
s++;
}
// Symbols that are in other, but not in this, are added at the end.
for (unsigned t = other.getNumDimIds() + getNumSymbolIds(),
e = other.getNumDimAndSymbolIds();
t < e; t++)
insertSymbolId(getNumSymbolIds(), other.getValue(t));
assert(getNumSymbolIds() == other.getNumSymbolIds() &&
"expected same number of symbols");
assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique");
assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique");
}
// Changes all symbol identifiers which are loop IVs to dim identifiers.
void FlatAffineValueConstraints::convertLoopIVSymbolsToDims() {
// Gather all symbols which are loop IVs.
SmallVector<Value, 4> loopIVs;
for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) {
if (hasValue(i) && getForInductionVarOwner(getValue(i)))
loopIVs.push_back(getValue(i));
}
// Turn each symbol in 'loopIVs' into a dim identifier.
for (auto iv : loopIVs) {
turnSymbolIntoDim(this, iv);
}
}
void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) {
if (containsId(val))
return;
// Caller is expected to fully compose map/operands if necessary.
assert((isTopLevelValue(val) || isForInductionVar(val)) &&
"non-terminal symbol / loop IV expected");
// Outer loop IVs could be used in forOp's bounds.
if (auto loop = getForInductionVarOwner(val)) {
appendDimId(val);
if (failed(this->addAffineForOpDomain(loop)))
LLVM_DEBUG(
loop.emitWarning("failed to add domain info to constraint system"));
return;
}
// Add top level symbol.
appendSymbolId(val);
// Check if the symbol is a constant.
if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>())
addBound(BoundType::EQ, val, constOp.value());
}
LogicalResult
FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) {
unsigned pos;
// Pre-condition for this method.
if (!findId(forOp.getInductionVar(), &pos)) {
assert(false && "Value not found");
return failure();
}
int64_t step = forOp.getStep();
if (step != 1) {
if (!forOp.hasConstantLowerBound())
LLVM_DEBUG(forOp.emitWarning("domain conservatively approximated"));
else {
// Add constraints for the stride.
// (iv - lb) % step = 0 can be written as:
// (iv - lb) - step * q = 0 where q = (iv - lb) / step.
// Add local variable 'q' and add the above equality.
// The first constraint is q = (iv - lb) floordiv step
SmallVector<int64_t, 8> dividend(getNumCols(), 0);
int64_t lb = forOp.getConstantLowerBound();
dividend[pos] = 1;
dividend.back() -= lb;
addLocalFloorDiv(dividend, step);
// Second constraint: (iv - lb) - step * q = 0.
SmallVector<int64_t, 8> eq(getNumCols(), 0);
eq[pos] = 1;
eq.back() -= lb;
// For the local var just added above.
eq[getNumCols() - 2] = -step;
addEquality(eq);
}
}
if (forOp.hasConstantLowerBound()) {
addBound(BoundType::LB, pos, forOp.getConstantLowerBound());
} else {
// Non-constant lower bound case.
if (failed(addBound(BoundType::LB, pos, forOp.getLowerBoundMap(),
forOp.getLowerBoundOperands())))
return failure();
}
if (forOp.hasConstantUpperBound()) {
addBound(BoundType::UB, pos, forOp.getConstantUpperBound() - 1);
return success();
}
// Non-constant upper bound case.
return addBound(BoundType::UB, pos, forOp.getUpperBoundMap(),
forOp.getUpperBoundOperands());
}
LogicalResult
FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
ArrayRef<AffineMap> ubMaps,
ArrayRef<Value> operands) {
assert(lbMaps.size() == ubMaps.size());
assert(lbMaps.size() <= getNumDimIds());
for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
AffineMap lbMap = lbMaps[i];
AffineMap ubMap = ubMaps[i];
assert(!lbMap || lbMap.getNumInputs() == operands.size());
assert(!ubMap || ubMap.getNumInputs() == operands.size());
// Check if this slice is just an equality along this dimension. If so,
// retrieve the existing loop it equates to and add it to the system.
if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
ubMap.getNumResults() == 1 &&
lbMap.getResult(0) + 1 == ubMap.getResult(0) &&
// The condition above will be true for maps describing a single
// iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
// Make sure we skip those cases by checking that the lb result is not
// just a constant.
!lbMap.getResult(0).isa<AffineConstantExpr>()) {
// Limited support: we expect the lb result to be just a loop dimension.
// Not supported otherwise for now.
AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
if (!result)
return failure();
AffineForOp loop =
getForInductionVarOwner(operands[result.getPosition()]);
if (!loop)
return failure();
if (failed(addAffineForOpDomain(loop)))
return failure();
continue;
}
// This slice refers to a loop that doesn't exist in the IR yet. Add its
// bounds to the system assuming its dimension identifier position is the
// same as the position of the loop in the loop nest.
if (lbMap && failed(addBound(BoundType::LB, i, lbMap, operands)))
return failure();
if (ubMap && failed(addBound(BoundType::UB, i, ubMap, operands)))
return failure();
}
return success();
}
void FlatAffineValueConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
// Create the base constraints from the integer set attached to ifOp.
FlatAffineValueConstraints cst(ifOp.getIntegerSet());
// Bind ids in the constraints to ifOp operands.
SmallVector<Value, 4> operands = ifOp.getOperands();
cst.setValues(0, cst.getNumDimAndSymbolIds(), operands);
// Merge the constraints from ifOp to the current domain. We need first merge
// and align the IDs from both constraints, and then append the constraints
// from the ifOp into the current one.
mergeAndAlignIdsWithOther(0, &cst);
append(cst);
}
bool FlatAffineValueConstraints::hasConsistentState() const {
return IntegerPolyhedron::hasConsistentState() &&
values.size() == getNumIds();
}
void FlatAffineValueConstraints::removeIdRange(IdKind kind, unsigned idStart,
unsigned idLimit) {
IntegerPolyhedron::removeIdRange(kind, idStart, idLimit);
unsigned offset = getIdKindOffset(kind);
values.erase(values.begin() + idStart + offset,
values.begin() + idLimit + offset);
}
// Determine whether the identifier at 'pos' (say id_r) can be expressed as
// modulo of another known identifier (say id_n) w.r.t a constant. For example,
// if the following constraints hold true:
// ```
// 0 <= id_r <= divisor - 1
// id_n - (divisor * q_expr) = id_r
// ```
// where `id_n` is a known identifier (called dividend), and `q_expr` is an
// `AffineExpr` (called the quotient expression), `id_r` can be written as:
//
// `id_r = id_n mod divisor`.
//
// Additionally, in a special case of the above constaints where `q_expr` is an
// identifier itself that is not yet known (say `id_q`), it can be written as a
// floordiv in the following way:
//
// `id_q = id_n floordiv divisor`.
//
// Returns true if the above mod or floordiv are detected, updating 'memo' with
// these new expressions. Returns false otherwise.
static bool detectAsMod(const FlatAffineValueConstraints &cst, unsigned pos,
int64_t lbConst, int64_t ubConst,
SmallVectorImpl<AffineExpr> &memo,
MLIRContext *context) {
assert(pos < cst.getNumIds() && "invalid position");
// Check if a divisor satisfying the condition `0 <= id_r <= divisor - 1` can
// be determined.
if (lbConst != 0 || ubConst < 1)
return false;
int64_t divisor = ubConst + 1;
// Check for the aforementioned conditions in each equality.
for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
curEquality < numEqualities; curEquality++) {
int64_t coefficientAtPos = cst.atEq(curEquality, pos);
// If current equality does not involve `id_r`, continue to the next
// equality.
if (coefficientAtPos == 0)
continue;
// Constant term should be 0 in this equality.
if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0)
continue;
// Traverse through the equality and construct the dividend expression
// `dividendExpr`, to contain all the identifiers which are known and are
// not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
// `dividendExpr` gets simplified into a single identifier `id_n` discussed
// above.
auto dividendExpr = getAffineConstantExpr(0, context);
// Track the terms that go into quotient expression, later used to detect
// additional floordiv.
unsigned quotientCount = 0;
int quotientPosition = -1;
int quotientSign = 1;
// Consider each term in the current equality.
unsigned curId, e;
for (curId = 0, e = cst.getNumDimAndSymbolIds(); curId < e; ++curId) {
// Ignore id_r.
if (curId == pos)
continue;
int64_t coefficientOfCurId = cst.atEq(curEquality, curId);
// Ignore ids that do not contribute to the current equality.
if (coefficientOfCurId == 0)
continue;
// Check if the current id goes into the quotient expression.
if (coefficientOfCurId % (divisor * coefficientAtPos) == 0) {
quotientCount++;
quotientPosition = curId;
quotientSign = (coefficientOfCurId * coefficientAtPos) > 0 ? 1 : -1;
continue;
}
// Identifiers that are part of dividendExpr should be known.
if (!memo[curId])
break;
// Append the current identifier to the dividend expression.
dividendExpr = dividendExpr + memo[curId] * coefficientOfCurId;
}
// Can't construct expression as it depends on a yet uncomputed id.
if (curId < e)
continue;
// Express `id_r` in terms of the other ids collected so far.
if (coefficientAtPos > 0)
dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos);
else
dividendExpr = dividendExpr.floorDiv(-coefficientAtPos);
// Simplify the expression.
dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimIds(),
cst.getNumSymbolIds());
// Only if the final dividend expression is just a single id (which we call
// `id_n`), we can proceed.
// TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
// to dims themselves.
auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>();
if (!dimExpr)
continue;
// Express `id_r` as `id_n % divisor` and store the expression in `memo`.
if (quotientCount >= 1) {
auto ub = cst.getConstantBound(FlatAffineValueConstraints::BoundType::UB,
dimExpr.getPosition());
// If `id_n` has an upperbound that is less than the divisor, mod can be
// eliminated altogether.
if (ub.hasValue() && ub.getValue() < divisor)
memo[pos] = dimExpr;
else
memo[pos] = dimExpr % divisor;
// If a unique quotient `id_q` was seen, it can be expressed as
// `id_n floordiv divisor`.
if (quotientCount == 1 && !memo[quotientPosition])
memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign;
return true;
}
}
return false;
}
/// Check if the pos^th identifier can be expressed as a floordiv of an affine
/// function of other identifiers (where the divisor is a positive constant)
/// given the initial set of expressions in `exprs`. If it can be, the
/// corresponding position in `exprs` is set as the detected affine expr. For
/// eg: 4q <= i + j <= 4q + 3 <=> q = (i + j) floordiv 4. An equality can
/// also yield a floordiv: eg. 4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
/// <= i <= 32q + 31 => q = i floordiv 32.
static bool detectAsFloorDiv(const FlatAffineValueConstraints &cst,
unsigned pos, MLIRContext *context,
SmallVectorImpl<AffineExpr> &exprs) {
assert(pos < cst.getNumIds() && "invalid position");
// Get upper-lower bound pair for this variable.
SmallVector<bool, 8> foundRepr(cst.getNumIds(), false);
for (unsigned i = 0, e = cst.getNumIds(); i < e; ++i)
if (exprs[i])
foundRepr[i] = true;
SmallVector<int64_t, 8> dividend;
unsigned divisor;
auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor);
// No upper-lower bound pair found for this var.
if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality)
return false;
// Construct the dividend expression.
auto dividendExpr = getAffineConstantExpr(dividend.back(), context);
for (unsigned c = 0, f = cst.getNumIds(); c < f; c++)
if (dividend[c] != 0)
dividendExpr = dividendExpr + dividend[c] * exprs[c];
// Successfully detected the floordiv.
exprs[pos] = dividendExpr.floorDiv(divisor);
return true;
}
std::pair<AffineMap, AffineMap>
FlatAffineValueConstraints::getLowerAndUpperBound(
unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
assert(pos + offset < getNumDimIds() && "invalid dim start pos");
assert(symStartPos >= (pos + offset) && "invalid sym start pos");
assert(getNumLocalIds() == localExprs.size() &&
"incorrect local exprs count");
SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices,
offset, num);
/// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
b.clear();
for (unsigned i = 0, e = a.size(); i < e; ++i) {
if (i < offset || i >= offset + num)
b.push_back(a[i]);
}
};
SmallVector<int64_t, 8> lb, ub;
SmallVector<AffineExpr, 4> lbExprs;
unsigned dimCount = symStartPos - num;
unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
lbExprs.reserve(lbIndices.size() + eqIndices.size());
// Lower bound expressions.
for (auto idx : lbIndices) {
auto ineq = getInequality(idx);
// Extract the lower bound (in terms of other coeff's + const), i.e., if
// i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
// - 1.
addCoeffs(ineq, lb);
std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
auto expr =
getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context);
// expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
int64_t divisor = std::abs(ineq[pos + offset]);
expr = (expr + divisor - 1).floorDiv(divisor);
lbExprs.push_back(expr);
}
SmallVector<AffineExpr, 4> ubExprs;
ubExprs.reserve(ubIndices.size() + eqIndices.size());
// Upper bound expressions.
for (auto idx : ubIndices) {
auto ineq = getInequality(idx);
// Extract the upper bound (in terms of other coeff's + const).
addCoeffs(ineq, ub);
auto expr =
getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context);
expr = expr.floorDiv(std::abs(ineq[pos + offset]));
// Upper bound is exclusive.
ubExprs.push_back(expr + 1);
}
// Equalities. It's both a lower and a upper bound.
SmallVector<int64_t, 4> b;
for (auto idx : eqIndices) {
auto eq = getEquality(idx);
addCoeffs(eq, b);
if (eq[pos + offset] > 0)
std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
// Extract the upper bound (in terms of other coeff's + const).
auto expr =
getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
expr = expr.floorDiv(std::abs(eq[pos + offset]));
// Upper bound is exclusive.
ubExprs.push_back(expr + 1);
// Lower bound.
expr =
getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
expr = expr.ceilDiv(std::abs(eq[pos + offset]));
lbExprs.push_back(expr);
}
auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
return {lbMap, ubMap};
}
/// Computes the lower and upper bounds of the first 'num' dimensional
/// identifiers (starting at 'offset') as affine maps of the remaining
/// identifiers (dimensional and symbolic identifiers). Local identifiers are
/// themselves explicitly computed as affine functions of other identifiers in
/// this process if needed.
void FlatAffineValueConstraints::getSliceBounds(
unsigned offset, unsigned num, MLIRContext *context,
SmallVectorImpl<AffineMap> *lbMaps, SmallVectorImpl<AffineMap> *ubMaps,
bool getClosedUB) {
assert(num < getNumDimIds() && "invalid range");
// Basic simplification.
normalizeConstraintsByGCD();
LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
<< " identifiers\n");
LLVM_DEBUG(dump());
// Record computed/detected identifiers.
SmallVector<AffineExpr, 8> memo(getNumIds());
// Initialize dimensional and symbolic identifiers.
for (unsigned i = 0, e = getNumDimIds(); i < e; i++) {
if (i < offset)
memo[i] = getAffineDimExpr(i, context);
else if (i >= offset + num)
memo[i] = getAffineDimExpr(i - num, context);
}
for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++)
memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context);
bool changed;
do {
changed = false;
// Identify yet unknown identifiers as constants or mod's / floordiv's of
// other identifiers if possible.
for (unsigned pos = 0; pos < getNumIds(); pos++) {
if (memo[pos])
continue;
auto lbConst = getConstantBound(BoundType::LB, pos);