-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
apply_vector_layout.cc
2073 lines (1996 loc) · 87.3 KB
/
apply_vector_layout.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#include <algorithm>
#include <array>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <functional>
#include <memory>
#include <optional>
#include <tuple>
#include <utility>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "absl/log/check.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
#include "jaxlib/mosaic/dialect/tpu/layout.h"
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
#include "jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h"
#include "jaxlib/mosaic/dialect/tpu/util.h"
#include "xla/array.h"
#include "xla/layout.h"
#include "xla/util.h"
// TODO(tlongeri): Prefer returning failure over CHECKs. In particular, be more
// consistent about this for layout null checks in rules.
#define NYI(msg) \
op->emitOpError("not implemented: " msg); \
return failure();
namespace mlir::tpu {
// TODO(tlongeri): Maybe just roll our own multi-dimensional array instead of
// using XLA's? There's too much glue for going from/to ArrayRef.
#define GEN_PASS_DECL_APPLYVECTORLAYOUTPASS
#define GEN_PASS_DEF_APPLYVECTORLAYOUTPASS
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
struct RewriteContext {
func::FuncOp func;
OpBuilder &builder;
// TODO(tlongeri): target_shape should be determined from hardware_generation
const int hardware_generation;
const std::array<int64_t, 2> target_shape;
};
LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block);
RollVectorsOp assemble(RewriteContext &ctx, VectorType vty,
const VectorLayout &layout, xla::Array<Value> vals);
FailureOr<xla::Array<Value>> disassemble(RewriteContext &ctx,
const VectorLayout &layout, Value val);
namespace {
// Models Numpy's np.repeat, repeating each element `repeats` times along the
// specified axis. For example, if `src` is [1, 2], `axis` is 0 and `repeats` is
// 3, this will return [1, 1, 1, 2, 2, 2].
xla::Array<Value> repeat(const xla::Array<Value> &src, const int repeats,
const int64_t axis) {
SmallVector<int64_t> dims(toArrayRef(src.dimensions()));
dims[axis] *= repeats;
xla::Array<Value> res(dims);
src.Each([&](absl::Span<const int64_t> idx, const Value v) {
SmallVector<int64_t> res_idx(toArrayRef(idx));
res_idx[axis] *= repeats;
for (int i = 0; i < repeats; ++i) {
res(res_idx) = v;
++res_idx[axis];
}
});
return res;
}
FailureOr<TypedAttr> getZeroIntOrFloatAttr(Type ty) {
if (isa<FloatType>(ty)) {
return TypedAttr(FloatAttr::get(ty, 0));
}
if (isa<IntegerType>(ty)) {
return TypedAttr(IntegerAttr::get(ty, 0));
}
return emitError(UnknownLoc::get(ty.getContext()), "Not implemented: ") << ty;
}
FailureOr<int64_t> getIntConst(Value v) {
if (auto constant_op = v.getDefiningOp<arith::ConstantOp>()) {
if (auto integer_attr = dyn_cast<IntegerAttr>(constant_op.getValue())) {
return integer_attr.getValue().getSExtValue();
}
}
return emitError(v.getLoc(), "Expected an integer constant");
}
FailureOr<SmallVector<int64_t>> getIntConstsFromOperandRange(
OperandRange vals) {
SmallVector<int64_t> res(vals.size());
for (int i = 0; i < vals.size(); ++i) {
FAILUREOR_ASSIGN_OR_RETURN(res[i], getIntConst(vals[i]));
}
return res;
}
// Returns the first-level tiling of a (packed and tiled) memref value.
FailureOr<std::array<int64_t, 2>> getMemRefTiling(
TypedValue<MemRefType> value, const std::array<int64_t, 2> target_shape) {
if (auto erase_layout_op =
dyn_cast_if_present<EraseLayoutOp>(value.getDefiningOp())) {
value = erase_layout_op.getOperand();
}
const MemRefType memref_ty = value.getType();
const auto mem_layout = dyn_cast<TiledLayoutAttr>(memref_ty.getLayout());
if (mem_layout == nullptr) {
return emitError(value.getLoc(), "Expected a tiled memref");
}
FAILUREOR_ASSIGN_OR_RETURN(int8_t bitwidth,
getTypeBitwidth(memref_ty.getElementType()));
const int packing = 32 / bitwidth;
const ArrayRef<xla::Tile> tiles = mem_layout.getTiles();
const xla::Tile &first_tile = tiles.front();
if (first_tile.dimensions().size() == 1) {
const int64_t tile_size = first_tile.dimension(0);
if (tile_size % (target_shape[1] * packing) != 0) {
return emitError(value.getLoc(), "Not implemented");
}
if (bitwidth == 32) {
if (tiles.size() > 1) {
return emitError(value.getLoc(), "Not implemented");
}
} else if (bitwidth < 32) {
if (tiles.drop_front() !=
ArrayRef<xla::Tile>{xla::Tile({target_shape[1]}),
xla::Tile({packing, 1})}) {
return emitError(value.getLoc(), "Not implemented");
}
}
return std::array<int64_t, 2>{1, tile_size};
}
if (first_tile.dimensions().size() == 2) {
if (bitwidth == 32) {
if (tiles.size() > 1) {
return emitError(value.getLoc(), "Not implemented");
}
return std::array<int64_t, 2>{first_tile.dimension(0),
first_tile.dimension(1)};
}
if (bitwidth < 32) {
if (tiles.size() != 2 || tiles[1] != xla::Tile({packing, 1})) {
return emitError(value.getLoc(), "Not implemented");
}
return std::array<int64_t, 2>{first_tile.dimension(0),
first_tile.dimension(1)};
}
}
return emitError(value.getLoc(), "Not implemented");
}
// Hoist a vector constant as an additional argument of the function.
FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
DenseElementsAttr value) {
MLIRContext *mlir_ctx = ctx.func.getContext();
Block &entry_block = ctx.func.getBody().front();
auto value_ty = cast<VectorType>(value.getType());
if (value_ty.getElementType().getIntOrFloatBitWidth() != 32) {
return ctx.func.emitOpError("Only 32-bit constants supported");
}
if (ctx.func->getAttr("scratch_operands")) {
return ctx.func.emitOpError(
"Not implemented: function has scratch_operands");
}
FAILUREOR_ASSIGN_OR_RETURN(
MemRefType arg_type,
inferMemref(
MemRefType::get(value_ty.getShape(), value_ty.getElementType()),
ctx.hardware_generation));
const BlockArgument argument = entry_block.insertArgument(
entry_block.getNumArguments() - 1, arg_type, ctx.builder.getUnknownLoc());
const FunctionType func_ty = ctx.func.getFunctionType();
// Adjust the function type.
SmallVector<Type> new_arg_tys(func_ty.getInputs());
new_arg_tys.insert(new_arg_tys.begin() + (new_arg_tys.size() - 1), arg_type);
const auto new_func_ty =
FunctionType::get(mlir_ctx, new_arg_tys, func_ty.getResults());
ctx.func.setFunctionType(new_func_ty);
// Adjust the constants attribute.
if (auto prev_cst = ctx.func->getAttrOfType<ArrayAttr>("vector_constants")) {
SmallVector<Attribute> vector_constants(prev_cst.getValue());
vector_constants.push_back(value);
ctx.func->setAttr("vector_constants",
ArrayAttr::get(ctx.func.getContext(), vector_constants));
}
// Adjust window params for the extra operand.
if (auto window_params =
ctx.func->getAttrOfType<ArrayAttr>("window_params")) {
const auto iteration_bounds =
ctx.func->getAttrOfType<DenseI64ArrayAttr>("iteration_bounds");
CHECK(iteration_bounds);
const int64_t iteration_rank = iteration_bounds.getSize();
const SmallVector<AffineExpr> zeros(
iteration_rank, getAffineConstantExpr(0, ctx.func.getContext()));
const auto transform_indices =
AffineMap::get(iteration_rank, 0, zeros, ctx.func.getContext());
const auto new_param = DictionaryAttr::get(
ctx.func.getContext(),
NamedAttribute(
StringAttr::get(ctx.func.getContext(), "transform_indices"),
AffineMapAttr::get(transform_indices)));
SmallVector<Attribute> window_params_values(window_params.getValue());
window_params_values.push_back(new_param);
ctx.func->setAttr("window_params", ArrayAttr::get(ctx.func.getContext(),
window_params_values));
}
return argument;
}
FailureOr<VectorType> getNativeVregType(
Type elem_ty, const std::array<int64_t, 2> target_shape) {
FAILUREOR_ASSIGN_OR_RETURN(const int8_t bitwidth,
getTypeBitwidth<true>(elem_ty));
if (bitwidth == 32) {
return VectorType::get(target_shape, elem_ty);
}
// bitwidth != 32
return VectorType::get({target_shape[0], target_shape[1], 32 / bitwidth},
elem_ty);
}
LogicalResult elementwise_op_rule(
RewriteContext &ctx, Operation &op, const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out,
std::function<FailureOr<Operation *>(RewriteContext &, ArrayRef<Value>)>
factory) {
CHECK_EQ(layouts_in.size(), op.getNumOperands());
CHECK_GT(layouts_in.size(), 0);
CHECK_EQ(layouts_out.size(), 1);
if (!(layouts_out.front().has_value() &&
llvm::all_of(layouts_in,
[&](const Layout &l) { return l.has_value(); }))) {
return op.emitOpError("null layout in elementwise operation");
}
const auto vty = cast<VectorType>(op.getResult(0).getType());
const VectorLayout &layout_out = *layouts_out.front();
if (!llvm::all_of(layouts_in, [&](const Layout &l) {
return l->generalizes(layout_out, vty.getShape(), ctx.target_shape);
})) {
return op.emitOpError("incompatible layouts in elementwise operation");
}
const unsigned num_operands = op.getNumOperands();
SmallVector<xla::Array<Value>> in_tile_arrays;
in_tile_arrays.reserve(num_operands);
for (unsigned i = 0; i < num_operands; ++i) {
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> tile_array,
disassemble(ctx, *layouts_in[i], op.getOperand(i)));
in_tile_arrays.emplace_back(std::move(tile_array));
}
// Note that we have to broadcast to handle replicate dimensions.
SmallVector<int64_t> broadcasted_shape(
toArrayRef(in_tile_arrays[0].dimensions()));
for (size_t i = 1; i < num_operands; ++i) {
SmallVector<int64_t> new_broadcasted_shape;
CHECK(OpTrait::util::getBroadcastedShape(
broadcasted_shape, toArrayRef(in_tile_arrays[i].dimensions()),
new_broadcasted_shape));
broadcasted_shape = std::move(new_broadcasted_shape);
}
// TODO(tlongeri): Can we avoid initializing the array before filling values?
xla::Array<Value> out_tile_array(broadcasted_shape);
absl::Status status =
out_tile_array.EachStatus([&](absl::Span<const int64_t> idx, Value *v) {
SmallVector<Value> operands(num_operands);
for (unsigned i = 0; i < num_operands; ++i) {
// Handle indices for broadcasted dimensions
SmallVector<int64_t> operand_idx(toArrayRef(idx));
for (unsigned j = 0; j < idx.size(); ++j) {
if (in_tile_arrays[i].dim(j) == 1) {
operand_idx[j] = 0;
}
}
operands[i] = in_tile_arrays[i](operand_idx);
}
FailureOr<Operation *> failure_or_tile_op = factory(ctx, operands);
if (failed(failure_or_tile_op)) {
return absl::InvalidArgumentError("");
}
Operation *tile_op = *failure_or_tile_op;
CHECK(tile_op);
CHECK_EQ(tile_op->getNumResults(), 1);
*v = tile_op->getResult(0);
return absl::OkStatus();
});
if (!status.ok()) {
return failure();
}
op.replaceAllUsesWith(
assemble(ctx, vty, layout_out, std::move(out_tile_array)));
op.erase();
return success();
}
// Helper for index_sequence expansion
template <typename T, std::size_t>
using Wrapper = T;
template <std::size_t... I>
LogicalResult elementwise_op_rule_unpacked_impl(
RewriteContext &ctx, Operation &op, const ArrayRef<Layout> layout_in,
const ArrayRef<Layout> layout_out,
std::function<FailureOr<Operation *>(RewriteContext &ctx,
Wrapper<Value, I>...)>
factory,
std::index_sequence<I...>) {
return elementwise_op_rule(
ctx, op, layout_in, layout_out,
[&](RewriteContext &ctx,
ArrayRef<Value> operands) -> FailureOr<Operation *> {
if (operands.size() != sizeof...(I)) {
return failure();
}
return factory(ctx, operands[I]...);
});
}
// Like elementwise_op_rule, but operands are "unpacked" into individual
// arguments for the factory.
// Returns failure if the number of operands is not the one expected (i.e. it
// doesn't match NumOperands).
template <std::size_t NumOperands, typename Func>
LogicalResult elementwise_op_rule_unpacked(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out,
Func factory) {
return elementwise_op_rule_unpacked_impl(
ctx, op, layouts_in, layouts_out, std::move(factory),
std::make_index_sequence<NumOperands>());
}
using rule_type = std::function<LogicalResult(
RewriteContext &, Operation &, ArrayRef<Layout>, ArrayRef<Layout>)>;
LogicalResult arith_cmpf_rule(RewriteContext &ctx, Operation &op,
ArrayRef<Layout> layouts_in,
ArrayRef<Layout> layouts_out) {
auto cmpf_op = cast<arith::CmpFOp>(op);
return elementwise_op_rule_unpacked<2>(
ctx, op, layouts_in, layouts_out,
[&](RewriteContext &ctx, const Value lhs,
const Value rhs) -> FailureOr<Operation *> {
return ctx.builder
.create<arith::CmpFOp>(cmpf_op.getLoc(), cmpf_op.getPredicateAttr(),
lhs, rhs)
.getOperation();
});
}
LogicalResult arith_cmpi_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
auto cmpi_op = cast<arith::CmpIOp>(op);
return elementwise_op_rule_unpacked<2>(
ctx, op, layouts_in, layouts_out,
[&](RewriteContext &ctx, const Value lhs,
const Value rhs) -> FailureOr<Operation *> {
return ctx.builder
.create<arith::CmpIOp>(cmpi_op.getLoc(), cmpi_op.getPredicateAttr(),
lhs, rhs)
.getOperation();
});
}
LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
auto extui_op = cast<arith::ExtUIOp>(op);
const Type elem_ty =
cast<VectorType>(extui_op.getResult().getType()).getElementType();
return elementwise_op_rule_unpacked<1>(
ctx, op, layouts_in, layouts_out,
[&](RewriteContext &ctx, const Value x) -> FailureOr<Operation *> {
const VectorType x_ty = cast<VectorType>(x.getType());
const VectorType out_ty = VectorType::get(x_ty.getShape(), elem_ty);
return ctx.builder.create<arith::ExtUIOp>(extui_op.getLoc(), out_ty, x)
.getOperation();
});
}
template <typename OpTy>
LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
const VectorLayout &layout_in,
const VectorLayout &layout_out) {
auto result_ty = cast<VectorType>(op.getResult().getType());
if (layout_out.bitwidth() != 32) {
return op.emitOpError("Only extensions to 32-bit supported");
}
FAILUREOR_ASSIGN_OR_RETURN(const xla::Array<Value> input_vregs,
disassemble(ctx, layout_in, op.getIn()));
xla::Array<Value> output_vregs(
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
FAILUREOR_ASSIGN_OR_RETURN(
const VectorType res_vreg_ty,
getNativeVregType(result_ty.getElementType(), ctx.target_shape));
if (layout_in.implicit_dim() != layout_out.implicit_dim()) {
return op.emitOpError("Not implemented: Change of layout during the cast");
}
switch (layout_in.implicit_dim()) {
case VectorLayout::ImplicitDim::kNone: {
if (layout_in.tiling() != ctx.target_shape ||
layout_out.tiling() != ctx.target_shape) {
return op.emitOpError("Not implemented: tiling not supported");
}
const int packing = layout_in.packing();
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
SmallVector<int64_t> input_vreg_idxs(toArrayRef(idxs));
input_vreg_idxs.back() /= packing;
const int64_t vreg_part = idxs.back() % packing;
*v = ctx.builder.create<UnpackSubelementsOp>(
op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
});
} break;
case VectorLayout::ImplicitDim::kMinor:
return op.emitOpError(
"Not implemented: Only casts of lane-oriented values supported");
case VectorLayout::ImplicitDim::kSecondMinor: {
if (input_vregs.dimensions() != absl::Span<const int64_t>{1} ||
output_vregs.dimensions() != absl::Span<const int64_t>{1}) {
return op.emitOpError("Not implemented");
}
if (layout_in.offsets()[0] >= ctx.target_shape[0]) {
return op.emitOpError("Not implemented");
}
auto unpack_subelements_op = ctx.builder.create<UnpackSubelementsOp>(
op.getLoc(), res_vreg_ty, *input_vregs.begin(), 0);
output_vregs.Fill(unpack_subelements_op.getResult());
}
}
op.replaceAllUsesWith(
assemble(ctx, result_ty, layout_out, std::move(output_vregs))
.getResult());
op.erase();
return success();
}
LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK(layouts_in.front().has_value());
CHECK(layouts_out.front().has_value());
auto extf_op = cast<arith::ExtFOp>(op);
if (layouts_in.front()->bitwidth() != 32 ||
layouts_out.front()->bitwidth() != 32) {
return op.emitOpError("Only 16-bit to 32-bit conversion supported");
}
return ext_op_rule_impl(ctx, extf_op, *layouts_in.front(),
*layouts_out.front());
}
LogicalResult arith_extsi_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK(layouts_in.front().has_value());
CHECK_EQ(layouts_out.size(), 1);
CHECK(layouts_out.front().has_value());
auto extsi_op = cast<arith::ExtSIOp>(op);
return ext_op_rule_impl(ctx, extsi_op, *layouts_in.front(),
*layouts_out.front());
}
template <typename OpTy>
LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
const VectorLayout &layout_in,
const VectorLayout &layout_out) {
auto result_ty = cast<VectorType>(op.getResult().getType());
FAILUREOR_ASSIGN_OR_RETURN(const xla::Array<Value> input_vregs,
disassemble(ctx, layout_in, op.getIn()));
xla::Array<Value> output_vregs(
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
if (layout_in.bitwidth() != 32) {
return op.emitOpError("Only 32-bit truncation supported");
}
FAILUREOR_ASSIGN_OR_RETURN(
VectorType res_vreg_ty,
getNativeVregType(result_ty.getElementType(), ctx.target_shape));
if (layout_in.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
layout_out.implicit_dim() == VectorLayout::ImplicitDim::kNone) {
if (layout_in.tiling() != ctx.target_shape) {
return op.emitOpError("Not implemented: Only (8,128) tiling supported");
}
if (layout_out.tiling() == ctx.target_shape) {
const int packing = layout_out.packing();
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
SmallVector<Value> parts;
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
idxs_local.back() *= packing;
for (int64_t i = 0; i < packing; ++i) {
parts.push_back(input_vregs(idxs_local));
// Pack any data lying around if OOB
if (idxs_local.back() < input_vregs.dimensions().back() - 1) {
++idxs_local.back();
}
}
*v = ctx.builder.create<PackSubelementsOp>(op.getLoc(), res_vreg_ty,
parts);
});
} else if (layout_out.bitwidth() == 16 &&
layout_out.tiling() ==
std::array<int64_t, 2>{2 * ctx.target_shape[0],
ctx.target_shape[1]}) {
output_vregs.Each([&](absl::Span<const int64_t> idxs, Value *v) {
// TODO(tlongeri): should probably express as a multiple of target_shape
// instead of (16, 128)
CHECK_GE(idxs.size(), 2);
SmallVector<int64_t> idxs_local(toArrayRef(idxs));
idxs_local[idxs.size() - 2] *= 2;
const Value first = input_vregs(idxs_local);
Value second;
if (idxs[idxs.size() - 2] * 2 + 1 ==
input_vregs.dim(input_vregs.num_dimensions() - 2)) {
second = first;
} else {
idxs_local[idxs.size() - 2] += 1;
second = input_vregs(idxs_local);
}
*v = ctx.builder.create<PackSubelementsOp>(
op.getLoc(), res_vreg_ty, ArrayRef<Value>{first, second});
});
} else {
return op.emitOpError("Not implemented");
}
op.replaceAllUsesWith(
assemble(ctx, result_ty, layout_out, std::move(output_vregs))
.getResult());
op.erase();
return success();
}
// TODO(tlongeri): why wasn't this part of the original code?
return op.emitOpError("Not implemented");
}
LogicalResult arith_truncf_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK(layouts_in.front().has_value());
CHECK_EQ(layouts_out.size(), 1);
CHECK(layouts_out.front().has_value());
auto truncf_op = cast<arith::TruncFOp>(op);
if (layouts_in.front()->bitwidth() != 32 ||
layouts_out.front()->bitwidth() != 16) {
return op.emitOpError(
"Not implemented: Only 32-bit to 16-bit conversion supported");
}
return trunc_op_rule_impl(ctx, truncf_op, *layouts_in.front(),
*layouts_out.front());
}
LogicalResult arith_trunci_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK(layouts_in.front().has_value());
CHECK_EQ(layouts_out.size(), 1);
CHECK(layouts_out.front().has_value());
auto trunci_op = cast<arith::TruncIOp>(op);
return trunc_op_rule_impl(ctx, trunci_op, *layouts_in.front(),
*layouts_out.front());
}
LogicalResult tpu_load_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_out.size(), 1);
if (llvm::any_of(layouts_in,
[&](const Layout &l) { return l.has_value(); })) {
return op.emitOpError("Expected null input layouts");
}
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const VectorLayout &layout_out = *layouts_out.front();
// We expect the result is already a native-sized vreg.
// TODO(b/300493694): Support other bitwidths
if (layout_out.bitwidth() != 32) {
return op.emitOpError("Not implemented: Only 32-bit loads supported");
}
tpu::LoadOp load_op = cast<tpu::LoadOp>(op);
if (layout_out != VectorLayout(32, {0, 0}, ctx.target_shape,
VectorLayout::ImplicitDim::kNone)) {
return op.emitOpError("Invalid output layout for ") << load_op->getName();
}
FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<int64_t> indices,
getIntConstsFromOperandRange(load_op.getIndices()));
CHECK_EQ(indices.size(), 2);
if (indices[1] % ctx.target_shape[1] != 0) {
return op.emitOpError("Not implemented: Lane index is not a multiple of ")
<< ctx.target_shape[1];
}
const RollVectorsOp roll_vectors_op = assemble(
ctx, load_op.getResult().getType(), layout_out, {{load_op.getResult()}});
load_op->replaceUsesWithIf(roll_vectors_op, [&](OpOperand &operand) {
return operand.getOwner() != roll_vectors_op;
});
return success();
}
LogicalResult tpu_store_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_out.size(), 0);
if (llvm::any_of(layouts_in.drop_front(),
[&](const Layout &l) { return l.has_value(); })) {
return op.emitOpError("Expected null layouts for tpu.store indices");
}
if (!layouts_in.front().has_value()) {
return op.emitOpError("Expected non-null layout for tpu.store base");
}
const VectorLayout &to_store_layout = *layouts_in.front();
// We expect the value to store is already a native-sized vreg.
if (to_store_layout.bitwidth() != 32) {
return op.emitOpError("Not implemented: Only 32-bit loads supported");
}
CHECK(to_store_layout == VectorLayout(32, {0, 0}, ctx.target_shape,
VectorLayout::ImplicitDim::kNone));
tpu::StoreOp store_op = cast<tpu::StoreOp>(op);
FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<int64_t> indices,
getIntConstsFromOperandRange(store_op.getIndices()));
CHECK_EQ(indices.size(), 2);
if (indices[1] % ctx.target_shape[1] != 0) {
return op.emitOpError("Not implemented: Lane index is not a multiple of ")
<< ctx.target_shape[1];
}
FAILUREOR_ASSIGN_OR_RETURN(
xla::Array<Value> tiles,
disassemble(ctx, to_store_layout, store_op.getValueToStore()));
CHECK((tiles.dimensions() == xla::DimensionVector{1, 1}));
store_op.getValueToStoreMutable().assign(tiles({0, 0}));
return success();
}
LogicalResult tpu_trace_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
if (op.getNumOperands() != 0 || op.getNumResults() != 0) {
return op.emitOpError(
"Not implemented: tpu.traced_block with inputs or outputs");
}
CHECK_EQ(layouts_in.size(), 0);
CHECK_EQ(layouts_out.size(), 0);
// We don't modify the op, but we do rewrite the branch bodies.
CHECK_EQ(op.getNumRegions(), 1);
Region ®ion = op.getRegion(0);
CHECK(region.hasOneBlock());
Block &block = region.front();
return applyLayoutBlock(ctx, block);
}
LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 0);
CHECK_EQ(layouts_out.size(), 1);
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const VectorLayout &layout_out = *layouts_out.front();
tpu::IotaOp iota_op = cast<tpu::IotaOp>(op);
VectorType vty = iota_op.getResult().getType();
if (const auto int_ty = dyn_cast<IntegerType>(vty.getElementType());
int_ty == nullptr || int_ty.getWidth() != 32) {
return iota_op.emitOpError("Not implemented: Only 32-bit Iota supported");
}
FAILUREOR_ASSIGN_OR_RETURN(
const auto native_vreg_ty,
getNativeVregType(vty.getElementType(), ctx.target_shape));
if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
return op.emitOpError("Not implemented: Only 2D layouts supported");
}
const SmallVector<int64_t> tile_array_shape =
layout_out.tileArrayShape(vty.getShape(), ctx.target_shape);
const std::optional<int32_t> dimension = iota_op.getDimension();
if (!dimension.has_value()) {
return op.emitOpError("Not implemented: null dimension");
}
if (*dimension == vty.getRank() - 1) {
if (layout_out.offsets()[1] != 0) {
return op.emitOpError("Not implemented: Unsupported offset");
}
const int64_t num_tiles = tile_array_shape[tile_array_shape.size() - 1];
SmallVector<Value> tiles(num_tiles);
auto vreg_iota = ctx.builder.create<tpu::IotaOp>(
op.getLoc(), native_vreg_ty,
/*dimension =*/ctx.builder.getI32IntegerAttr(1));
for (int64_t i = 0; i < num_tiles; ++i) {
auto offset = ctx.builder.create<arith::ConstantOp>(
op.getLoc(), native_vreg_ty,
DenseElementsAttr::get(
native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 1))));
tiles[i] =
ctx.builder.create<arith::AddIOp>(op.getLoc(), vreg_iota, offset);
}
xla::Array<Value> broadcasted_tiles(tile_array_shape);
broadcasted_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = tiles[*(idxs.end() - 1)];
});
op.replaceAllUsesWith(assemble(ctx, vty, layout_out, broadcasted_tiles));
op.erase();
return success();
}
if (*dimension == vty.getRank() - 2) {
if (layout_out.offsets()[0] != 0) {
return op.emitOpError("Not implemented: Unsupported offset");
}
const int64_t num_tiles = tile_array_shape[tile_array_shape.size() - 2];
SmallVector<Value> tiles(num_tiles);
auto vreg_iota = ctx.builder.create<tpu::IotaOp>(
op.getLoc(), native_vreg_ty,
/*dimension =*/ctx.builder.getI32IntegerAttr(0));
for (int64_t i = 0; i < num_tiles; ++i) {
auto offset = ctx.builder.create<arith::ConstantOp>(
op.getLoc(), native_vreg_ty,
DenseElementsAttr::get(
native_vreg_ty,
IntegerAttr::get(vty.getElementType(),
i * *(native_vreg_ty.getShape().end() - 2))));
tiles[i] =
ctx.builder.create<arith::AddIOp>(op.getLoc(), vreg_iota, offset);
}
xla::Array<Value> broadcasted_tiles(tile_array_shape);
broadcasted_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
*v = tiles[*(idxs.end() - 2)];
});
op.replaceAllUsesWith(assemble(ctx, vty, layout_out, broadcasted_tiles));
op.erase();
return success();
}
return op.emitOpError("Not implemented: Unsupported dimension");
}
LogicalResult tpu_gather_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK_EQ(layouts_out.size(), 1);
if (!layouts_in.front().has_value()) {
return op.emitOpError("Expected non-null input layout");
}
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const VectorLayout &layout_in = *layouts_in.front();
const VectorLayout &layout_out = *layouts_out.front();
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone ||
layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone ||
layout_in.offsets() != layout_out.offsets() ||
llvm::any_of(layout_in.offsets(), [&](const LayoutOffset o) {
return o.has_value() && o != 0;
})) {
return op.emitOpError("Not implemented: Only 2D layouts supported");
}
auto gather_op = cast<tpu::GatherOp>(op);
const VectorType vty = gather_op.getResult().getType();
const uint32_t dimension = gather_op.getDimension();
if (dimension + 2 < vty.getRank()) {
return op.emitOpError("Not implemented: Unsupported dimension");
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> in_tiles,
disassemble(ctx, layout_in, gather_op.getSource()));
const int64_t width = ctx.target_shape[2 - (vty.getRank() - dimension)];
const ArrayRef<int32_t> indices(gather_op.getIndices());
auto [num_sections, rem] = std::div(indices.size(), width);
SmallVector<int32_t> segment_indices;
if (rem == 0) {
for (int64_t i = 0; i < width; ++i) {
const int64_t offset = i - i % width;
if (!(offset <= indices[i] && indices[i] < offset + width)) {
return op.emitOpError("Not implemented: Cross-segment gather");
}
}
for (int64_t i = width; i < indices.size(); ++i) {
const int64_t offset = i - i % width;
if (indices[i] != indices[i % width] + offset) {
return op.emitOpError(
"Not implemented: Indices varying between segments");
}
}
segment_indices.assign(indices.begin(), indices.begin() + width);
} else if (num_sections == 0) { // Only one vreg.
segment_indices.assign(indices.begin(), indices.end());
segment_indices.append(width - indices.size(), 0);
} else {
return op.emitOpError("Not implemented: Not a multiple of target length");
}
xla::Array<Value> out_tiles(in_tiles.dimensions());
if (dimension == vty.getRank() - 1) {
// TODO(b/265133497): Remove the broadcast once 2nd minor works.
const auto dyn_ix_ty =
VectorType::get(ctx.target_shape, ctx.builder.getI32Type());
// Broadcast indices to target_shape
SmallVector<int32_t> dyn_ix_val;
for (int64_t i = 0; i < ctx.target_shape[0]; ++i) { // Broadcast
dyn_ix_val.append(segment_indices);
}
FAILUREOR_ASSIGN_OR_RETURN(
const BlockArgument dyn_ix_ref,
appendConstant(ctx, DenseIntElementsAttr::get(dyn_ix_ty, dyn_ix_val)));
auto all_sublanes = ctx.builder.getAttr<DenseBoolArrayAttr>(
SmallVector<bool>(ctx.target_shape[1], true));
auto dyn_ix = ctx.builder.create<tpu::LoadOp>(
op.getLoc(), dyn_ix_ty, dyn_ix_ref,
SmallVector<Value>(2, IdxConst(0, ctx.builder, op.getLoc())),
/*sublane_mask=*/all_sublanes, /*sublane_stride=*/nullptr);
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
const Value in_tile = in_tiles(idxs);
*v = ctx.builder.create<tpu::DynamicGatherOp>(
op.getLoc(), in_tile.getType(), in_tile, dyn_ix, 1);
});
} else {
CHECK_EQ(dimension, vty.getRank() - 2);
const auto segment_indices_attr =
ctx.builder.getAttr<DenseI32ArrayAttr>(segment_indices);
out_tiles.Each([&](absl::Span<const int64_t> idxs, Value *v) {
const Value in_tile = in_tiles(idxs);
*v = ctx.builder.create<tpu::GatherOp>(op.getLoc(), in_tile.getType(),
in_tile, segment_indices_attr, 0);
});
}
gather_op.replaceAllUsesWith(
assemble(ctx, vty, layout_out, out_tiles).getOperation());
gather_op.erase();
return success();
}
LogicalResult tpu_repeat_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_in.size(), 1);
CHECK_EQ(layouts_out.size(), 1);
if (!layouts_in.front().has_value()) {
return op.emitOpError("Expected non-null input layout");
}
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const VectorLayout &layout_in = *layouts_in.front();
const VectorLayout &layout_out = *layouts_out.front();
if (layout_in.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
return op.emitOpError("Not implemented: Only 2D layouts supported");
}
if (layout_in != layout_out) {
return op.emitOpError("Not implemented: Changing layout mid-repeat");
}
if (!layout_in.hasNaturalTopology(ctx.target_shape) ||
layout_in.offsets() != LayoutOffsets{0, 0}) {
return op.emitOpError("Not implemented: Non-trivial layouts unsupported");
}
tpu::RepeatOp repeat_op = cast<tpu::RepeatOp>(op);
VectorType src_ty = repeat_op.getSource().getType();
const uint32_t dim = repeat_op.getDimension();
if (dim != src_ty.getRank() - 1) {
return op.emitOpError(
"Not implemented: Only repeats along the last dim supported");
}
if (src_ty.getShape().back() % ctx.target_shape.back() != 0) {
return op.emitOpError("Not implemented: Only free repeats are suppported");
}
FAILUREOR_ASSIGN_OR_RETURN(
const xla::Array<Value> &in_vregs,
disassemble(ctx, layout_in, repeat_op.getSource()));
xla::Array<Value> out_vregs = repeat(in_vregs, repeat_op.getTimes(), dim);
repeat_op->replaceAllUsesWith(
assemble(ctx, repeat_op.getResult().getType(), layout_out, out_vregs));
repeat_op->erase();
return success();
}
LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
const ArrayRef<Layout> layouts_in,
const ArrayRef<Layout> layouts_out) {
CHECK_EQ(layouts_out.size(), 1);
MLIRContext *const mlir_ctx = op.getContext();
if (llvm::any_of(layouts_in,
[&](const Layout &l) { return l.has_value(); })) {
return op.emitOpError("Expected null input layouts");
}
if (!layouts_out.front().has_value()) {
return op.emitOpError("Expected non-null output layout");
}
const VectorLayout &layout_out = *layouts_out.front();
auto load_op = cast<vector::LoadOp>(op);
const auto memref_ty = cast<MemRefType>(load_op.getBase().getType());
const auto vty = cast<VectorType>(load_op.getResult().getType());
FAILUREOR_ASSIGN_OR_RETURN(
VectorType target_ty,
getNativeVregType(vty.getElementType(), ctx.target_shape));
if (layout_out.implicit_dim() == VectorLayout::ImplicitDim::kMinor) {
return op.emitOpError("Not implemented");
}
const bool is_1d =
layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone;
using Tiling = std::array<int64_t, 2>; // To avoid comma in macro
FAILUREOR_ASSIGN_OR_RETURN(
Tiling memref_tiling,
getMemRefTiling(load_op.getBase(), ctx.target_shape));
if (layout_out.tiling() != memref_tiling) {
// Now we can handle the case when tiling is (1, TARGET_SHAPE.lanes).
// TODO(b/295393167): need to support strided load for bitwidth < 32.
if (layout_out.bitwidth() != 32 ||
layout_out.tiling() != std::array<int64_t, 2>{1, ctx.target_shape[1]}) {
return op.emitOpError("Not implemented");
}
}
// TODO(apaszke): Check that loads are from vmem!
FAILUREOR_ASSIGN_OR_RETURN(
const SmallVector<int64_t> indices,
getIntConstsFromOperandRange(load_op.getIndices()));
if (llvm::any_of(
llvm::zip_equal(indices, vty.getShape(), memref_ty.getShape()),
[](auto tup) {
auto [idx, n, extent] = tup;
return idx + n > extent;
})) {
return op.emitOpError("Reading out of bounds");
}
const SmallVector<int64_t> implicit_shape =
layout_out.implicitShape(vty.getShape());
const int64_t ss = implicit_shape[implicit_shape.size() - 2];
int64_t sublane_stride = 1;
if (layout_out.bitwidth() == 32 &&
layout_out.tiling() == std::array<int64_t, 2>{1, ctx.target_shape[1]} &&
ss == 1) {
sublane_stride = memref_tiling[0];
}
const LayoutOffsets offsets = layout_out.offsets();
AffineMap load_map;
arith::ConstantOp padding;
if (offsets[1] == std::nullopt) {
return op.emitOpError("Load replicated along lanes is unsupported");
}
if (offsets[0] == std::nullopt) {
if (ss != 1) {
return op.emitOpError(
"Sublane-replicated load with size > 1 is unsupported");
}
if (!layout_out.hasNativeTiling(ctx.target_shape)) {
return op.emitOpError("Not implemented");
}
// affine_map<(..., j) -> (0, j)
load_map =
AffineMap::get(memref_ty.getRank(), 0,
{getAffineConstantExpr(0, mlir_ctx),
getAffineDimExpr(memref_ty.getRank() - 1, mlir_ctx)},
mlir_ctx);
FAILUREOR_ASSIGN_OR_RETURN(const TypedAttr zero_attr,
getZeroIntOrFloatAttr(vty.getElementType()));
padding = ctx.builder.create<arith::ConstantOp>(
load_op.getLoc(), vty.getElementType(), zero_attr);
}
xla::Array<Value> tiles(
layout_out.tileArrayShape(vty.getShape(), ctx.target_shape));
const std::array<int64_t, 2> vreg_slice =
layout_out.vregSlice(ctx.target_shape);