-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
infer_vector_layout.cc
1753 lines (1665 loc) · 70.5 KB
/
infer_vector_layout.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/* Copyright 2023 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <array>
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <ostream>
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/include/mlir/IR/Attributes.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "xla/layout.h"
namespace mlir::tpu {
#define GEN_PASS_DECL_INFERVECTORLAYOUTPASS
#define GEN_PASS_DEF_INFERVECTORLAYOUTPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
namespace {
using ImplicitDim = VectorLayout::ImplicitDim;
static constexpr int kLayoutLog = 10;
class Print {
public:
explicit Print(Operation *t) : payload_(t) {}
Operation *payload_;
private:
friend std::ostream &operator<<(std::ostream &, Print);
};
std::ostream &operator<<(std::ostream &os, Print p) {
std::string s;
llvm::raw_string_ostream tmp_os(s);
p.payload_->print(tmp_os);
os << tmp_os.str();
return os;
}
bool is_fully_replicated(const Layout &layout) {
static LayoutOffsets replicated_offsets = {std::nullopt, std::nullopt};
return layout.has_value() && layout->offsets() == replicated_offsets;
}
TiledLayoutAttr getMemRefLayout(Value ref) {
if (auto erase_op = ref.getDefiningOp<tpu::EraseLayoutOp>()) {
ref = erase_op.getOperand();
}
return cast<TiledLayoutAttr>(cast<MemRefType>(ref.getType()).getLayout());
}
LogicalResult verifyDivisibleIndex(Value tiled_index, int64_t tiling, int dim,
Operation *op) {
if (!isGuaranteedDivisible(tiled_index, tiling)) {
return op->emitOpError("cannot statically prove that index in dimension ")
<< dim << " is a multiple of " << tiling;
}
return success();
}
// TODO(apaszke): Test that this pass fills in NoLayout for all operations that
// have corresponding native instructions.
class VectorLayoutInferer {
public:
explicit VectorLayoutInferer(std::array<int64_t, 2> target_shape)
: target_shape_({target_shape[0], target_shape[1]}),
default_tiling_(target_shape) {}
#define TPU_CHECK_OP(cond, msg) \
if (!(cond)) { \
op->emitOpError(msg); \
return failure(); \
}
#define NYI(msg) \
op->emitOpError("not implemented: " msg); \
return failure();
LogicalResult inferBlock(
Block &block,
const std::function<LogicalResult(Operation *)> &match_terminator) {
for (Operation &any_op : block.without_terminator()) {
VLOG(kLayoutLog) << Print(&any_op);
if (any_op.hasAttr("in_layout") || any_op.hasAttr("out_layout")) {
if (auto op = dyn_cast<tpu::AssumeLayoutOp>(any_op)) {
TPU_CHECK_OP(
any_op.hasAttr("in_layout") && any_op.hasAttr("out_layout"),
"expect layout attributes in tpu::AssumeLayoutOp");
continue;
} else {
any_op.emitOpError("layout attributes already attached");
return failure();
}
}
bool has_vector_io = false;
for (auto op : any_op.getOperands()) {
has_vector_io |= op.getType().isa<VectorType>();
}
for (auto r : any_op.getResults()) {
has_vector_io |= r.getType().isa<VectorType>();
}
if (!has_vector_io && any_op.getRegions().empty()) {
llvm::SmallVector<Layout, 4> in_layout(any_op.getNumOperands(),
kNoLayout);
if (any_op.getNumResults() == 0) {
setInLayout(&any_op, in_layout);
} else if (any_op.getNumResults() == 1) {
setLayout(&any_op, in_layout, kNoLayout);
} else {
any_op.emitOpError("Multi-result ops not supported");
return failure();
}
} else if (auto op = dyn_cast<arith::ExtFOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<arith::TruncFOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<arith::ExtSIOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<arith::TruncIOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<arith::SelectOp>(any_op)) {
auto true_ty = dyn_cast<VectorType>(op.getTrueValue().getType());
auto false_ty = dyn_cast<VectorType>(op.getFalseValue().getType());
TPU_CHECK_OP(static_cast<bool>(true_ty) == static_cast<bool>(false_ty),
"Only one side of arith is a vector?");
if (true_ty) {
TPU_CHECK_OP(true_ty.getElementTypeBitWidth() == kNativeBitwidth &&
false_ty.getElementTypeBitWidth() == kNativeBitwidth,
"Only 32-bit select supported");
}
if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) {
return failure();
}
} else if (auto op = dyn_cast<arith::ExtUIOp>(any_op)) {
auto in_ty = dyn_cast<VectorType>(op.getIn().getType());
auto out_ty = dyn_cast<VectorType>(op.getType());
TPU_CHECK_OP(static_cast<bool>(in_ty) == static_cast<bool>(out_ty),
"Input and output are not both vectors?");
if (in_ty) {
TPU_CHECK_OP(in_ty.getElementTypeBitWidth() == 1 &&
out_ty.getElementTypeBitWidth() == 32,
"Only 1 bit -> 32 bit extensison supported");
}
if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) {
return failure();
}
} else if (isa<arith::CmpIOp>(any_op) || isa<arith::CmpFOp>(any_op)) {
auto lhs_ty = dyn_cast<VectorType>(any_op.getOperand(0).getType());
auto rhs_ty = dyn_cast<VectorType>(any_op.getOperand(1).getType());
TPU_CHECK_OP(static_cast<bool>(lhs_ty) == static_cast<bool>(rhs_ty),
"Only one side of cmp is a vector?");
if (lhs_ty) {
TPU_CHECK_OP(lhs_ty.getElementTypeBitWidth() == kNativeBitwidth &&
rhs_ty.getElementTypeBitWidth() == kNativeBitwidth,
"Only 32-bit cmp supported");
}
if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) {
return failure();
}
} else if (OpTrait::hasElementwiseMappableTraits(&any_op)) {
if (inferElementwise(&any_op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<arith::ConstantOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<cf::AssertOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<memref::LoadOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<scf::IfOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<scf::ForOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::RotateOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::ConcatenateOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::LoadOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::MatmulOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::StoreOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::EraseLayoutOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::IotaOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::GatherOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::BitcastOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::RepeatOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::TraceOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<tpu::RegionOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<vector::BroadcastOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<vector::ContractionOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<vector::ExtractOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<vector::LoadOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<vector::MultiDimReductionOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<vector::ShapeCastOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<vector::StoreOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op = dyn_cast<vector::TransposeOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else if (auto op =
llvm::dyn_cast<vector::ExtractStridedSliceOp>(any_op)) {
if (infer(op).failed()) {
return failure();
}
} else {
any_op.emitOpError("unsupported in vector layout inference");
return failure();
}
CHECK(any_op.getNumResults() == 0 || any_op.hasAttr("out_layout"));
CHECK(any_op.getNumOperands() == 0 || any_op.hasAttr("in_layout"));
}
return match_terminator(block.getTerminator());
}
LogicalResult infer(arith::ConstantOp op) {
if (op.getType().isSignlessIntOrIndexOrFloat()) {
setOutLayout(op, kNoLayout);
return success();
}
if (auto ty = dyn_cast<VectorType>(op.getType())) {
auto elems = dyn_cast<DenseElementsAttr>(op.getValue());
TPU_CHECK_OP(ty.getElementType().isSignlessIntOrIndexOrFloat(),
"expected scalar element type in vector");
TPU_CHECK_OP(ty.getRank() > 0, "rank 0 vectors unsupported");
TPU_CHECK_OP(elems, "expected vector constants to use DenseElementsAttr");
auto bitwidth = ty.getElementTypeBitWidth();
if (elems.isSplat()) {
if (ty.getRank() == 1) {
// Here, we choose to lay out along lanes arbitrarily. It would be
// equally valid to go with sublanes. Still, this value is so easy
// to relayout that it shouldn't really make a difference.
setOutLayout(op, VectorLayout(bitwidth, {std::nullopt, std::nullopt},
nativeTiling(bitwidth),
ImplicitDim::kSecondMinor));
} else { // ty.getRank() >= 2
setOutLayout(
op, VectorLayout(bitwidth, {std::nullopt, std::nullopt},
nativeTiling(bitwidth), ImplicitDim::kNone));
}
} else {
TPU_CHECK_OP(ty.getElementTypeBitWidth() == kNativeBitwidth,
"Only 32-bit non-splat constants supported");
if (ty.getRank() == 1) {
if (ty.getDimSize(0) <= target_shape_[0]) {
// Use 2D layout with replication.
NYI("small 1D constants");
} else { // NOLINT(readability-else-after-return)
NYI("large 1D constants");
}
} else { // ty.getRank() >= 2
setOutLayout(op, VectorLayout(kNativeBitwidth, {0, 0},
default_tiling_, ImplicitDim::kNone));
}
}
return success();
}
op.emitOpError("unsupported constant type");
return failure();
}
LogicalResult infer(arith::ExtFOp op) {
auto src_ty = dyn_cast<VectorType>(op.getIn().getType());
if (!src_ty) {
setLayout(op, kNoLayout, kNoLayout);
return success();
}
auto dst_ty = cast<VectorType>(op.getOut().getType());
auto some_layout = getLayout(op.getIn());
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 16 &&
dst_ty.getElementTypeBitWidth() == 32,
"Only 16-bit to 32-bit extensions supported");
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
auto &layout = *some_layout;
if (layout.implicit_dim() == ImplicitDim::kNone) {
Layout src_layout;
Layout dst_layout;
// All layouts that subdivide the rows of the default tiling evenly
// can be handled uniformly with the default case, by preserving the
// tiling through the op.
// TODO(apaszke): Support (16,128) too.
if (default_tiling_[0] % layout.tiling()[0] == 0 &&
default_tiling_[1] == layout.tiling()[1]) {
src_layout = layout;
} else {
src_layout = VectorLayout(16, layout.offsets(), default_tiling_,
ImplicitDim::kNone);
}
dst_layout = VectorLayout(32, layout.offsets(), src_layout->tiling(),
ImplicitDim::kNone);
setLayout(op, src_layout, dst_layout);
return success();
}
if (layout.implicit_dim() == ImplicitDim::kSecondMinor) {
TPU_CHECK_OP(layout.tiling() == nativeTiling(16), "unsupported tiling");
auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_,
layout.implicit_dim());
setLayout(op, some_layout, dst_layout);
return success();
}
op.emitOpError("unsupported extension layout");
return failure();
}
LogicalResult infer(arith::TruncFOp op) {
auto src_ty = dyn_cast<VectorType>(op.getIn().getType());
if (!src_ty) {
setLayout(op, kNoLayout, kNoLayout);
return success();
}
auto dst_ty = cast<VectorType>(op.getOut().getType());
auto some_layout = getLayout(op.getIn());
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32 &&
dst_ty.getElementTypeBitWidth() == 16,
"Only 32-bit to 16-bit truncation supported");
auto &layout = *some_layout;
if (layout.implicit_dim() == ImplicitDim::kNone) {
bool select_native = allUsersRequireNativeTiling(op.getResult());
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
ImplicitDim::kNone);
auto dst_layout =
VectorLayout(16, layout.offsets(),
select_native ? nativeTiling(16) : default_tiling_,
ImplicitDim::kNone);
setLayout(op, src_layout, dst_layout);
return success();
}
op.emitOpError("unsupported truncation layout");
return failure();
}
LogicalResult infer(arith::ExtSIOp op) {
auto src_ty = dyn_cast<VectorType>(op.getIn().getType());
if (!src_ty) {
setLayout(op, kNoLayout, kNoLayout);
return success();
}
auto dst_ty = cast<VectorType>(op.getOut().getType());
auto some_layout = getLayout(op.getIn());
TPU_CHECK_OP(dst_ty.getElementTypeBitWidth() == 32,
"Only extensions to 32-bit supported");
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
auto &layout = *some_layout;
if (layout.implicit_dim() == ImplicitDim::kNone) {
// TODO(apaszke): Support native layouts here.
auto src_layout = VectorLayout(layout.bitwidth(), layout.offsets(),
default_tiling_, ImplicitDim::kNone);
auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_,
ImplicitDim::kNone);
setLayout(op, src_layout, dst_layout);
return success();
}
if (layout.implicit_dim() == ImplicitDim::kSecondMinor) {
TPU_CHECK_OP(layout.tiling() == nativeTiling(16), "unsupported tiling");
auto dst_layout = VectorLayout(32, layout.offsets(), default_tiling_,
layout.implicit_dim());
setLayout(op, some_layout, dst_layout);
return success();
}
op.emitOpError("unsupported extension layout");
return failure();
}
LogicalResult infer(arith::TruncIOp op) {
auto src_ty = dyn_cast<VectorType>(op.getIn().getType());
if (!src_ty) {
setLayout(op, kNoLayout, kNoLayout);
return success();
}
auto dst_ty = cast<VectorType>(op.getOut().getType());
auto some_layout = getLayout(op.getIn());
TPU_CHECK_OP(src_ty.getElementTypeBitWidth() == 32,
"Only 32-bit truncation supported");
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
auto &layout = *some_layout;
if (layout.implicit_dim() == ImplicitDim::kNone) {
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
ImplicitDim::kNone);
bool select_native = allUsersRequireNativeTiling(op.getResult());
auto dst_layout = VectorLayout(
dst_ty.getElementTypeBitWidth(), layout.offsets(),
select_native ? nativeTiling(dst_ty.getElementTypeBitWidth())
: default_tiling_,
ImplicitDim::kNone);
setLayout(op, src_layout, dst_layout);
return success();
}
op.emitOpError("unsupported truncation layout");
return failure();
}
LogicalResult infer(cf::AssertOp op) {
setInLayout(op, {kNoLayout});
return success();
}
LogicalResult infer(func::FuncOp op) {
if (!op.getBody().hasOneBlock()) {
op.emitOpError("Only one block functions supported");
return failure();
}
return inferBlock(
op.getBody().front(), [this](Operation *op) -> LogicalResult {
TPU_CHECK_OP(isa<func::ReturnOp>(op),
"Expected func.return terminator");
for (Value o : op->getOperands()) {
TPU_CHECK_OP(!isa<VectorType>(o.getType()),
"vector returns unsupported");
}
SmallVector<Layout, 4> in_layout(op->getNumOperands(), {kNoLayout});
setInLayout(op, in_layout);
return success();
});
}
LogicalResult infer(memref::LoadOp op) {
TPU_CHECK_OP(op.getType().isSignlessIntOrIndexOrFloat(),
"memref.load with non-scalar result");
SmallVector<Layout, 5> in_layout(op.getNumOperands(), {kNoLayout});
setLayout(op, in_layout, kNoLayout);
return success();
}
LogicalResult infer(scf::IfOp op) {
static LogicalResult (*match_yield)(Operation *) = [](Operation *op) {
TPU_CHECK_OP(isa<scf::YieldOp>(op), "expected yield terminator");
return success();
};
TPU_CHECK_OP(op->getNumOperands() == 1, "expected one operand");
setInLayout(op, {kNoLayout});
if (inferBlock(*op.thenBlock(), match_yield).failed()) {
op.emitOpError("failed to infer layout for then branch");
return failure();
}
auto then_yield = op.thenBlock()->getTerminator();
TPU_CHECK_OP(then_yield->getOperandTypes() == op->getResultTypes(),
"scf if results and then branch yield operands do not match");
llvm::SmallVector<Layout, 4> result_layout;
result_layout.reserve(then_yield->getNumOperands());
for (const auto &operand : then_yield->getOperands()) {
if (operand.getType().isSignlessIntOrIndexOrFloat()) {
result_layout.push_back(kNoLayout);
} else if (isa<VectorType>(operand.getType())) {
result_layout.push_back(getLayout(operand));
} else {
op.emitOpError("unsupported scf.yield type");
return failure();
}
}
if (auto else_block = op.elseBlock()) {
if (inferBlock(*else_block, match_yield).failed()) {
op.emitOpError("failed to infer layout for else branch");
return failure();
}
}
if (op->getNumResults() == 0) {
return success();
}
// If the if op has results, it should have both then and else regions with
// yield op.
auto else_yield = op.elseBlock()->getTerminator();
TPU_CHECK_OP(else_yield->getOperandTypes() == op->getResultTypes(),
"scf if results and else branch yield operands do not match");
// Check each layout of the yield in else branch and override the
// result_layout if else branch's yield layout is less general. For example,
// if we yield offset (*, *) in then branch and offset (*, 0) in else
// branch, the result offset should be (*, 0).
for (int i = 0; i < else_yield->getNumOperands(); ++i) {
const auto &operand = else_yield->getOperand(i);
if (!isa<VectorType>(operand.getType())) {
continue;
}
auto shape = dyn_cast<VectorType>(operand.getType()).getShape();
auto layout = getLayout(operand);
CHECK(result_layout[i].has_value() && layout.has_value());
result_layout[i] =
VectorLayout::join(result_layout[i].value(), layout.value(), shape);
if (!result_layout[i].has_value()) {
op.emitOpError(
"failed to find a compatible layout in then and else branch for "
"output ")
<< i;
return failure();
}
}
setInLayout(then_yield, result_layout);
setInLayout(else_yield, result_layout);
setOutLayout(op, result_layout);
return success();
}
LogicalResult infer(scf::ForOp op) {
static LogicalResult (*match_yield)(Operation *) = [](Operation *op) {
TPU_CHECK_OP(isa<scf::YieldOp>(op), "expected yield terminator");
return success();
};
TPU_CHECK_OP(op.getRegion().hasOneBlock(),
"expected one block for scf.for");
TPU_CHECK_OP(
op.getNumRegionIterArgs() == op.getNumResults(),
"expected num_region_iter_args is equal to num_results in scf.for");
TPU_CHECK_OP(
op->getNumOperands() == 3 + op.getNumResults(),
"expected num_operands is equal to 3 + num_results in scf.for");
llvm::SmallVector<Layout, 4> in_layouts;
in_layouts.reserve(op->getNumOperands());
in_layouts.push_back(kNoLayout); // Lower bound.
in_layouts.push_back(kNoLayout); // Upper bound.
in_layouts.push_back(kNoLayout); // Step.
for (const auto &arg : op.getInitArgs()) {
if (arg.getType().isSignlessIntOrIndexOrFloat()) {
in_layouts.push_back(kNoLayout);
} else if (isa<VectorType>(arg.getType())) {
auto layout = getLayout(arg);
in_layouts.push_back(layout);
} else {
op.emitOpError() << "unsupported arg type " << arg.getType()
<< " in scf::for";
return failure();
}
}
ArrayRef<Layout> out_layouts = ArrayRef<Layout>(in_layouts).drop_front(3);
// Use tpu.assume_layout to annotate every block argument with the layout of
// the corresponding operand in forOp and replace all uses of the block
// argument with the result of tpu.assume_layout.
ImplicitLocOpBuilder builder =
ImplicitLocOpBuilder::atBlockBegin(op.getLoc(), op.getBody());
// Drop the induction_variable and layouts of bounds+step (respectively).
for (auto [iter_arg, layout] : llvm::zip_equal(
op.getBody()->getArguments().drop_front(1), out_layouts)) {
if (!dyn_cast<VectorType>(iter_arg.getType())) {
continue;
}
auto assume_layout_op =
builder.create<AssumeLayoutOp>(iter_arg.getType(), iter_arg);
setLayout(assume_layout_op, layout, layout);
iter_arg.replaceUsesWithIf(assume_layout_op, [&](OpOperand &operand) {
return operand.getOwner() != assume_layout_op;
});
}
if (inferBlock(*op.getBody(), match_yield).failed()) {
return failure();
}
auto yield_op = op.getBody()->getTerminator();
setInLayout(yield_op, out_layouts);
setLayout(op, in_layouts, out_layouts);
return success();
}
LogicalResult infer(tpu::RotateOp op) {
auto bitwidth = op.getType().getElementTypeBitWidth();
if (bitwidth != 32) {
NYI("Rotate with non-32-bit data");
}
if (op.getType().getRank() < 2) {
NYI("Unsupported 1D shape");
}
auto layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
setLayout(op, layout, layout);
return success();
}
LogicalResult infer(tpu::ConcatenateOp op) {
TPU_CHECK_OP(!op.getSources().empty(),
"Need at least one vector to concatenate");
auto res_rank = op.getType().getRank();
auto dimension = op.getDimension();
TPU_CHECK_OP(0 <= dimension && dimension < res_rank,
"Expect a valid concatenate dimension");
if (res_rank == 1) {
NYI("Support concatenation with 1D vectors");
}
auto res_ty = op.getResult().getType();
int8_t bitwidth = res_ty.getElementTypeBitWidth();
if (bitwidth != 32) {
NYI("Support concatenation with non 32-bit data");
}
auto layout = (dimension >= res_rank - 2)
? VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone)
: getLayout(op.getSources().front());
SmallVector<Layout> in_layouts(op->getNumOperands(), layout);
setLayout(op, in_layouts, in_layouts.back());
return success();
}
LogicalResult infer(tpu::LoadOp op) {
auto res_ty = op.getResult().getType();
int8_t bitwidth = res_ty.getElementTypeBitWidth();
// We expect the result is already a native-sized vreg.
TPU_CHECK_OP(bitwidth == 32 && res_ty.getShape()[0] == target_shape_[0] &&
res_ty.getShape()[1] == target_shape_[1],
"Only 32-bit loads supported");
SmallVector<Layout, 4> in_layout(op->getNumOperands(), kNoLayout);
auto out_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
setLayout(op, in_layout, out_layout);
return success();
}
LogicalResult infer(tpu::MatmulOp op) { return inferMatmul(op); }
LogicalResult infer(tpu::StoreOp op) {
auto store_ty = op.getValueToStore().getType();
int8_t bitwidth = store_ty.getElementTypeBitWidth();
// We expect the value to store is already a native-sized vreg.
TPU_CHECK_OP(bitwidth == 32 && store_ty.getShape()[0] == target_shape_[0] &&
store_ty.getShape()[1] == target_shape_[1],
"Only 32-bit stores supported");
auto store_layout = VectorLayout(bitwidth, {0, 0}, nativeTiling(bitwidth),
ImplicitDim::kNone);
SmallVector<Layout, 5> in_layout{store_layout};
in_layout.insert(in_layout.end(), op.getIndices().size() + 1, kNoLayout);
setInLayout(op, in_layout);
return success();
}
LogicalResult infer(tpu::EraseLayoutOp op) {
setLayout(op, kNoLayout, kNoLayout);
return success();
}
LogicalResult infer(tpu::GatherOp op) {
auto src_layout = getLayout(op.getSource());
setLayout(op, src_layout, src_layout);
return success();
}
LogicalResult infer(tpu::BitcastOp op) {
auto src_layout = getLayout(op.getInput());
LayoutOffsets src_offsets = src_layout->offsets();
if (src_offsets[0].value_or(0) || src_offsets[1].value_or(0)) {
NYI("unsupported bitcast with offsets");
}
if (src_layout->implicit_dim() != ImplicitDim::kNone) {
NYI("unsupported bitcast with an implicit dim");
}
// Check if input and output have same bit size.
auto in_ty = dyn_cast<VectorType>(op.getInput().getType());
auto out_ty = dyn_cast<VectorType>(op.getOutput().getType());
auto in_bitwidth = in_ty.getElementTypeBitWidth();
auto out_bitwidth = out_ty.getElementTypeBitWidth();
TPU_CHECK_OP(in_ty && out_ty && in_ty.getRank() == out_ty.getRank(),
"Input and output have different rank");
if (out_ty.getRank() < 2) {
NYI("Support bitcast with 1D vector");
}
for (int i = 0; i < in_ty.getRank(); ++i) {
auto in_dim = in_ty.getDimSize(i);
auto out_dim = out_ty.getDimSize(i);
// The sublane dimension is scaled down by the ratio of input element
// bitwidth to output element bitwidth when bitcasting. For example,
// bitcasting a vector<16x128xbf16> to a vector<8x128xi32> packs every 2
// rows in the bf16 vector into 1 row in the i32 vector. This means the
// bit representation of one i32 element vector[i,j] is equal to
// concatenating bf16 elements vector[2*i+1,j] and vector[2*i,j].
if (i == in_ty.getRank() - 2) {
in_dim *= in_bitwidth;
out_dim *= out_bitwidth;
}
TPU_CHECK_OP(in_dim == out_dim,
"Input and output have incompatible shape");
}
setLayout(op,
VectorLayout(in_bitwidth, src_offsets, nativeTiling(in_bitwidth),
ImplicitDim::kNone),
VectorLayout(out_bitwidth, src_offsets,
nativeTiling(out_bitwidth), ImplicitDim::kNone));
return success();
}
LogicalResult infer(tpu::RepeatOp op) {
auto src_layout = getLayout(op.getSource());
setLayout(op, src_layout, src_layout);
return success();
}
LogicalResult infer(tpu::TraceOp op) {
static LogicalResult (*match_yield)(Operation *) = [](Operation *op) {
TPU_CHECK_OP(isa<tpu::YieldOp>(op), "expected yield terminator");
return success();
};
TPU_CHECK_OP(op->getNumOperands() == 0, "expected no operands");
TPU_CHECK_OP(op->getNumResults() == 0, "results unsupported");
return inferBlock(*op.getBody(), match_yield);
}
LogicalResult infer(tpu::RegionOp op) {
static LogicalResult (*match_region)(Operation *) = [](Operation *op) {
TPU_CHECK_OP(isa<tpu::YieldOp>(op), "expected yield terminator");
return success();
};
TPU_CHECK_OP(op->getNumOperands() == 0, "expected no operands");
TPU_CHECK_OP(op->getNumResults() == 0, "results unsupported");
return inferBlock((*op).getRegion(0).getBlocks().front(), match_region);
}
LogicalResult infer(tpu::IotaOp op) {
auto ty = op.getResult().getType();
TPU_CHECK_OP(ty.getElementType().isSignlessInteger(32),
"Only 32-bit integer iota supported");
TPU_CHECK_OP(ty.getRank() >= 2, "iota rank below 2D unsupported");
LayoutOffsets offsets = {0, 0};
if (op.getDimension() == ty.getRank() - 1) {
offsets[0] = std::nullopt;
}
if (op.getDimension() == ty.getRank() - 2) {
offsets[1] = std::nullopt;
}
setOutLayout(op, VectorLayout(kNativeBitwidth, offsets, default_tiling_,
ImplicitDim::kNone));
return success();
}
LogicalResult infer(vector::BroadcastOp op) {
auto some_src_ty = op.getSourceType();
auto res_ty = op.getResultVectorType();
TPU_CHECK_OP(res_ty.getRank() > 0, "rank 0 vectors unsupported");
if (some_src_ty.isSignlessIntOrIndexOrFloat()) {
TPU_CHECK_OP(some_src_ty.getIntOrFloatBitWidth() == kNativeBitwidth,
"Only 32-bit broadcasts supported");
if (res_ty.getRank() == 1) {
// We use a full vreg tile, because only then its layout can be changed
// for free.
setLayout(op, kNoLayout,
VectorLayout(kNativeBitwidth, {std::nullopt, std::nullopt},
default_tiling_, ImplicitDim::kSecondMinor));
} else { // rank >= 2 // NOLINT(readability-else-after-return)
setLayout(op, kNoLayout,
VectorLayout(kNativeBitwidth, {std::nullopt, std::nullopt},
default_tiling_, ImplicitDim::kNone));
}
return success();
}
if (auto src_ty = dyn_cast<VectorType>(some_src_ty)) {
TPU_CHECK_OP(src_ty.getRank() >= 2, "source rank below 2D unsupported");
TPU_CHECK_OP(res_ty.getRank() >= 2, "result rank below 2D unsupported");
auto some_layout = getLayout(op.getSource());
TPU_CHECK_OP(some_layout.has_value(), "missing vector layout");
// We want to force the layout to be (8, 128) instead of (1, 128) if we
// are broadcasting sublane dim from 1 to at least 8.
if (some_layout->bitwidth() == kNativeBitwidth &&
some_layout->implicit_dim() == ImplicitDim::kNone &&
some_layout->tiling()[0] == 1 &&
some_layout->tiling()[1] == default_tiling_[1] &&
src_ty.getDimSize(src_ty.getRank() - 2) == 1 &&
res_ty.getDimSize(res_ty.getRank() - 2) >= 8) {
*some_layout = VectorLayout(
some_layout->bitwidth(), {std::nullopt, some_layout->offsets()[1]},
default_tiling_, some_layout->implicit_dim());
}
auto &layout = *some_layout;
if (layout.implicit_dim() != ImplicitDim::kNone) {
VectorLayout layout_2d(layout.bitwidth(), layout.offsets(),
layout.tiling(), ImplicitDim::kNone);
if (layout_2d.equivalentTo(layout, src_ty.getShape(), target_shape_)) {
layout = layout_2d;
} else {
op.emitOpError() << "Only 2D layouts supported";
return failure();
}
}
auto src_tiled_shape = src_ty.getShape().take_back(2);
auto dst_tiled_shape = res_ty.getShape().take_back(2);
LayoutOffsets offsets = layout.offsets();
if (layout.bitwidth() == kNativeBitwidth &&
layout.tiling() == default_tiling_) {
for (int i = 0; i < 2; ++i) {
if (src_tiled_shape[i] != dst_tiled_shape[i]) {
offsets[i] = std::nullopt;
}
}
}
setLayout(op, some_layout,
VectorLayout(layout.bitwidth(), offsets, layout.tiling(),
ImplicitDim::kNone));
return success();
}
op.emitOpError("unsupported broadcast source type");
return failure();
}
LogicalResult infer(vector::ContractionOp op) {
// TODO(apaszke): Support layout here, at least on batch dimensions.
TPU_CHECK_OP(op.getKind() == vector::CombiningKind::ADD,
"Only ADD supported");
auto ctx = op.getContext();
const auto matmul_iterator_types = mlir::ArrayAttr::get(
ctx,
{vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel),
vector::IteratorTypeAttr::get(ctx, vector::IteratorType::parallel),
vector::IteratorTypeAttr::get(ctx, vector::IteratorType::reduction)});
TPU_CHECK_OP(op.getIteratorTypes() == matmul_iterator_types,
"Not a matmul");
const auto matmul_indexing_maps = mlir::ArrayAttr::get(
ctx,
{AffineMapAttr::get(AffineMap::get(
3, 0, {getAffineDimExpr(0, ctx), getAffineDimExpr(2, ctx)}, ctx)),
AffineMapAttr::get(AffineMap::get(
3, 0, {getAffineDimExpr(2, ctx), getAffineDimExpr(1, ctx)}, ctx)),
AffineMapAttr::get(AffineMap::get(
3, 0, {getAffineDimExpr(0, ctx), getAffineDimExpr(1, ctx)},
ctx))});
const auto matmul_indexing_maps_transposed = mlir::ArrayAttr::get(
ctx,
{AffineMapAttr::get(AffineMap::get(
3, 0, {getAffineDimExpr(0, ctx), getAffineDimExpr(2, ctx)}, ctx)),
AffineMapAttr::get(AffineMap::get(
3, 0, {getAffineDimExpr(1, ctx), getAffineDimExpr(2, ctx)}, ctx)),
AffineMapAttr::get(AffineMap::get(
3, 0, {getAffineDimExpr(0, ctx), getAffineDimExpr(1, ctx)},
ctx))});
TPU_CHECK_OP(op.getIndexingMaps() == matmul_indexing_maps ||
op.getIndexingMaps() == matmul_indexing_maps_transposed,
"Not a matmul");
return inferMatmul(op);
}
LogicalResult infer(vector::ExtractOp op) {
TPU_CHECK_OP(!op.hasDynamicPosition(), "dynamic indices not supported");
TPU_CHECK_OP(
op.getSourceVectorType().getElementTypeBitWidth() == kNativeBitwidth,
"Only 32-bit types supported");
auto layout = getLayout(op.getVector());
TPU_CHECK_OP(layout.has_value(), "missing vector layout");
setLayout(op,
VectorLayout(kNativeBitwidth, {0, 0}, layout->tiling(),
layout->implicit_dim()),
kNoLayout);
return success();
}
LogicalResult infer(vector::LoadOp op) {
auto src_ty = op.getMemRefType();
auto res_ty = op.getVectorType();
TPU_CHECK_OP(src_ty.getRank() == res_ty.getRank(),
"memref and vector rank mismatch");
int64_t rank = res_ty.getRank();
int8_t bitwidth = res_ty.getElementTypeBitWidth();
auto maybe_tiling =
verifyMemoryTiling(op, getMemRefLayout(op.getBase()).getTiles(),
src_ty.getRank(), src_ty.getElementTypeBitWidth());
if (!maybe_tiling) {
return failure();
}
auto tiling = *maybe_tiling;
SmallVector<Layout, 4> in_layout(op->getNumOperands(), kNoLayout);
CHECK_EQ(op->getNumOperands(), op.getIndices().size() + 1);
SmallVector<int64_t, 2> tile_indices;
for (int i = 0; i < tiling.size(); ++i) {
int dim = rank - tiling.size() + i;
Value tiled_index = op.getIndices()[dim];
if (auto cst_op = tiled_index.getDefiningOp<arith::ConstantOp>()) {
tile_indices.push_back(cast<IntegerAttr>(cst_op.getValue()).getInt());
} else {
if (failed(verifyDivisibleIndex(tiled_index, tiling[0], dim, op))) {
return failure();
}
tile_indices.push_back(0);
}
}
if (rank == 0) {
op.emitOpError("rank 0 vectors unsupported");
return failure();
}
if (rank == 1) {
TPU_CHECK_OP(tiling.size() == 1, "Expected 1D tiling in 1D loads");
auto tile = tiling.front();
TPU_CHECK_OP(tile % target_shape_[1] == 0,
"Unsupported tiling for 1D load");
CHECK_EQ(tile_indices.size(), 1);
int64_t idx = tile_indices.front();
int64_t offset = idx % kVmemAlignment32;
// TODO(apaszke): We could generate replicated loads for short values.
setLayout(op, in_layout,