-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
apply_vector_layout.cc
4237 lines (4087 loc) · 180 KB
/
apply_vector_layout.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h"
#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <functional>
#include <memory>
#include <optional>
#include <tuple>
#include <utility>
#include <vector>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/iterator_range.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/include/mlir/IR/Builders.h"
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/include/mlir/IR/OperationSupport.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"
#include "xla/array.h"
#include "xla/layout.h"
#include "xla/util.h"
// TODO(tlongeri): Prefer returning failure over CHECKs. In particular, be more
// consistent about this for layout null checks in rules.
namespace mlir::tpu {
// TODO(tlongeri): Maybe just roll our own multi-dimensional array instead of
// using XLA's? There's too much glue for going from/to ArrayRef.
#define GEN_PASS_DECL_APPLYVECTORLAYOUTPASS
#define GEN_PASS_DEF_APPLYVECTORLAYOUTPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block);
namespace {
void moveAllRegions(Operation &src, Operation &dst) {
for (auto [src_region, dst_region] :
llvm::zip_equal(src.getRegions(), dst.getRegions())) {
dst_region.takeBody(src_region);
}
}
// Masks all values outside of bounds.
//
// Arguments:
// value: A rank 2 MLIR vector to be masked.
// bounds: A TargetTuple of slices specifying a rectangular subregion of value
// that should be preserved during masking.
// neutral: A scalar attribute specifying the value that will be inserted
// for all values outside of specified bounds.
//
// Returns:
// An MLIR value of the same type as the value argument, with all entries
// outside of bounds replaced by neutral.
FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
TypedValue<VectorType> value,
const VRegDataBounds &bounds,
const TypedAttr neutral) {
CHECK(llvm::equal(value.getType().getShape(), ctx.target_shape));
if (bounds.isComplete(ctx.target_shape)) {
return value;
}
FAILUREOR_ASSIGN_OR_RETURN(
TypedValue<VectorType> mask,
bounds.getVectorMask(builder, value.getLoc(), ctx.hardware_generation,
ctx.target_shape));
if (cast<IntegerType>(mask.getType().getElementType()).getWidth() != 1) {
return emitError(value.getLoc(),
"Not implemented: Unsupported mask bitwidth");
}
auto neutral_vec_ty = VectorType::get(ctx.target_shape, neutral.getType());
auto neutral_vec = builder.create<arith::ConstantOp>(
value.getLoc(), neutral_vec_ty,
DenseElementsAttr::get(neutral_vec_ty, neutral));
return builder
.create<arith::SelectOp>(value.getLoc(), mask, value, neutral_vec)
.getResult();
}
// Models Numpy's np.repeat, repeating each element `repeats` times along the
// specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is
// 3, this will return [1, 1, 1, 2, 2, 2].
xla::Array<Value> repeat(const xla::Array<Value> &src, const int repeats,
const int64_t axis) {
SmallVector<int64_t> dims(toArrayRef(src.dimensions()));
dims[axis] *= repeats;
xla::Array<Value> res(dims);
src.Each([&](absl::Span<const int64_t> idx, const Value v) {
SmallVector<int64_t> res_idx(toArrayRef(idx));
res_idx[axis] *= repeats;
for (int i = 0; i < repeats; ++i) {
res(res_idx) = v;
++res_idx[axis];
}
});
return res;
}
// Models Numpy's np.concatenate
xla::Array<Value> concatenate(const ArrayRef<xla::Array<Value>> arrays,
const int64_t axis) {
CHECK(!arrays.empty());
SmallVector<int64_t> dims(toArrayRef(arrays[0].dimensions()));
CHECK(0 <= axis && axis < dims.size());
for (size_t i = 1; i < arrays.size(); ++i) {
CHECK(arrays[i].dimensions() == arrays[0].dimensions());
dims[axis] += arrays[i].dim(axis);
}
xla::Array<Value> res(dims);
int64_t offset = 0;
for (xla::Array<Value> const& arr : arrays) {
arr.Each([&](const absl::Span<const int64_t> idx, const Value v) {
SmallVector<int64_t> res_idx(toArrayRef(idx));
res_idx[axis] += offset;
res(res_idx) = v;
});
offset += arr.dim(axis);
}
return res;
}
template <typename T>
ArrayRef<T> XlaArrayToFlatArrayRef(xla::Array<T> xla_array) {
return ArrayRef<T>(xla_array.data(), xla_array.num_elements());
}
template <typename T, typename Range>
xla::Array<T> XlaArrayFromShapeAndValues(ArrayRef<int64_t> sizes, Range vals) {
// TODO(tlongeri): is there no way to avoid default initialization in the
// constructor?
xla::Array<T> arr(sizes);
arr.SetValues(vals);
return arr;
}
bool incrementSliceIndex(const MutableArrayRef<int64_t> idx,
const absl::Span<const int64_t> starts,
const absl::Span<const int64_t> limits) {
const int64_t nd = idx.size();
CHECK_EQ(nd, starts.size());
CHECK_EQ(nd, limits.size());
for (int64_t i = nd - 1; i >= 0; --i) {
++idx[i];
if (idx[i] < limits[i]) {
return true;
}
idx[i] = starts[i];
}
return false;
}
bool incrementIndex(const MutableArrayRef<int64_t> idx,
const absl::Span<const int64_t> limits) {
const int64_t nd = idx.size();
CHECK_EQ(nd, limits.size());
for (int64_t i = nd - 1; i >= 0; --i) {
++idx[i];
if (idx[i] < limits[i]) {
return true;
}
idx[i] = 0;
}
return false;
}
bool sliceIsEmpty(const absl::Span<const int64_t> starts,
const absl::Span<const int64_t> limits) {
for (auto [s, l] : llvm::zip_equal(starts, limits)) {
CHECK_LE(s, l);
if (s == l) {
return true;
}
}
return false;
}
// An alternative to xla::Array::UpdateSlice that takes a single value
template <typename T>
void updateSlice(xla::Array<T> &arr, const T &value,
const absl::Span<const int64_t> starts,
const absl::Span<const int64_t> limits) {
if (sliceIsEmpty(starts, limits)) {
return;
}
SmallVector<int64_t> idx(toArrayRef(starts));
do {
arr(idx) = value;
} while (incrementSliceIndex(idx, starts, limits));
}
// An alternative to xla::Array::UpdateSlice that takes a range of data
template <typename T, typename Range>
void updateSliceFromRange(xla::Array<T> &arr, Range data,
const absl::Span<const int64_t> starts,
const absl::Span<const int64_t> limits) {
if (sliceIsEmpty(starts, limits)) {
return;
}
SmallVector<int64_t> idx(toArrayRef(starts));
auto data_it = data.begin();
do {
arr(idx) = *data_it;
++data_it;
} while (incrementSliceIndex(idx, starts, limits));
CHECK(data_it == data.end());
}
FailureOr<TypedAttr> getZeroIntOrFloatAttr(Type ty) {
if (isa<FloatType>(ty)) {
return TypedAttr(FloatAttr::get(ty, 0));
}
if (isa<IntegerType>(ty)) {
return TypedAttr(IntegerAttr::get(ty, 0));
}
return emitError(UnknownLoc::get(ty.getContext()), "Not implemented: ") << ty;
}
FailureOr<int64_t> getIntConst(Value v, bool silent = false) {
if (auto constant_op = v.getDefiningOp<arith::ConstantOp>()) {
if (auto integer_attr = dyn_cast<IntegerAttr>(constant_op.getValue())) {
return integer_attr.getValue().getSExtValue();
}
}
if (silent) {
return failure();
}
return emitError(v.getLoc(), "Expected an integer constant");
}
FailureOr<SmallVector<int64_t>> getIntConstsFromOperandRange(
ValueRange vals, bool silent = false) {
SmallVector<int64_t> res(vals.size());
for (int i = 0; i < vals.size(); ++i) {
FAILUREOR_ASSIGN_OR_RETURN(res[i], getIntConst(vals[i], silent));
}
return res;
}
FailureOr<std::pair<Value, SmallVector<int64_t>>> sliceRef(
ImplicitLocOpBuilder &builder, TypedValue<MemRefType> base_ref,
ArrayRef<int64_t> slice_shape, ValueRange indices,
ArrayRef<int64_t> tiling) {
IntegerType i32 = builder.getI32Type();
MemRefType ref_ty = base_ref.getType();
// MemRefSliceOp only allows tile-aligned slices. We pad the shape up
// accordingly with the padding. We don't include the static tiled indices
// in the slice when they can be arbitrary. But we do include dynamic tiled
// indices under the condition that they are divisible by the tile size.
SmallVector<int64_t> pad_slice_shape(slice_shape);
CHECK_LE(tiling.size(), slice_shape.size());
for (int i = 1; i <= tiling.size(); ++i) {
auto &dim = *(pad_slice_shape.end() - i);
dim = xla::RoundUpTo(dim, *(tiling.end() - i));
}
SmallVector<Value> slice_base_indices;
slice_base_indices.reserve(ref_ty.getRank());
for (auto idx : indices.drop_back(tiling.size())) {
slice_base_indices.push_back(builder.create<arith::IndexCastOp>(i32, idx));
}
Value c0 = nullptr;
SmallVector<int64_t> indices_within_slice(indices.size() - tiling.size(), 0);
for (auto tiled_idx : indices.take_back(tiling.size())) {
if (auto cst = getIntConst(tiled_idx, /*silent=*/true); succeeded(cst)) {
indices_within_slice.push_back(*cst);
if (!c0) {
c0 = builder.create<arith::ConstantOp>(i32,
builder.getI32IntegerAttr(0));
}
slice_base_indices.push_back(c0);
} else {
indices_within_slice.push_back(0);
// TODO: Check divisibility!
slice_base_indices.push_back(
builder.create<arith::IndexCastOp>(i32, tiled_idx));
}
}
// TODO(apaszke): Allow tile-aligned dynamic slicing on tiled dimensions.
Value sliced_ref = builder.create<tpu::MemRefSliceOp>(
MemRefType::get(pad_slice_shape, ref_ty.getElementType(),
ref_ty.getLayout(), ref_ty.getMemorySpace()),
base_ref, slice_base_indices);
return std::make_pair(sliced_ref, indices_within_slice);
}
// Returns the first-level tiling of a (packed and tiled) memref value.
FailureOr<std::array<int64_t, 2>> getMemRefTiling(
TypedValue<MemRefType> value, const std::array<int64_t, 2> target_shape) {
if (auto erase_layout_op =
dyn_cast_if_present<EraseLayoutOp>(value.getDefiningOp())) {
value = erase_layout_op.getOperand();
}
const MemRefType memref_ty = value.getType();
const auto mem_layout = dyn_cast<TiledLayoutAttr>(memref_ty.getLayout());
if (mem_layout == nullptr) {
return emitError(value.getLoc(), "Expected a tiled memref");
}
FAILUREOR_ASSIGN_OR_RETURN(int8_t bitwidth,
getTypeBitwidth(memref_ty.getElementType()));
const int packing = 32 / bitwidth;
const ArrayRef<xla::Tile> tiles = mem_layout.getTiles();
const xla::Tile &first_tile = tiles.front();
if (first_tile.dimensions().size() == 1) {
const int64_t tile_size = first_tile.dimension(0);
if (tile_size % (target_shape[1] * packing) != 0) {
return emitError(value.getLoc(), "Not implemented");
}
if (bitwidth == 32) {
if (tiles.size() > 1) {
return emitError(value.getLoc(), "Not implemented");
}
} else if (bitwidth < 32) {
if (tiles.drop_front() !=
ArrayRef<xla::Tile>{xla::Tile({target_shape[1]}),
xla::Tile({packing, 1})}) {
return emitError(value.getLoc(), "Not implemented");
}
}
return std::array<int64_t, 2>{1, tile_size};
}
if (first_tile.dimensions().size() == 2) {
if (bitwidth == 32) {
if (tiles.size() > 1) {
return emitError(value.getLoc(), "Not implemented");
}
return std::array<int64_t, 2>{first_tile.dimension(0),
first_tile.dimension(1)};
}
if (bitwidth < 32) {
if (tiles.size() != 2 || tiles[1] != xla::Tile({packing, 1})) {
return emitError(value.getLoc(), "Not implemented");
}
return std::array<int64_t, 2>{first_tile.dimension(0),
first_tile.dimension(1)};
}
}
return emitError(value.getLoc(), "Not implemented");
}
// Hoist a vector constant as an additional argument of the function.
FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
DenseElementsAttr value) {
MLIRContext *mlir_ctx = ctx.func.getContext();
Block &entry_block = ctx.func.getBody().front();
auto value_ty = cast<VectorType>(value.getType());
if (value_ty.getElementType().getIntOrFloatBitWidth() != 32) {
return ctx.func.emitOpError(
"Not implemented: Only 32-bit constants supported");
}
if (ctx.func->getAttr("scratch_operands")) {
return ctx.func.emitOpError(
"Not implemented: function has scratch_operands");
}
FAILUREOR_ASSIGN_OR_RETURN(
MemRefType arg_type,
inferMemref(
MemRefType::get(value_ty.getShape(), value_ty.getElementType()),
ctx.hardware_generation));
const BlockArgument argument =
entry_block.insertArgument(entry_block.getNumArguments() - 1, arg_type,
UnknownLoc::get(ctx.getMLIRContext()));
const FunctionType func_ty = ctx.func.getFunctionType();
// Adjust the function type.
SmallVector<Type> new_arg_tys(func_ty.getInputs());
new_arg_tys.insert(new_arg_tys.begin() + (new_arg_tys.size() - 1), arg_type);
const auto new_func_ty =
FunctionType::get(mlir_ctx, new_arg_tys, func_ty.getResults());
ctx.func.setFunctionType(new_func_ty);
// Adjust the constants attribute.
if (auto prev_cst = ctx.func->getAttrOfType<ArrayAttr>("vector_constants")) {
SmallVector<Attribute> vector_constants(prev_cst.getValue());
vector_constants.push_back(value);
ctx.func->setAttr("vector_constants",
ArrayAttr::get(ctx.func.getContext(), vector_constants));
} else {
ctx.func->setAttr("vector_constants",
ArrayAttr::get(ctx.func.getContext(), value));
}
// Adjust window params for the extra operand.
if (auto window_params =
ctx.func->getAttrOfType<ArrayAttr>("window_params")) {
const auto iteration_bounds =
ctx.func->getAttrOfType<DenseI64ArrayAttr>("iteration_bounds");
CHECK(iteration_bounds);
const int64_t iteration_rank = iteration_bounds.getSize();
const SmallVector<AffineExpr> zeros(
iteration_rank, getAffineConstantExpr(0, ctx.func.getContext()));
const auto transform_indices =
AffineMap::get(iteration_rank, 0, zeros, ctx.func.getContext());
const auto new_param = DictionaryAttr::get(
ctx.func.getContext(),
NamedAttribute(
StringAttr::get(ctx.func.getContext(), "transform_indices"),
AffineMapAttr::get(transform_indices)));
SmallVector<Attribute> window_params_values(window_params.getValue());
window_params_values.insert(window_params_values.end() - 1, new_param);
ctx.func->setAttr("window_params", ArrayAttr::get(ctx.func.getContext(),
window_params_values));
}
return argument;
}
FailureOr<VectorType> getNativeVregType(
Type elem_ty, const std::array<int64_t, 2> target_shape) {
FAILUREOR_ASSIGN_OR_RETURN(const int8_t bitwidth,
getTypeBitwidth<true>(elem_ty));
if (bitwidth == 32) {
return VectorType::get(target_shape, elem_ty);
}
// bitwidth != 32
return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth},
elem_ty);
}
// Returns empty vector on null attribute
FailureOr<SmallVector<Layout>> getLayoutArrayFromAttr(const Attribute attr) {
if (const auto array_attr = dyn_cast_if_present<ArrayAttr>(attr)) {
SmallVector<Layout> out_layouts;
out_layouts.reserve(array_attr.size());
for (const Attribute a : array_attr) {
if (auto layout_attr = dyn_cast_if_present<VectorLayoutAttr>(a)) {
out_layouts.push_back(layout_attr.getLayout());
} else {
return failure();
}
}
return out_layouts;
}
return SmallVector<Layout>{};
}
// TODO(tlongeri): Unify with infer_vector_layout.cc's getOutLayout.
FailureOr<SmallVector<Layout>> getOutLayout(Operation &op) {
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> out_layout,
getLayoutArrayFromAttr(op.getAttr("out_layout")));
if (out_layout.size() != op.getNumResults()) {
return failure();
}
return out_layout;
}
FailureOr<SmallVector<Layout>> getInLayout(Operation &op) {
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> in_layout,
getLayoutArrayFromAttr(op.getAttr("in_layout")));
if (in_layout.size() != op.getNumOperands()) {
return failure();
}
return in_layout;
}
LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK(OpTrait::hasElementwiseMappableTraits(&op));
if (op.getNumResults() != 1) {
return op.emitError("Not implemented: Only ops with one result supported");
}
CHECK_EQ(layouts_in.size(), op.getNumOperands());
CHECK_GT(layouts_in.size(), 0);
CHECK_EQ(layouts_out.size(), 1);
OpBuilder builder(&op);
if (!(layouts_out.front().has_value() &&
llvm::all_of(layouts_in,
[&](const Layout &l) { return l.has_value(); }))) {
return op.emitOpError(
"Not implemented: Null layout / non-vector operand in elementwise "
"operation");
}
const auto out_ty = cast<VectorType>(op.getResult(0).getType());
const VectorLayout &layout_out = *layouts_out.front();
if (!llvm::all_of(layouts_in, [&](const Layout &l) {
return l->generalizes(layout_out, out_ty.getShape(), ctx.target_shape);
})) {
return op.emitOpError(
"Not implemented: Incompatible layouts in elementwise operation");
}
const unsigned num_operands = op.getNumOperands();
SmallVector<xla::Array<Value>> in_vreg_arrays;
in_vreg_arrays.reserve(num_operands);
for (unsigned i = 0; i < num_operands; ++i) {
FAILUREOR_ASSIGN_OR_RETURN(xla::Array<Value> tile_array,
disassemble(builder, *layouts_in[i],
op.getOperand(i), ctx.target_shape));
in_vreg_arrays.emplace_back(std::move(tile_array));
}
FAILUREOR_ASSIGN_OR_RETURN(
const VectorType out_vreg_ty,
getNativeVregType(out_ty.getElementType(), ctx.target_shape));
NamedAttrList attributes(op.getAttrDictionary());
attributes.erase("in_layout");
attributes.erase("out_layout");
// Note that we have to broadcast to handle replicate dimensions.
SmallVector<int64_t> broadcasted_shape(
toArrayRef(in_vreg_arrays[0].dimensions()));
for (size_t i = 1; i < num_operands; ++i) {
SmallVector<int64_t> new_broadcasted_shape;
CHECK(OpTrait::util::getBroadcastedShape(
broadcasted_shape, toArrayRef(in_vreg_arrays[i].dimensions()),
new_broadcasted_shape));
broadcasted_shape = std::move(new_broadcasted_shape);
}
CHECK(broadcasted_shape ==
layout_out.tileArrayShape(out_ty.getShape(), ctx.target_shape));
// TODO(tlongeri): Can we avoid initializing the array before filling values?
xla::Array<Value> out_vreg_array(broadcasted_shape);
out_vreg_array.Each([&](absl::Span<const int64_t> idx, Value *out_vreg) {
SmallVector<Value> operands(num_operands);
for (unsigned i = 0; i < num_operands; ++i) {
// Handle indices for broadcasted dimensions
SmallVector<int64_t> operand_idx(toArrayRef(idx));
for (unsigned j = 0; j < idx.size(); ++j) {
if (in_vreg_arrays[i].dim(j) == 1) {
operand_idx[j] = 0;
}
}
operands[i] = in_vreg_arrays[i](operand_idx);
}
Operation *vreg_op =
builder.create(op.getLoc(), op.getName().getIdentifier(), operands,
out_vreg_ty, attributes.getAttrs());
CHECK(vreg_op);
CHECK_EQ(vreg_op->getNumResults(), 1);
*out_vreg = vreg_op->getResult(0);
});
op.replaceAllUsesWith(assemble(builder, out_ty, layout_out,
std::move(out_vreg_array), ctx.target_shape));
op.erase();
return success();
}
using rule_type = std::function<LogicalResult(
RewriteContext &, Operation &, ArrayRef<Layout>, ArrayRef<Layout>)>;
template <typename OpTy>
LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
const VectorLayout &layout_in,
const VectorLayout &layout_out) {
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
auto result_ty = cast<VectorType>(op.getResult().getType());
auto source_ty = cast<VectorType>(op.getIn().getType());
if (layout_out.bitwidth() != 32) {
return op.emitOpError(
"Not implemented: Only extensions to 32-bit supported");
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> input_vregs,
disassemble(builder, layout_in, op.getIn(), ctx.target_shape));
xla::Array<Value> output_vregs(
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
FAILUREOR_ASSIGN_OR_RETURN(
const VectorType res_vreg_ty,
getNativeVregType(result_ty.getElementType(), ctx.target_shape));
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
return op.emitOpError("Not implemented: Change of layout during the cast");
}
if (layout_in.offsets() != layout_out.offsets()) {
return op.emitOpError("Not implemented: Change of offsets during the cast");
}
switch (layout_in.implicit_dim()) {
case VectorLayout::ImplicitDim::kNone: {
if (layout_in.tiling() != layout_out.tiling()) {
return op.emitOpError(
"Not implemented: Changing tiling during the cast");
}
auto tiling = layout_in.tiling();
if (ctx.target_shape[0] % tiling[0] != 0 ||
ctx.target_shape[1] != tiling[1]) {
return op.emitOpError("Not implemented: tiling not supported");
}
const int packing = layout_in.packing();
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
SmallVector<int64_t> input_vreg_idxs(toArrayRef(idxs));
input_vreg_idxs.back() /= packing;
const int64_t vreg_part = idxs.back() % packing;
*v = builder.create<UnpackSubelementsOp>(
res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
});
} break;
case VectorLayout::ImplicitDim::kMinor:
return op.emitOpError(
"Not implemented: Only casts of lane-oriented values supported");
case VectorLayout::ImplicitDim::kSecondMinor: {
auto is_one_tile = [](VectorType vty, VectorLayout layout) {
auto implicit_shape = layout.implicitShape(vty.getShape());
auto tiled_shape = ArrayRef<int64_t>(implicit_shape).take_back(2);
return (layout.offsets()[0].value_or(0) + tiled_shape[0] <=
layout.tiling()[0]) &&
(layout.offsets()[1].value_or(0) + tiled_shape[1] <=
layout.tiling()[1]);
};
if (input_vregs.dimensions() != absl::Span<const int64_t>{1} ||
output_vregs.dimensions() != absl::Span<const int64_t>{1} ||
!is_one_tile(source_ty, layout_in) ||
!is_one_tile(result_ty, layout_out)) {
return op.emitOpError("Not implemented");
}
if (layout_in.offsets()[0] >= ctx.target_shape[0]) {
return op.emitOpError("Not implemented");
}
auto unpack_subelements_op = builder.create<UnpackSubelementsOp>(
res_vreg_ty, *input_vregs.begin(), 0);
output_vregs.Fill(unpack_subelements_op.getResult());
}
}
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
std::move(output_vregs), ctx.target_shape)
.getResult());
op.erase();
return success();
}
LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK(layouts_in.front().has_value());
CHECK(layouts_out.front().has_value());
auto extf_op = cast<arith::ExtFOp>(op);
if (layouts_in.front()->bitwidth() != 16 ||
layouts_out.front()->bitwidth() != 32) {
return op.emitOpError(
"Not implemented: Only 16-bit to 32-bit conversion supported");
}
return ext_op_rule_impl(ctx, extf_op, *layouts_in.front(),
*layouts_out.front());
}
LogicalResult arith_extsi_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK(layouts_in.front().has_value());
CHECK_EQ(layouts_out.size(), 1);
CHECK(layouts_out.front().has_value());
auto extsi_op = cast<arith::ExtSIOp>(op);
return ext_op_rule_impl(ctx, extsi_op, *layouts_in.front(),
*layouts_out.front());
}
template <typename OpTy>
LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
const VectorLayout &layout_in,
const VectorLayout &layout_out) {
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
auto result_ty = cast<VectorType>(op.getResult().getType());
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> input_vregs,
disassemble(builder, layout_in, op.getIn(), ctx.target_shape));
xla::Array<Value> output_vregs(
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
if (layout_in.bitwidth() != 32) {
return op.emitOpError("Not implemented: Only 32-bit truncation supported");
}
FAILUREOR_ASSIGN_OR_RETURN(
VectorType res_vreg_ty,
getNativeVregType(result_ty.getElementType(), ctx.target_shape));
if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone) {
if (layout_in.tiling() != ctx.target_shape) {
return op.emitOpError("Not implemented: Only (8,128) tiling supported");
}
if (layout_out.tiling() == ctx.target_shape) {
const int packing = layout_out.packing();
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
SmallVector<Value> parts;
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
idxs_local.back() *= packing;
for (int64_t i = 0; i < packing; ++i) {
parts.push_back(input_vregs(idxs_local));
// Pack any data lying around if OOB
if (idxs_local.back() < input_vregs.dimensions().back() - 1) {
++idxs_local.back();
}
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
});
} else if (layout_out.hasNativeTiling(ctx.target_shape)) {
int packing = layout_out.packing();
SmallVector<Value> parts;
parts.reserve(packing);
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
CHECK_GE(idxs.size(), 2);
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
idxs_local[idxs.size() - 2] *= packing;
parts.push_back(input_vregs(idxs_local));
idxs_local[idxs.size() - 2]++;
while (parts.size() < packing) {
if (*(idxs_local.end() - 2) < *(input_vregs.dimensions().end() - 2)) {
parts.push_back(input_vregs(idxs_local));
idxs_local[idxs.size() - 2]++;
} else {
// Once we run out of tiles, we can pick any one we like.
parts.push_back(parts.back());
}
}
*v = builder.create<PackSubelementsOp>(res_vreg_ty, parts);
parts.clear();
});
} else {
return op.emitOpError("Not implemented: unsupported output tiling");
}
op.replaceAllUsesWith(assemble(builder, result_ty, layout_out,
std::move(output_vregs), ctx.target_shape)
.getResult());
op.erase();
return success();
}
// TODO(tlongeri): why wasn't this part of the original code?
return op.emitOpError("Not implemented");
}
LogicalResult arith_truncf_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK(layouts_in.front().has_value());
CHECK_EQ(layouts_out.size(), 1);
CHECK(layouts_out.front().has_value());
auto truncf_op = cast<arith::TruncFOp>(op);
if (layouts_in.front()->bitwidth() != 32 ||
layouts_out.front()->bitwidth() != 16) {
return op.emitOpError(
"Not implemented: Only 32-bit to 16-bit conversion supported");
}
return trunc_op_rule_impl(ctx, truncf_op, *layouts_in.front(),
*layouts_out.front());
}
LogicalResult arith_trunci_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK(layouts_in.front().has_value());
CHECK_EQ(layouts_out.size(), 1);
CHECK(layouts_out.front().has_value());
auto trunci_op = cast<arith::TruncIOp>(op);
return trunc_op_rule_impl(ctx, trunci_op, *layouts_in.front(),
*layouts_out.front());
}
LogicalResult func_return_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK(layouts_out.empty());
for (const Layout &layout_in : layouts_in) {
if (layout_in.has_value()) {
return op.emitOpError("Vector-typed return values are not supported");
}
}
return success();
}
LogicalResult scf_for_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
scf::ForOp for_op = cast<scf::ForOp>(op);
CHECK_EQ(layouts_in.size(), for_op->getNumOperands());
CHECK_EQ(layouts_out.size(), for_op->getNumResults());
if (!llvm::equal(layouts_in.drop_front(3), layouts_out)) {
return op.emitOpError(
"Expected matched layouts in scf.for's inputs and outputs");
}
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> yield_in_layouts,
getInLayout(*for_op.getBody()->getTerminator()));
if (!llvm::equal(ArrayRef<Layout>(yield_in_layouts), layouts_out)) {
return op.emitOpError(
"Expected matched layouts in scf.yield operands and scf.for's results");
}
if (failed(applyLayoutBlock(ctx, *for_op.getBody()))) {
return failure();
}
if (op.getNumResults() == 0) {
return success();
}
OpBuilder builder(&op);
SmallVector<Value> unrolled_args;
for (int i = 0; i < layouts_in.size(); ++i) {
auto layout = layouts_in[i];
auto operand = for_op.getOperand(i);
if (i < 3) {
if (layout.has_value()) {
return op.emitOpError("Expected no layout for bounds and step");
}
continue;
}
if (auto vty = dyn_cast<VectorType>(operand.getType())) {
if (!layout.has_value()) {
return op.emitOpError("Expected layout for vector operand");
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> tiles,
disassemble(builder, *layout, operand, ctx.target_shape));
unrolled_args.append(tiles.begin(), tiles.end());
} else {
if (layout.has_value()) {
return op.emitOpError("Expected no layout for scalar operand");
}
unrolled_args.push_back(operand);
}
}
// Create a new scf::ForOp with unrolled args.
auto new_op = builder.create<scf::ForOp>(
for_op->getLoc(), for_op.getLowerBound(), for_op.getUpperBound(),
for_op.getStep(), unrolled_args);
int num_old_args = for_op.getBody()->getNumArguments();
SmallVector<Location> locs(new_op.getBody()->getNumArguments(),
for_op.getLoc());
for_op.getBody()->addArguments(TypeRange(new_op.getBody()->getArguments()),
locs);
builder.setInsertionPointToStart(for_op.getBody());
auto arg_idx = num_old_args;
// Block also has an induction variable that should have no layout,
// which conveniently matches the in layouts.
for (auto [old_arg, layout] : llvm::zip_equal(
for_op.getBody()->getArguments().take_front(num_old_args),
layouts_in.drop_front(2))) {
if (const auto vty = dyn_cast<VectorType>(old_arg.getType())) {
CHECK(layout.has_value());
const SmallVector<int64_t> tiles_shape =
layout->tileArrayShape(vty.getShape(), ctx.target_shape);
const int64_t num_vectors = ShapedType::getNumElements(tiles_shape);
xla::Array<Value> tiles(tiles_shape);
CHECK_LE(arg_idx + num_vectors, for_op.getBody()->getNumArguments());
tiles.SetValues(llvm::make_range(
for_op.getBody()->getArguments().begin() + arg_idx,
for_op.getBody()->getArguments().begin() + arg_idx + num_vectors));
arg_idx += num_vectors;
RollVectorsOp rolled_op =
assemble(builder, vty, *layout, tiles, ctx.target_shape);
old_arg.replaceUsesWithIf(rolled_op, [&](OpOperand &operand) {
return operand.getOwner() != rolled_op;
});
} else {
CHECK(!layout.has_value());
old_arg.replaceAllUsesWith(for_op.getBody()->getArgument(arg_idx));
++arg_idx;
}
}
for_op.getBody()->eraseArguments(0, num_old_args);
new_op.getRegion().takeBody(for_op.getRegion());
// Roll the results back to the original shapes.
builder.setInsertionPointAfter(new_op);
int64_t res_idx = 0;
SmallVector<Value> rolled_results;
for (auto [result, layout] :
llvm::zip_equal(for_op.getResults(), layouts_out)) {
if (const auto vty = dyn_cast<VectorType>(result.getType())) {
CHECK(layout.has_value());
const SmallVector<int64_t> tiles_shape =
layout->tileArrayShape(vty.getShape(), ctx.target_shape);
const int64_t num_vectors = ShapedType::getNumElements(tiles_shape);
xla::Array<Value> tiles(tiles_shape);
CHECK_LE(res_idx + num_vectors, new_op.getResults().size());
tiles.SetValues(llvm::make_range(
new_op.getResults().begin() + res_idx,
new_op.getResults().begin() + res_idx + num_vectors));
res_idx += num_vectors;
RollVectorsOp rolled_op =
assemble(builder, vty, *layout, tiles, ctx.target_shape);
rolled_results.push_back(rolled_op);
} else {
CHECK(!layout.has_value());
rolled_results.push_back(new_op.getResult(res_idx));
++res_idx;
}
}
for_op.replaceAllUsesWith(rolled_results);
for_op.erase();
return success();
}
LogicalResult scf_if_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK(!layouts_in.front().has_value());
ImplicitLocOpBuilder builder(op.getLoc(), &op);
scf::IfOp if_op = cast<scf::IfOp>(op);
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> then_yield_in_layouts,
getInLayout(*if_op.thenYield()));
// TODO(tlongeri): ArrayRef<Layout> conversion should not be necessary, fix
// after LLVM adds const qualifiers to ==/!= operators. Also
// applies to else_yield_in_layouts comparison below.
if (!layouts_out.empty() &&
ArrayRef<Layout>(then_yield_in_layouts) != layouts_out) {
return op.emitOpError(
"Not implemented: different layouts in then yield's operands and if's "
"results");
}
if (failed(applyLayoutBlock(ctx, *if_op.thenBlock()))) {
return failure();
}
if (if_op.getElseRegion().empty()) {
CHECK_EQ(if_op->getNumResults(), 0)
<< "Expected no results if op does not have an else block";
CHECK_EQ(layouts_out.size(), 0);
return success();
}
FAILUREOR_ASSIGN_OR_RETURN(const SmallVector<Layout> else_yield_in_layouts,
getInLayout(*if_op.elseYield()));
if (!layouts_out.empty() &&
ArrayRef<Layout>(else_yield_in_layouts) != layouts_out) {
return op.emitOpError(
"Not implemented: different layouts in else yield's operands and if's "
"results");
}
if (failed(applyLayoutBlock(ctx, *if_op.elseBlock()))) {
return failure();
}
// Apply layout to results after applying layout in the true and false
// regions.
if (if_op.getNumResults() == 0) {
CHECK_EQ(layouts_out.size(), 0);
return success();
}
CHECK_EQ(if_op.getNumResults(), layouts_out.size());
// If scf.if has results, it should have both non-empty true and false
// regions.
CHECK(!if_op.getThenRegion().empty() && !if_op.getElseRegion().empty());
// Move true and false regions to the new if op whose result has same type and
// layout as yield operand's.
auto new_op = builder.create<scf::IfOp>(
TypeRange(if_op.thenYield().getResults()), if_op.getCondition(),
/*withElseRegion =*/true);
moveAllRegions(*if_op, *new_op);
int64_t index = 0;
SmallVector<Value> rolled_results;
for (auto [result, layout] :
llvm::zip_equal(if_op.getResults(), layouts_out)) {
if (const auto vty = dyn_cast<VectorType>(result.getType())) {
// When the result has a vector type, assemble the result.
CHECK(layout.has_value());
const SmallVector<int64_t> tiles_shape =
layout->tileArrayShape(vty.getShape(), ctx.target_shape);
const int64_t num_vectors = ShapedType::getNumElements(tiles_shape);
xla::Array<Value> tiles(tiles_shape);