forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 30
/
feature.cc
executable file
·1651 lines (1436 loc) · 62.7 KB
/
feature.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file auto_scheduler/feature.cc
* \brief Feature extraction for the cost model
*/
#include <tvm/arith/analyzer.h>
#include <tvm/auto_scheduler/feature.h>
#include <tvm/auto_scheduler/measure.h>
#include <tvm/auto_scheduler/measure_record.h>
#include <tvm/runtime/registry.h>
#include <tvm/support/parallel_for.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op_attr_types.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>
#include <algorithm>
#include <cmath>
#include <numeric>
#include <unordered_map>
#include <vector>
#include "search_policy/utils.h"
#include "utils.h"
namespace tvm {
// import the function from driver_api.cc
void GetBinds(const Array<te::Tensor>& args, bool compact,
const std::unordered_map<te::Tensor, tir::Buffer>& binds,
Map<te::Tensor, tir::Buffer>* out_binds, Array<ObjectRef>* out_arg_list);
} // namespace tvm
namespace tvm {
namespace auto_scheduler {
using namespace tvm::tir;
using arith::Analyzer;
using arith::ConstIntBound;
template <class T>
using BufferMap = std::unordered_map<Buffer, T, ObjectHash, ObjectEqual>;
// The number of samples to extract for arithmetic intensity curves
static const int ARITH_INTENSITY_CURVE_SAMPLE_N = 10;
// Annotation position encoding
enum class AnnotationPosType : int {
kPosNone = 0, // Does not have this kind of annotation
kPosInnerSpatial = 1, // The annotated iterator is the innermost spatial iterator
kPosMiddleSpatial = 2, // The annotated iterator is a middle spatial iterator
kPosOuterSpatial = 3, // The annotated iterator is the outermost spatial iterator
kPosInnerReduce = 4, // The annotated iterator is the innermost reduce iterator
kPosMiddleReduce = 5, // The annotated iterator is a middle reduce iterator
kPosOuterReduce = 6, // The annotated iterator is the outermost reduce iterator
kPosMixed = 7 // The annotated iterator is a mixed space and reduce iterator
};
// Buffer access type
enum class BufferAccessType : int { kRead = 0, kWrite = 1, kReadWrite = 2, kUnknownRW = 3 };
// Accesses to a buffer
struct BufferAccess {
// data reuse type
BufferAccessType acc_type{BufferAccessType::kUnknownRW};
// Use a two-dimensional array to store multiple multi-dimensional accesses.
// The innermost vector stores the multi-dimensional indices of one access.
std::vector<std::vector<PrimExpr>> indices;
};
// Data reuse type
enum class ReuseType : int { kLoopMultipleRead = 0, kSerialMultipleReadWrite = 1, kNoReuse = 2 };
// Feature for an access of a buffer
struct BufferAccessFeature {
std::string buffer_name; // The name of the buffer
BufferAccessType acc_type; // The type of the access
float bytes; // The touched memory in bytes
float unique_bytes; // The touched unique memory in bytes
float lines; // The number of touched cache lines
float unique_lines; // The number touched unique cache lines
ReuseType reuse_type; // Tye type of data reuse
float reuse_dis_iter; // The reuse distance in iterator number
float reuse_dis_bytes; // The reuse distance in total touched bytes
float reuse_ct; // The reuse ratio
float bytes_d_reuse_ct; // bytes / reuse_ct
float unique_bytes_d_reuse_ct; // unique_bytes / reuse_ct
float lines_d_reuse_ct; // lines / reuse_ct
float unique_lines_d_reuse_ct; // unique_lines / reuse_ct
float stride; // The stride in access
};
// Feature set of a BufferStore statement
struct FeatureSet {
// Group 1: Computation related features
float float_mad; // The number of float MAD (Multiply–add) ops
float float_addsub; // The number of float add and sub ops
float float_mul; // The number of float multiply ops
float float_divmod; // The number of float div and mod ops
float float_cmp; // The number of float comparison ops
float float_math_func; // The number of float math func calls
float float_other_func; // The number of other float func calls
float int_mad; // The number of integer MAD (Multiply–add) ops
float int_addsub; // The number of integer add and sub ops
float int_mul; // The number of float multiply ops
float int_divmod; // The number of float div and mod ops
float int_cmp; // The number of float comparison ops
float int_math_func; // The number of float math func calls
float int_other_func; // The number of other float func calls
float bool_op; // The number of bool ops
float select_op; // The number of select ops
float vec_num; // The number of vectorized iterators
float vec_prod; // The product of the lengths of vectorized iterators
float vec_len; // The length of the innermost vectorized iterator
AnnotationPosType vec_type; // The type of vectorization position
float unroll_num; // The number of unrolled iterators
float unroll_prod; // The product of the lengths of vectorized iterators
float unroll_len; // The length of the innermost unrolled iterator
AnnotationPosType unroll_type; // The type of unroll position
float parallel_num; // The number of paralleled iterators
float parallel_prod; // The product of the lengths of paralleled iterators
float parallel_len; // The length of the innermost paralleled iterators
AnnotationPosType parallel_type; // The type of parallel position
float is_gpu; // Whether it is a GPU task
float blockIdx_x_len; // The length of blockIdx.x
float blockIdx_y_len; // The length of blockIdx.y
float blockIdx_z_len; // The length of blockIdx.z
float threadIdx_x_len; // The length of threadIdx.x
float threadIdx_y_len; // The length of threadIdx.y
float threadIdx_z_len; // The length of threadIdx.z
float vthread_len; // The length of virtual thread
// Group 2: Buffer access related features (per buffer)
std::vector<BufferAccessFeature> access_feas;
// Group 3: Arithmetic intensity related features
float arith_intensity_curve[ARITH_INTENSITY_CURVE_SAMPLE_N]; // points sampled from the
// arithmetic intensity curve
// Group 4: Allocation related features
float alloc_size; // The size of allocated buffer in bytes
float alloc_outer_prod; // The product of lengths of loops outside the scope of the allocation
float alloc_inner_prod; // The product of lengths of loops inside the score of the allocation
float alloc_prod; // alloc_outer_prod * alloc_inner_prod
// Group 5: Outer scope related features
float outer_prod; // The product of lengths of outer loops
float num_loops; // The number of outer loops
float auto_unroll_max_step; // The value of pragma "auto_unroll_max_step"
};
// Return whether a var is in an expr
bool VarInExpr(const Var& var, const PrimExpr& expr) {
bool find = false;
PostOrderVisit(expr, [&find, &var](const ObjectRef& node) {
if (find) {
return;
}
if (const VarNode* op = node.as<VarNode>()) {
if (op == var.get()) {
find = true;
}
}
});
return find;
}
// Get position encoding for annotation
AnnotationPosType GetAnnotationPosEncoding(const Var& var, const Array<PrimExpr>& spatial_args,
const Array<IterVar>& axis,
const Array<IterVar>& reduce_axis) {
// Try to match spatial args first
size_t find_i = 0;
size_t find_ct = 0;
for (size_t i = 0; i < spatial_args.size(); ++i) {
if (VarInExpr(var, spatial_args[i])) {
find_i = i;
find_ct += 1;
}
}
if (find_ct == 0) {
// If it is not found in spacial args, then it is a reduce iterator.
// Use name to match
const std::string& var_name = var->name_hint;
for (size_t i = 0; i < reduce_axis.size(); ++i) {
if (var_name.find(reduce_axis[i]->var->name_hint) != std::string::npos) {
find_i = i;
find_ct++;
}
}
if (find_ct >= 1) {
if (find_i == 0) {
return AnnotationPosType::kPosInnerReduce;
} else if (find_i == reduce_axis.size() - 1) {
return AnnotationPosType::kPosOuterReduce;
} else {
return AnnotationPosType::kPosMiddleReduce;
}
} else {
// If the axis is not found in both spatial args and reduce axis,
// then this stage must compute_at somewhere under this axis and this axis is simplified out
// We assume it is an outer spatial
return AnnotationPosType::kPosOuterSpatial;
}
} else if (find_ct == 1) {
if (find_i == spatial_args.size() - 1) {
return AnnotationPosType::kPosInnerSpatial;
} else if (find_i == 0) {
return AnnotationPosType::kPosOuterSpatial;
} else {
return AnnotationPosType::kPosMiddleSpatial;
}
} else {
return AnnotationPosType::kPosMixed;
}
}
// Return the extent of a for loop
int64_t GetLoopExtent(const ForNode* node) {
auto pint = node->extent.as<IntImmNode>();
if (pint != nullptr) {
return pint->value;
} else {
return 1;
}
}
// Count math ops in an expr
class MathOpCounter : public StmtExprVisitor {
public:
#define VisitBinary(Type, float_ct, int_ct) \
void VisitExpr_(const Type* op) final { \
if (op->a.dtype().is_float()) { \
float_ct++; \
} else { \
int_ct++; \
} \
StmtExprVisitor::VisitExpr_(op); \
}
VisitBinary(AddNode, float_addsub, int_addsub);
VisitBinary(SubNode, float_addsub, int_addsub);
VisitBinary(MulNode, float_mul, int_mul);
VisitBinary(DivNode, float_divmod, int_divmod);
VisitBinary(ModNode, float_divmod, int_divmod);
VisitBinary(FloorDivNode, float_divmod, int_divmod);
VisitBinary(FloorModNode, float_divmod, int_divmod);
VisitBinary(MaxNode, float_cmp, int_cmp);
VisitBinary(MinNode, float_cmp, int_cmp);
VisitBinary(EQNode, float_cmp, int_cmp);
VisitBinary(NENode, float_cmp, int_cmp);
VisitBinary(LTNode, float_cmp, int_cmp);
VisitBinary(LENode, float_cmp, int_cmp);
VisitBinary(GTNode, float_cmp, int_cmp);
VisitBinary(GENode, float_cmp, int_cmp);
#undef VisitBinary
void VisitExpr_(const AndNode* op) final {
bool_op++;
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const OrNode* op) final {
bool_op++;
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const NotNode* op) final {
bool_op++;
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const SelectNode* op) final {
select_op++;
StmtExprVisitor::VisitExpr_(op);
}
void VisitExpr_(const CallNode* op) final {
auto* pop = op->op.as<OpNode>();
ICHECK(pop != nullptr);
auto effect_kind = op_call_effect_[GetRef<Op>(pop)];
bool is_pure =
effect_kind == CallEffectKind::kPure || effect_kind == CallEffectKind::kExprAnnotation;
if (is_pure) {
if (op->dtype.is_float()) {
float_math_func++;
} else {
int_math_func++;
}
} else {
if (op->dtype.is_float()) {
float_other_func++;
} else {
int_other_func++;
}
}
StmtExprVisitor::VisitExpr_(op);
}
// todo(merrymercy): Detect MAD (Multiply–add)
size_t float_mad{0}; // The number of float MAD (Multiply–add) ops
size_t float_addsub{0}; // The number of float add and sub ops
size_t float_mul{0}; // The number of float multiply ops
size_t float_divmod{0}; // The number of float div and mod ops
size_t float_cmp{0}; // The number of float comparison ops
size_t float_math_func{0}; // The number of float math func calls
size_t float_other_func{0}; // The number of other float func calls
size_t int_mad{0}; // The number of integer MAD (Multiply–add) ops
size_t int_addsub{0}; // The number of integer add and sub ops
size_t int_mul{0}; // The number of float multiply ops
size_t int_divmod{0}; // The number of float div and mod ops
size_t int_cmp{0}; // The number of float comparison ops
size_t int_math_func{0}; // The number of float math func calls
size_t int_other_func{0}; // The number of other float func calls
size_t bool_op{0}; // The number of bool ops
size_t select_op{0}; // The number of select ops
OpAttrMap<TCallEffectKind> op_call_effect_ = Op::GetAttrMap<TCallEffectKind>("TCallEffectKind");
};
// Extract all buffer accesses in an expr
class BufferAccessExtractor : public StmtExprVisitor {
public:
void ExtractReads(const PrimExpr& expr) { this->VisitExpr(expr); }
void InsertAccess(const Buffer& buf, BufferAccessType acc_type, const Array<PrimExpr>& indices) {
BufferAccess& acc = buf_accesses[buf];
acc.acc_type = acc_type;
acc.indices.push_back(std::vector<PrimExpr>(indices.begin(), indices.end()));
}
void VisitExpr_(const BufferLoadNode* op) final {
BufferAccess& acc = buf_accesses[op->buffer];
switch (acc.acc_type) {
case BufferAccessType::kRead:
break;
case BufferAccessType::kWrite:
acc.acc_type = BufferAccessType::kReadWrite;
break;
case BufferAccessType::kReadWrite:
break;
case BufferAccessType::kUnknownRW:
default:
acc.acc_type = BufferAccessType::kRead;
break;
}
if (acc.acc_type != BufferAccessType::kReadWrite) {
// If a buffer is both read and written, in the tvm DSL, it must be a update,
// so the indices should be the same. Then we can skip appending indices for it.
// Otherwise we do the following.
buf_accesses[op->buffer].indices.push_back(
std::vector<PrimExpr>(op->indices.begin(), op->indices.end()));
}
StmtExprVisitor::VisitExpr_(op);
}
BufferMap<BufferAccess> buf_accesses;
};
// Compute the coefficient for an loop iterator in an expression
// Note: we use an approximation strategy to find coefficient.
// Hopefully, it is faster than DetectLinearEquation and can handle more cases (non-linear)
class CoefficientExtractor : public StmtExprVisitor {
public:
void VisitExpr_(const MulNode* node) final {
StmtExprVisitor::VisitExpr_(node);
if (visited_var) {
if (!visited_add) {
if (auto a = node->a.as<IntImmNode>()) {
visited_mul = true;
stride = a->value;
} else if (auto b = node->b.as<IntImmNode>()) {
visited_mul = true;
stride = b->value;
}
}
}
}
void VisitExpr_(const AddNode* node) final {
StmtExprVisitor::VisitExpr_(node);
if (visited_var) {
if (!visited_mul) {
visited_add = true;
stride = 1;
}
}
}
void VisitExpr_(const VarNode* node) final {
if (node == var_) {
visited_var = true;
// This is a magic default stride in case our approximation strategy fails
stride = 2;
}
}
int ExtractCoefficient(const PrimExpr& expr, const VarNode* var) {
visited_var = visited_mul = visited_add = false;
var_ = var;
this->VisitExpr(expr);
if (visited_var && !visited_mul && !visited_add) {
return 1;
} else {
return stride;
}
}
bool visited_var{false};
bool visited_mul{false};
bool visited_add{false};
int stride{0};
private:
const VarNode* var_{nullptr};
};
// Compute stride for the accesses to a buffer
int64_t ComputeStride(const std::vector<std::vector<PrimExpr>>& indices,
const std::vector<int>& shape, const VarNode* stride_var) {
int64_t min_stride = std::numeric_limits<int64_t>::max();
bool find = false;
CoefficientExtractor extractor;
for (const auto& index : indices) {
int64_t shape_stride = 1;
for (int i = static_cast<int>(index.size()) - 1; i >= 0; i--) {
int coefficient = extractor.ExtractCoefficient(index[i], stride_var);
if (extractor.visited_var) {
find = true;
min_stride = std::min(min_stride, std::abs(coefficient) * shape_stride);
break;
}
shape_stride *= shape[i];
}
}
return find ? min_stride : 0;
}
// Compute touched bytes and cache lines for accesses to a buffer
void ComputeRegion(const std::vector<std::vector<PrimExpr>>& indices, arith::Analyzer* ana,
std::vector<int>* region) {
region->clear();
if (indices.empty()) {
return;
}
region->reserve(indices[0].size());
if (indices.size() == 1) {
for (const auto& index : indices[0]) {
ConstIntBound bound = ana->const_int_bound(index);
region->push_back(bound->max_value - bound->min_value + 1);
}
} else {
// future(lmzheng): implement a more accurate IntSet?
for (size_t i = 0; i < indices[0].size(); ++i) {
int64_t minimum = ConstIntBound::kPosInf, maximum = ConstIntBound::kNegInf;
for (size_t j = 0; j < indices.size(); ++j) {
ConstIntBound bound = ana->const_int_bound(indices[j][i]);
minimum = std::min(minimum, bound->min_value);
maximum = std::max(maximum, bound->max_value);
}
region->push_back(maximum - minimum + 1);
}
}
}
// Compute reuse distance and reuse ratio for accesses to a buffer
// return values: reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct
std::tuple<ReuseType, float, float, float> ComputeReuse(
const Buffer& buf, const std::vector<std::vector<PrimExpr>>& indices,
const std::vector<const ForNode*>& for_loop_stack,
const std::unordered_map<const ForNode*,
BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>>&
for_touch_regions) {
float reuse_dis_iter = 1.0f;
float reuse_dis_bytes = -1.0f;
for (int i = static_cast<int>(for_loop_stack.size()) - 1; i >= 0; --i) {
const ForNode* cur_for = for_loop_stack[i];
bool find = false;
for (size_t j = 0; j < indices.size(); j++) {
for (size_t k = 0; k < indices[j].size(); k++) {
if (VarInExpr(cur_for->loop_var, indices[j][k])) {
find = true;
break;
}
}
if (find) {
break;
}
}
int64_t extent = GetLoopExtent(for_loop_stack[i]);
if (find) {
// accumulate/update reuse distance
reuse_dis_iter *= extent;
reuse_dis_bytes = 0.0f;
for (const auto& iter : for_touch_regions.at(cur_for)) {
for (const auto& access : iter.second) {
reuse_dis_bytes += std::get<1>(access) * std::get<2>(access);
}
}
} else {
// Have LoopMultipleRead reuse
if (reuse_dis_bytes < 0) {
// For the reuse in the innermost axis, the above code won't be executed.
// So we compute bytes here
reuse_dis_bytes = 0.0f;
for (const auto& iter : for_touch_regions.at(cur_for)) {
for (const auto& access : iter.second) {
reuse_dis_bytes += 1 * std::get<2>(access);
}
}
}
return std::make_tuple(ReuseType::kLoopMultipleRead, reuse_dis_iter, reuse_dis_bytes, extent);
}
const BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>& buffer_map =
for_touch_regions.at(cur_for);
int serial_reuse = static_cast<int>(buffer_map.at(buf).size()) - 1;
if (serial_reuse > 0) {
int64_t extent = GetLoopExtent(cur_for);
// Have SerialMultipleReadWrite reuse
reuse_dis_iter = std::numeric_limits<float>::max();
for (const auto& acc_info : buffer_map.at(buf)) {
reuse_dis_iter = std::min(reuse_dis_iter, static_cast<float>(std::get<1>(acc_info)));
}
reuse_dis_bytes = 0.0f;
for (const auto& iter : for_touch_regions.at(cur_for)) {
for (const auto& access : iter.second) {
reuse_dis_bytes += std::get<1>(access) * std::get<2>(access);
}
}
return std::make_tuple(ReuseType::kSerialMultipleReadWrite, reuse_dis_iter / extent,
reuse_dis_bytes / extent, serial_reuse);
}
}
return std::make_tuple(ReuseType::kNoReuse, 0, 0, 0);
}
// Extract features for every BufferStore statement
class PerStoreFeatureExtractor : public StmtExprVisitor {
public:
explicit PerStoreFeatureExtractor(int cache_line_size) : cache_line_size_(cache_line_size) {}
void VisitStmt_(const AttrStmtNode* node) final {
if (node->attr_key == tir::attr::thread_extent || node->attr_key == tir::attr::virtual_thread) {
const Var& var = node->node.as<IterVarNode>()->var;
int extent = GetIntImm(node->value);
int* plen = nullptr;
const std::string& name = var.get()->name_hint;
if (node->attr_key == tir::attr::thread_extent) {
if (name == "blockIdx.x") {
plen = &blockIdx_x_len_;
} else if (name == "blockIdx.y") {
plen = &block_idx_y_len_;
} else if (name == "blockIdx.z") {
plen = &block_idx_z_len_;
} else if (name == "threadIdx.x") {
plen = &threadIdx_x_len_;
} else if (name == "threadIdx.y") {
plen = &thread_idx_y_len_;
} else if (name == "threadIdx.z") {
plen = &thread_idx_z_len_;
} else {
LOG(FATAL) << "invalid thread itervar " + name;
}
} else {
plen = &vthread_len_;
}
int extent_before = *plen;
if (node->attr_key == tir::attr::thread_extent) {
*plen = extent;
} else {
*plen *= extent;
}
is_gpu_ = true;
// make a fake for node for blockIdx.x or threadIdx.x
Stmt fake_for_node = For(var, 0, extent, ForType::Parallel, DeviceAPI::None, node->body);
outer_loop_prod_ *= extent;
for_loop_stack_.push_back(fake_for_node.as<ForNode>());
StmtExprVisitor::VisitStmt_(node);
for_loop_stack_.pop_back();
outer_loop_prod_ /= extent;
*plen = extent_before;
} else if (node->attr_key == "pragma_auto_unroll_max_step") {
int value = GetIntImm(node->value);
int16_t old_value = cur_auto_unroll_max_step_;
cur_auto_unroll_max_step_ = value;
StmtExprVisitor::VisitStmt_(node);
cur_auto_unroll_max_step_ = old_value;
} else {
StmtExprVisitor::VisitStmt_(node);
}
}
void VisitStmt_(const ForNode* node) final {
int64_t loop_extent = GetLoopExtent(node);
if (node->for_type == ForType::Vectorized) {
vec_for_stack_.push_back(node);
} else if (node->for_type == ForType::Unrolled) {
unroll_for_stack_.push_back(node);
} else if (node->for_type == ForType::Parallel) {
parallel_for_stack_.push_back(node);
}
outer_loop_prod_ *= loop_extent;
for_loop_stack_.push_back(node);
StmtExprVisitor::VisitStmt_(node);
for_loop_stack_.pop_back();
outer_loop_prod_ /= loop_extent;
if (node->for_type == ForType::Vectorized) {
vec_for_stack_.pop_back();
} else if (node->for_type == ForType::Unrolled) {
unroll_for_stack_.pop_back();
} else if (node->for_type == ForType::Parallel) {
parallel_for_stack_.pop_back();
}
}
void VisitStmt_(const BufferStoreNode* node) final {
MathOpCounter math_op_counter;
math_op_counter(node->value);
std::vector<float> mem_bytes_list;
std::vector<float> compute_ops_list;
double cur_compute_ops;
// Group 1: Computation related features
ExtractComputationFeature(node, math_op_counter);
// Group 2: Buffer access related features (per buffer)
ExtractBufferAccessFeature(node, math_op_counter, &cur_compute_ops, &compute_ops_list,
&mem_bytes_list);
// Group 3: Arithmetic intensity related features
ExtractArithmeticIntensityFeature(node, cur_compute_ops, compute_ops_list, mem_bytes_list);
// Group 4: Allocation related features
ExtractOuterScopeFeature(node);
}
void VisitStmt_(const BufferRealizeNode* node) final {
StmtExprVisitor::VisitStmt_(node);
// Group 5: Outer scope related features
ExtractAllocationFeature(node);
}
// Extract computation related features (group 1)
void ExtractComputationFeature(const BufferStoreNode* node,
const MathOpCounter& math_op_counter) {
FeatureSet& fea = buffer_features[node->buffer];
// Computation related features
fea.float_mad = outer_loop_prod_ * math_op_counter.float_mad;
fea.float_addsub = outer_loop_prod_ * math_op_counter.float_addsub;
fea.float_mul = outer_loop_prod_ * math_op_counter.float_mul;
fea.float_divmod = outer_loop_prod_ * math_op_counter.float_divmod;
fea.float_cmp = outer_loop_prod_ * math_op_counter.float_cmp;
fea.float_math_func = outer_loop_prod_ * math_op_counter.float_math_func;
fea.float_other_func = outer_loop_prod_ * math_op_counter.float_other_func;
fea.int_mad = outer_loop_prod_ * math_op_counter.int_mad;
fea.int_addsub = outer_loop_prod_ * math_op_counter.int_addsub;
fea.int_mul = outer_loop_prod_ * math_op_counter.int_mul;
fea.int_divmod = outer_loop_prod_ * math_op_counter.int_divmod;
fea.int_math_func = outer_loop_prod_ * math_op_counter.int_math_func;
fea.int_cmp = outer_loop_prod_ * math_op_counter.int_cmp;
fea.int_other_func = outer_loop_prod_ * math_op_counter.int_other_func;
fea.bool_op = outer_loop_prod_ * math_op_counter.bool_op;
fea.select_op = outer_loop_prod_ * math_op_counter.select_op;
fea.vec_len = fea.unroll_len = fea.parallel_len = 0.0f;
fea.vec_type = fea.unroll_type = fea.parallel_type = AnnotationPosType::kPosNone;
fea.vec_num = vec_for_stack_.size();
if (!vec_for_stack_.empty()) {
fea.vec_len = GetLoopExtent(vec_for_stack_.back());
fea.vec_prod = 1.0;
for (const ForNode* pfor : vec_for_stack_) {
fea.vec_prod *= GetLoopExtent(pfor);
}
fea.vec_type = AnnotationPosType::kPosMixed;
// todo(merrymercy): this feature requires operation (tvm.compute) information
// GetAnnotationPosEncoding(vec_for_stack_.back()->loop_var,
// node->args, pcompute->axis, pcompute->reduce_axis);
}
fea.unroll_num = unroll_for_stack_.size();
if (!unroll_for_stack_.empty()) {
fea.unroll_len = GetLoopExtent(unroll_for_stack_.back());
fea.unroll_prod = 1.0;
for (const ForNode* pfor : unroll_for_stack_) {
fea.unroll_prod *= GetLoopExtent(pfor);
}
fea.unroll_type = AnnotationPosType::kPosMixed;
// GetAnnotationPosEncoding(unroll_for_stack_.back()->loop_var,
// node->args, pcompute->axis, pcompute->reduce_axis);
}
fea.parallel_num = parallel_for_stack_.size();
if (!parallel_for_stack_.empty()) {
fea.parallel_len = GetLoopExtent(parallel_for_stack_.back());
fea.parallel_prod = 1.0;
for (const ForNode* pfor : parallel_for_stack_) {
fea.parallel_prod *= GetLoopExtent(pfor);
}
fea.parallel_type = AnnotationPosType::kPosMixed;
// GetAnnotationPosEncoding(parallel_for_stack_.back()->loop_var,
// node->args, pcompute->axis, pcompute->reduce_axis);
}
// GPU threads
fea.is_gpu = is_gpu_;
fea.blockIdx_x_len = blockIdx_x_len_;
fea.blockIdx_y_len = block_idx_y_len_;
fea.blockIdx_z_len = block_idx_z_len_;
fea.threadIdx_x_len = threadIdx_x_len_;
fea.threadIdx_y_len = thread_idx_y_len_;
fea.threadIdx_z_len = thread_idx_z_len_;
fea.vthread_len = vthread_len_;
}
// Extract buffer access related features (group 2)
void ExtractBufferAccessFeature(const BufferStoreNode* node, const MathOpCounter& math_op_counter,
double* cur_compute_ops, std::vector<float>* compute_ops_list,
std::vector<float>* mem_bytes_list) {
FeatureSet& fea = buffer_features[node->buffer];
// Extract all buffer accesses
std::vector<BufferAccessFeature> acc_feas;
BufferAccessExtractor buf_extractor;
buf_extractor.InsertAccess(node->buffer, BufferAccessType::kWrite, node->indices);
buf_extractor.ExtractReads(node->value);
// Compute touched region for all outer loops
for (auto x : for_loop_stack_) {
ana_.Bind(x->loop_var, Range::FromMinExtent(x->min, 1), true);
}
mem_bytes_list->reserve(for_loop_stack_.size());
compute_ops_list->reserve(for_loop_stack_.size());
*cur_compute_ops = math_op_counter.float_mad + math_op_counter.float_addsub +
math_op_counter.float_mul + math_op_counter.float_divmod +
math_op_counter.float_cmp + math_op_counter.float_math_func +
math_op_counter.float_other_func;
std::vector<int> tmp_region;
for (int i = static_cast<int>(for_loop_stack_.size()) - 1; i >= 0; i--) {
const ForNode* p_for = for_loop_stack_[i];
ana_.Bind(p_for->loop_var,
Range::FromMinExtent(for_loop_stack_[i]->min, for_loop_stack_[i]->extent), true);
// Note, here we do overwrite.
// So if there are multiple BufferStoreNode, the last one will overwrite the first few.
// e.g. The update part in gemm will overwrite the init part.
BufferMap<std::vector<std::tuple<BufferAccessType, int64_t, int>>>& buffer_regions_map =
for_touch_regions_[p_for];
int64_t mem_bytes = 0;
for (const auto& x : buf_extractor.buf_accesses) {
const Buffer& t = x.first;
const BufferAccess& acc = x.second;
ComputeRegion(acc.indices, &ana_, &tmp_region);
int64_t touched_size = ElementProduct(tmp_region);
buffer_regions_map[t].push_back(
std::make_tuple(acc.acc_type, touched_size, t->dtype.bytes()));
mem_bytes += touched_size * t->dtype.bytes();
}
mem_bytes_list->push_back(std::log2(mem_bytes));
*cur_compute_ops *= GetLoopExtent(for_loop_stack_[i]);
compute_ops_list->push_back(std::log2(*cur_compute_ops));
}
// Buffer access related features (per buffer)
for (const auto& x : buf_extractor.buf_accesses) {
const Buffer& t = x.first;
const BufferAccess& acc = x.second;
std::vector<int> int_shape;
for (const auto& dim : t->shape) {
int_shape.push_back(GetIntImm(dim));
}
size_t ele_bytes = t->dtype.bytes();
// calculate bytes
float bytes = outer_loop_prod_ * ele_bytes;
float unique_bytes;
// calculate cache lines
int64_t stride;
float lines;
float unique_lines;
if (for_loop_stack_.empty()) {
unique_bytes = ele_bytes;
stride = 0;
lines = 1.0f;
unique_lines = 1.0f;
} else {
unique_bytes =
std::get<1>(for_touch_regions_[for_loop_stack_.front()][t].front()) * ele_bytes;
stride = 0;
int64_t reduce_ratio = 1;
int i;
for (i = static_cast<int>(for_loop_stack_.size()) - 1; i >= 0; i--) {
stride = ComputeStride(acc.indices, int_shape, for_loop_stack_[i]->loop_var.get());
if (stride != 0) {
break;
}
reduce_ratio *= GetLoopExtent(for_loop_stack_.back());
}
lines = outer_loop_prod_ / reduce_ratio *
std::min(1.0f, 1.0f * stride * ele_bytes / cache_line_size_);
lines = std::max(lines, 1.0f);
// convert `stride` back to the stride of the innermost iterator
stride = (i == static_cast<int>(for_loop_stack_.size()) - 1 ? stride : 0);
float n_continuous = ele_bytes;
for (int i = std::min(static_cast<int>(tmp_region.size()) - 1,
static_cast<int>(int_shape.size()) - 1);
i >= 0; i--) {
if (tmp_region[i] == int_shape[i]) {
n_continuous *= tmp_region[i];
break;
}
}
unique_lines = unique_bytes / std::min(n_continuous, static_cast<float>(cache_line_size_));
unique_lines = std::max(unique_lines, 1.0f);
}
ReuseType reuse_type;
float reuse_dis_iter, reuse_dis_bytes, reuse_ct;
std::tie(reuse_type, reuse_dis_iter, reuse_dis_bytes, reuse_ct) =
ComputeReuse(t, acc.indices, for_loop_stack_, for_touch_regions_);
acc_feas.emplace_back();
BufferAccessFeature& acc_fea = acc_feas.back();
acc_fea.buffer_name = t->name;
acc_fea.acc_type = acc.acc_type;
acc_fea.stride = stride;
acc_fea.bytes = bytes;
acc_fea.unique_bytes = unique_bytes;
acc_fea.lines = lines;
acc_fea.unique_lines = unique_lines;
acc_fea.reuse_type = reuse_type;
acc_fea.reuse_dis_iter = reuse_dis_iter;
acc_fea.reuse_dis_bytes = reuse_dis_bytes;
acc_fea.reuse_ct = reuse_ct;
if (acc_fea.reuse_ct > 0.5) {
acc_fea.bytes_d_reuse_ct = bytes / reuse_ct;
acc_fea.unique_bytes_d_reuse_ct = unique_bytes / reuse_ct;
acc_fea.lines_d_reuse_ct = lines / reuse_ct;
acc_fea.unique_lines_d_reuse_ct = unique_lines / reuse_ct;
} else {
// no reuse, multiply by a magic number '2'
acc_fea.bytes_d_reuse_ct = bytes * 2;
acc_fea.unique_bytes_d_reuse_ct = unique_bytes * 2;
acc_fea.lines_d_reuse_ct = lines * 2;
acc_fea.unique_lines_d_reuse_ct = unique_lines * 2;
}
}
fea.access_feas = acc_feas;
}
// Extract arithmetic intensity related feature (group 3)
void ExtractArithmeticIntensityFeature(const BufferStoreNode* node, double cur_compute_ops,
const std::vector<float>& compute_ops_list,
const std::vector<float>& mem_bytes_list) {
FeatureSet& fea = buffer_features[node->buffer];
// Compute arithmetic intensity curve (y axis : arithmetic intensity, x axis : flops).
// We use piecewise linear interpolation to fit this curve.
int pt = 0;
if (cur_compute_ops <= 0 || compute_ops_list.empty()) {
std::fill(fea.arith_intensity_curve,
fea.arith_intensity_curve + ARITH_INTENSITY_CURVE_SAMPLE_N, 0.0);
} else {
for (size_t i = 0; i < ARITH_INTENSITY_CURVE_SAMPLE_N; ++i) {
float cur_compute_ops = compute_ops_list.back() * (i + 1) / ARITH_INTENSITY_CURVE_SAMPLE_N;
while (compute_ops_list[pt] < cur_compute_ops - 1e-4) {
pt++;
}
ICHECK_LT(pt, compute_ops_list.size());
float value;
if (pt == 0) {
value = compute_ops_list[pt] / mem_bytes_list[pt];
} else {
float base = compute_ops_list[pt - 1] / mem_bytes_list[pt - 1];
float slope = (compute_ops_list[pt] / mem_bytes_list[pt] -
compute_ops_list[pt - 1] / mem_bytes_list[pt - 1]) /
(compute_ops_list[pt] - compute_ops_list[pt - 1]);
value = base + slope * (cur_compute_ops - compute_ops_list[pt - 1]);
}
fea.arith_intensity_curve[i] = value;
}
}
}
// Extract allocation related features (group 4)
void ExtractAllocationFeature(const BufferRealizeNode* node) {
FeatureSet& fea = buffer_features[node->buffer];
float allocation_size = 1.0f;
for (const auto& x : node->bounds) {
allocation_size *= GetIntImm(x->extent);
}
// allocation feature
fea.alloc_size = allocation_size * node->buffer->dtype.bytes();
fea.alloc_prod = allocation_size * outer_loop_prod_;
fea.alloc_outer_prod = outer_loop_prod_;
fea.alloc_inner_prod = fea.outer_prod / outer_loop_prod_;
}
// Extract outer scope related features (group 5)
void ExtractOuterScopeFeature(const BufferStoreNode* node) {
FeatureSet& fea = buffer_features[node->buffer];
fea.outer_prod = outer_loop_prod_;
fea.num_loops = for_loop_stack_.size();
fea.auto_unroll_max_step = cur_auto_unroll_max_step_;
}
// Stores FeatureSet for every buffer
BufferMap<FeatureSet> buffer_features;
private:
// The shared arithmetic analyzer
Analyzer ana_;
// The product of outer loop
float outer_loop_prod_ = 1.0f;
// The stacks to store parent loops during DFS
std::vector<const ForNode*> for_loop_stack_;
std::vector<const ForNode*> parallel_for_stack_;
std::vector<const ForNode*> vec_for_stack_;
std::vector<const ForNode*> unroll_for_stack_;
// GPU-related features
bool is_gpu_{false};
int blockIdx_x_len_{1};