-
Notifications
You must be signed in to change notification settings - Fork 11.6k
/
OpDefinition.h
1499 lines (1304 loc) · 56.3 KB
/
OpDefinition.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
//===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements helper classes for implementing the "Op" types. This
// includes the Op type, which is the base class for Op class definitions,
// as well as number of traits in the OpTrait namespace that provide a
// declarative way to specify properties of Ops.
//
// The purpose of these types are to allow light-weight implementation of
// concrete ops (like DimOp) with very little boilerplate.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_OPDEFINITION_H
#define MLIR_IR_OPDEFINITION_H
#include "mlir/IR/Operation.h"
#include "llvm/Support/PointerLikeTypeTraits.h"
#include <type_traits>
namespace mlir {
class Builder;
class OpBuilder;
namespace OpTrait {
template <typename ConcreteType> class OneResult;
}
/// This class represents success/failure for operation parsing. It is
/// essentially a simple wrapper class around LogicalResult that allows for
/// explicit conversion to bool. This allows for the parser to chain together
/// parse rules without the clutter of "failed/succeeded".
class ParseResult : public LogicalResult {
public:
ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
// Allow diagnostics emitted during parsing to be converted to failure.
ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {}
ParseResult(const Diagnostic &) : LogicalResult(failure()) {}
/// Failure is true in a boolean context.
explicit operator bool() const { return failed(*this); }
};
/// This class implements `Optional` functionality for ParseResult. We don't
/// directly use Optional here, because it provides an implicit conversion
/// to 'bool' which we want to avoid. This class is used to implement tri-state
/// 'parseOptional' functions that may have a failure mode when parsing that
/// shouldn't be attributed to "not present".
class OptionalParseResult {
public:
OptionalParseResult() = default;
OptionalParseResult(LogicalResult result) : impl(result) {}
OptionalParseResult(ParseResult result) : impl(result) {}
OptionalParseResult(const InFlightDiagnostic &)
: OptionalParseResult(failure()) {}
OptionalParseResult(llvm::NoneType) : impl(llvm::None) {}
/// Returns true if we contain a valid ParseResult value.
bool hasValue() const { return impl.hasValue(); }
/// Access the internal ParseResult value.
ParseResult getValue() const { return impl.getValue(); }
ParseResult operator*() const { return getValue(); }
private:
Optional<ParseResult> impl;
};
// These functions are out-of-line utilities, which avoids them being template
// instantiated/duplicated.
namespace impl {
/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
/// region's only block if it does not have a terminator already. If the region
/// is empty, insert a new block first. `buildTerminatorOp` should return the
/// terminator operation to insert.
void ensureRegionTerminator(
Region ®ion, OpBuilder &builder, Location loc,
function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
void ensureRegionTerminator(
Region ®ion, Builder &builder, Location loc,
function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
} // namespace impl
/// This is the concrete base class that holds the operation pointer and has
/// non-generic methods that only depend on State (to avoid having them
/// instantiated on template types that don't affect them.
///
/// This also has the fallback implementations of customization hooks for when
/// they aren't customized.
class OpState {
public:
/// Ops are pointer-like, so we allow implicit conversion to bool.
operator bool() { return getOperation() != nullptr; }
/// This implicitly converts to Operation*.
operator Operation *() const { return state; }
/// Return the operation that this refers to.
Operation *getOperation() { return state; }
/// Returns the closest surrounding operation that contains this operation
/// or nullptr if this is a top-level operation.
Operation *getParentOp() { return getOperation()->getParentOp(); }
/// Return the closest surrounding parent operation that is of type 'OpTy'.
template <typename OpTy> OpTy getParentOfType() {
return getOperation()->getParentOfType<OpTy>();
}
/// Returns the closest surrounding parent operation with trait `Trait`.
template <template <typename T> class Trait>
Operation *getParentWithTrait() {
return getOperation()->getParentWithTrait<Trait>();
}
/// Return the context this operation belongs to.
MLIRContext *getContext() { return getOperation()->getContext(); }
/// Print the operation to the given stream.
void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
state->print(os, flags);
}
void print(raw_ostream &os, AsmState &asmState,
OpPrintingFlags flags = llvm::None) {
state->print(os, asmState, flags);
}
/// Dump this operation.
void dump() { state->dump(); }
/// The source location the operation was defined or derived from.
Location getLoc() { return state->getLoc(); }
void setLoc(Location loc) { state->setLoc(loc); }
/// Return all of the attributes on this operation.
ArrayRef<NamedAttribute> getAttrs() { return state->getAttrs(); }
/// A utility iterator that filters out non-dialect attributes.
using dialect_attr_iterator = Operation::dialect_attr_iterator;
using dialect_attr_range = Operation::dialect_attr_range;
/// Return a range corresponding to the dialect attributes for this operation.
dialect_attr_range getDialectAttrs() { return state->getDialectAttrs(); }
dialect_attr_iterator dialect_attr_begin() {
return state->dialect_attr_begin();
}
dialect_attr_iterator dialect_attr_end() { return state->dialect_attr_end(); }
/// Return an attribute with the specified name.
Attribute getAttr(StringRef name) { return state->getAttr(name); }
/// If the operation has an attribute of the specified type, return it.
template <typename AttrClass> AttrClass getAttrOfType(StringRef name) {
return getAttr(name).dyn_cast_or_null<AttrClass>();
}
/// If the an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
void setAttr(Identifier name, Attribute value) {
state->setAttr(name, value);
}
void setAttr(StringRef name, Attribute value) {
setAttr(Identifier::get(name, getContext()), value);
}
/// Set the attributes held by this operation.
void setAttrs(ArrayRef<NamedAttribute> attributes) {
state->setAttrs(attributes);
}
void setAttrs(MutableDictionaryAttr newAttrs) { state->setAttrs(newAttrs); }
/// Set the dialect attributes for this operation, and preserve all dependent.
template <typename DialectAttrs> void setDialectAttrs(DialectAttrs &&attrs) {
state->setDialectAttrs(std::move(attrs));
}
/// Remove the attribute with the specified name if it exists. The return
/// value indicates whether the attribute was present or not.
MutableDictionaryAttr::RemoveResult removeAttr(Identifier name) {
return state->removeAttr(name);
}
MutableDictionaryAttr::RemoveResult removeAttr(StringRef name) {
return state->removeAttr(Identifier::get(name, getContext()));
}
/// Return true if there are no users of any results of this operation.
bool use_empty() { return state->use_empty(); }
/// Remove this operation from its parent block and delete it.
void erase() { state->erase(); }
/// Emit an error with the op name prefixed, like "'dim' op " which is
/// convenient for verifiers.
InFlightDiagnostic emitOpError(const Twine &message = {});
/// Emit an error about fatal conditions with this operation, reporting up to
/// any diagnostic handlers that may be listening.
InFlightDiagnostic emitError(const Twine &message = {});
/// Emit a warning about this operation, reporting up to any diagnostic
/// handlers that may be listening.
InFlightDiagnostic emitWarning(const Twine &message = {});
/// Emit a remark about this operation, reporting up to any diagnostic
/// handlers that may be listening.
InFlightDiagnostic emitRemark(const Twine &message = {});
/// Walk the operation in postorder, calling the callback for each nested
/// operation(including this one).
/// See Operation::walk for more details.
template <typename FnT, typename RetT = detail::walkResultType<FnT>>
RetT walk(FnT &&callback) {
return state->walk(std::forward<FnT>(callback));
}
// These are default implementations of customization hooks.
public:
/// This hook returns any canonicalization pattern rewrites that the operation
/// supports, for use by the canonicalization pass.
static void getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {}
protected:
/// If the concrete type didn't implement a custom verifier hook, just fall
/// back to this one which accepts everything.
LogicalResult verify() { return success(); }
/// Unless overridden, the custom assembly form of an op is always rejected.
/// Op implementations should implement this to return failure.
/// On success, they should fill in result with the fields to use.
static ParseResult parse(OpAsmParser &parser, OperationState &result);
// The fallback for the printer is to print it the generic assembly form.
void print(OpAsmPrinter &p);
/// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
/// so we can cast it away here.
explicit OpState(Operation *state) : state(state) {}
private:
Operation *state;
};
// Allow comparing operators.
inline bool operator==(OpState lhs, OpState rhs) {
return lhs.getOperation() == rhs.getOperation();
}
inline bool operator!=(OpState lhs, OpState rhs) {
return lhs.getOperation() != rhs.getOperation();
}
/// This class represents a single result from folding an operation.
class OpFoldResult : public PointerUnion<Attribute, Value> {
using PointerUnion<Attribute, Value>::PointerUnion;
};
/// Allow printing to a stream.
inline raw_ostream &operator<<(raw_ostream &os, OpState &op) {
op.print(os, OpPrintingFlags().useLocalScope());
return os;
}
/// This template defines the foldHook as used by AbstractOperation.
///
/// The default implementation uses a general fold method that can be defined on
/// custom ops which can return multiple results.
template <typename ConcreteType, bool isSingleResult, typename = void>
class FoldingHook {
public:
/// This is an implementation detail of the constant folder hook for
/// AbstractOperation.
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return cast<ConcreteType>(op).fold(operands, results);
}
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
/// the Builder::createOrFold API and the canonicalization pass.
///
/// This is an intentionally limited interface - implementations of this hook
/// can only perform the following changes to the operation:
///
/// 1. They can leave the operation alone and without changing the IR, and
/// return failure.
/// 2. They can mutate the operation in place, without changing anything else
/// in the IR. In this case, return success.
/// 3. They can return a list of existing values that can be used instead of
/// the operation. In this case, fill in the results list and return
/// success. The caller will remove the operation and use those results
/// instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
/// generalized constant folding.
///
/// If not overridden, this fallback implementation always fails to fold.
///
LogicalResult fold(ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
return failure();
}
};
/// This template specialization defines the foldHook as used by
/// AbstractOperation for single-result operations. This gives the hook a nicer
/// signature that is easier to implement.
template <typename ConcreteType, bool isSingleResult>
class FoldingHook<ConcreteType, isSingleResult,
typename std::enable_if<isSingleResult>::type> {
public:
/// If the operation returns a single value, then the Op can be implicitly
/// converted to an Value. This yields the value of the only result.
operator Value() {
return static_cast<ConcreteType *>(this)->getOperation()->getResult(0);
}
/// This is an implementation detail of the constant folder hook for
/// AbstractOperation.
static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
SmallVectorImpl<OpFoldResult> &results) {
auto result = cast<ConcreteType>(op).fold(operands);
if (!result)
return failure();
// Check if the operation was folded in place. In this case, the operation
// returns itself.
if (result.template dyn_cast<Value>() != op->getResult(0))
results.push_back(result);
return success();
}
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by
/// the Builder::createOrFold API and the canonicalization pass.
///
/// This is an intentionally limited interface - implementations of this hook
/// can only perform the following changes to the operation:
///
/// 1. They can leave the operation alone and without changing the IR, and
/// return nullptr.
/// 2. They can mutate the operation in place, without changing anything else
/// in the IR. In this case, return the operation itself.
/// 3. They can return an existing SSA value that can be used instead of
/// the operation. In this case, return that value. The caller will
/// remove the operation and use that result instead.
///
/// This allows expression of some simple in-place canonicalizations (e.g.
/// "x+0 -> x", "min(x,y,x,z) -> min(x,y,z)", "x+y-x -> y", etc), as well as
/// generalized constant folding.
///
/// If not overridden, this fallback implementation always fails to fold.
///
OpFoldResult fold(ArrayRef<Attribute> operands) { return {}; }
};
//===----------------------------------------------------------------------===//
// Operation Trait Types
//===----------------------------------------------------------------------===//
namespace OpTrait {
// These functions are out-of-line implementations of the methods in the
// corresponding trait classes. This avoids them being template
// instantiated/duplicated.
namespace impl {
LogicalResult verifyZeroOperands(Operation *op);
LogicalResult verifyOneOperand(Operation *op);
LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
LogicalResult verifyOperandsAreFloatLike(Operation *op);
LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
LogicalResult verifySameTypeOperands(Operation *op);
LogicalResult verifyZeroRegion(Operation *op);
LogicalResult verifyOneRegion(Operation *op);
LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
LogicalResult verifyZeroResult(Operation *op);
LogicalResult verifyOneResult(Operation *op);
LogicalResult verifyNResults(Operation *op, unsigned numOperands);
LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
LogicalResult verifySameOperandsShape(Operation *op);
LogicalResult verifySameOperandsAndResultShape(Operation *op);
LogicalResult verifySameOperandsElementType(Operation *op);
LogicalResult verifySameOperandsAndResultElementType(Operation *op);
LogicalResult verifySameOperandsAndResultType(Operation *op);
LogicalResult verifyResultsAreBoolLike(Operation *op);
LogicalResult verifyResultsAreFloatLike(Operation *op);
LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
LogicalResult verifyIsTerminator(Operation *op);
LogicalResult verifyZeroSuccessor(Operation *op);
LogicalResult verifyOneSuccessor(Operation *op);
LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
} // namespace impl
/// Helper class for implementing traits. Clients are not expected to interact
/// with this directly, so its members are all protected.
template <typename ConcreteType, template <typename> class TraitType>
class TraitBase {
protected:
/// Return the ultimate Operation being worked on.
Operation *getOperation() {
// We have to cast up to the trait type, then to the concrete type, then to
// the BaseState class in explicit hops because the concrete type will
// multiply derive from the (content free) TraitBase class, and we need to
// be able to disambiguate the path for the C++ compiler.
auto *trait = static_cast<TraitType<ConcreteType> *>(this);
auto *concrete = static_cast<ConcreteType *>(trait);
auto *base = static_cast<OpState *>(concrete);
return base->getOperation();
}
/// Provide default implementations of trait hooks. This allows traits to
/// provide exactly the overrides they care about.
static LogicalResult verifyTrait(Operation *op) { return success(); }
static AbstractOperation::OperationProperties getTraitProperties() {
return 0;
}
};
//===----------------------------------------------------------------------===//
// Operand Traits
namespace detail {
/// Utility trait base that provides accessors for derived traits that have
/// multiple operands.
template <typename ConcreteType, template <typename> class TraitType>
struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
using operand_iterator = Operation::operand_iterator;
using operand_range = Operation::operand_range;
using operand_type_iterator = Operation::operand_type_iterator;
using operand_type_range = Operation::operand_type_range;
/// Return the number of operands.
unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
/// Return the operand at index 'i'.
Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
/// Set the operand at index 'i' to 'value'.
void setOperand(unsigned i, Value value) {
this->getOperation()->setOperand(i, value);
}
/// Operand iterator access.
operand_iterator operand_begin() {
return this->getOperation()->operand_begin();
}
operand_iterator operand_end() { return this->getOperation()->operand_end(); }
operand_range getOperands() { return this->getOperation()->getOperands(); }
/// Operand type access.
operand_type_iterator operand_type_begin() {
return this->getOperation()->operand_type_begin();
}
operand_type_iterator operand_type_end() {
return this->getOperation()->operand_type_end();
}
operand_type_range getOperandTypes() {
return this->getOperation()->getOperandTypes();
}
};
} // end namespace detail
/// This class provides the API for ops that are known to have no
/// SSA operand.
template <typename ConcreteType>
class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyZeroOperands(op);
}
private:
// Disable these.
void getOperand() {}
void setOperand() {}
};
/// This class provides the API for ops that are known to have exactly one
/// SSA operand.
template <typename ConcreteType>
class OneOperand : public TraitBase<ConcreteType, OneOperand> {
public:
Value getOperand() { return this->getOperation()->getOperand(0); }
void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOneOperand(op);
}
};
/// This class provides the API for ops that are known to have a specified
/// number of operands. This is used as a trait like this:
///
/// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
///
template <unsigned N> class NOperands {
public:
static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
template <typename ConcreteType>
class Impl
: public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyNOperands(op, N);
}
};
};
/// This class provides the API for ops that are known to have a at least a
/// specified number of operands. This is used as a trait like this:
///
/// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
///
template <unsigned N> class AtLeastNOperands {
public:
template <typename ConcreteType>
class Impl : public detail::MultiOperandTraitBase<ConcreteType,
AtLeastNOperands<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyAtLeastNOperands(op, N);
}
};
};
/// This class provides the API for ops which have an unknown number of
/// SSA operands.
template <typename ConcreteType>
class VariadicOperands
: public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
//===----------------------------------------------------------------------===//
// Region Traits
/// This class provides verification for ops that are known to have zero
/// regions.
template <typename ConcreteType>
class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyZeroRegion(op);
}
};
namespace detail {
/// Utility trait base that provides accessors for derived traits that have
/// multiple regions.
template <typename ConcreteType, template <typename> class TraitType>
struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
using region_iterator = MutableArrayRef<Region>;
using region_range = RegionRange;
/// Return the number of regions.
unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
/// Return the region at `index`.
Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
/// Region iterator access.
region_iterator region_begin() {
return this->getOperation()->region_begin();
}
region_iterator region_end() { return this->getOperation()->region_end(); }
region_range getRegions() { return this->getOperation()->getRegions(); }
};
} // end namespace detail
/// This class provides APIs for ops that are known to have a single region.
template <typename ConcreteType>
class OneRegion : public TraitBase<ConcreteType, OneRegion> {
public:
Region &getRegion() { return this->getOperation()->getRegion(0); }
/// Returns a range of operations within the region of this operation.
auto getOps() { return getRegion().getOps(); }
template <typename OpT>
auto getOps() {
return getRegion().template getOps<OpT>();
}
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOneRegion(op);
}
};
/// This class provides the API for ops that are known to have a specified
/// number of regions.
template <unsigned N> class NRegions {
public:
static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
template <typename ConcreteType>
class Impl
: public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyNRegions(op, N);
}
};
};
/// This class provides APIs for ops that are known to have at least a specified
/// number of regions.
template <unsigned N> class AtLeastNRegions {
public:
template <typename ConcreteType>
class Impl : public detail::MultiRegionTraitBase<ConcreteType,
AtLeastNRegions<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyAtLeastNRegions(op, N);
}
};
};
/// This class provides the API for ops which have an unknown number of
/// regions.
template <typename ConcreteType>
class VariadicRegions
: public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
//===----------------------------------------------------------------------===//
// Result Traits
/// This class provides return value APIs for ops that are known to have
/// zero results.
template <typename ConcreteType>
class ZeroResult : public TraitBase<ConcreteType, ZeroResult> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyZeroResult(op);
}
};
namespace detail {
/// Utility trait base that provides accessors for derived traits that have
/// multiple results.
template <typename ConcreteType, template <typename> class TraitType>
struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
using result_iterator = Operation::result_iterator;
using result_range = Operation::result_range;
using result_type_iterator = Operation::result_type_iterator;
using result_type_range = Operation::result_type_range;
/// Return the number of results.
unsigned getNumResults() { return this->getOperation()->getNumResults(); }
/// Return the result at index 'i'.
Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
/// Replace all uses of results of this operation with the provided 'values'.
/// 'values' may correspond to an existing operation, or a range of 'Value'.
template <typename ValuesT> void replaceAllUsesWith(ValuesT &&values) {
this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
}
/// Return the type of the `i`-th result.
Type getType(unsigned i) { return getResult(i).getType(); }
/// Result iterator access.
result_iterator result_begin() {
return this->getOperation()->result_begin();
}
result_iterator result_end() { return this->getOperation()->result_end(); }
result_range getResults() { return this->getOperation()->getResults(); }
/// Result type access.
result_type_iterator result_type_begin() {
return this->getOperation()->result_type_begin();
}
result_type_iterator result_type_end() {
return this->getOperation()->result_type_end();
}
result_type_range getResultTypes() {
return this->getOperation()->getResultTypes();
}
};
} // end namespace detail
/// This class provides return value APIs for ops that are known to have a
/// single result.
template <typename ConcreteType>
class OneResult : public TraitBase<ConcreteType, OneResult> {
public:
Value getResult() { return this->getOperation()->getResult(0); }
Type getType() { return getResult().getType(); }
/// Replace all uses of 'this' value with the new value, updating anything in
/// the IR that uses 'this' to use the other value instead. When this returns
/// there are zero uses of 'this'.
void replaceAllUsesWith(Value newValue) {
getResult().replaceAllUsesWith(newValue);
}
/// Replace all uses of 'this' value with the result of 'op'.
void replaceAllUsesWith(Operation *op) {
this->getOperation()->replaceAllUsesWith(op);
}
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOneResult(op);
}
};
/// This class provides the API for ops that are known to have a specified
/// number of results. This is used as a trait like this:
///
/// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
///
template <unsigned N> class NResults {
public:
static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
template <typename ConcreteType>
class Impl
: public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyNResults(op, N);
}
};
};
/// This class provides the API for ops that are known to have at least a
/// specified number of results. This is used as a trait like this:
///
/// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
///
template <unsigned N> class AtLeastNResults {
public:
template <typename ConcreteType>
class Impl : public detail::MultiResultTraitBase<ConcreteType,
AtLeastNResults<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyAtLeastNResults(op, N);
}
};
};
/// This class provides the API for ops which have an unknown number of
/// results.
template <typename ConcreteType>
class VariadicResults
: public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
//===----------------------------------------------------------------------===//
// Terminator Traits
/// This class provides the API for ops that are known to be terminators.
template <typename ConcreteType>
class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
public:
static AbstractOperation::OperationProperties getTraitProperties() {
return static_cast<AbstractOperation::OperationProperties>(
OperationProperty::Terminator);
}
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyIsTerminator(op);
}
};
/// This class provides verification for ops that are known to have zero
/// successors.
template <typename ConcreteType>
class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyZeroSuccessor(op);
}
};
namespace detail {
/// Utility trait base that provides accessors for derived traits that have
/// multiple successors.
template <typename ConcreteType, template <typename> class TraitType>
struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
using succ_iterator = Operation::succ_iterator;
using succ_range = SuccessorRange;
/// Return the number of successors.
unsigned getNumSuccessors() {
return this->getOperation()->getNumSuccessors();
}
/// Return the successor at `index`.
Block *getSuccessor(unsigned i) {
return this->getOperation()->getSuccessor(i);
}
/// Set the successor at `index`.
void setSuccessor(Block *block, unsigned i) {
return this->getOperation()->setSuccessor(block, i);
}
/// Successor iterator access.
succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
succ_iterator succ_end() { return this->getOperation()->succ_end(); }
succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
};
} // end namespace detail
/// This class provides APIs for ops that are known to have a single successor.
template <typename ConcreteType>
class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
public:
Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
void setSuccessor(Block *succ) {
this->getOperation()->setSuccessor(succ, 0);
}
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOneSuccessor(op);
}
};
/// This class provides the API for ops that are known to have a specified
/// number of successors.
template <unsigned N>
class NSuccessors {
public:
static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2");
template <typename ConcreteType>
class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
NSuccessors<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyNSuccessors(op, N);
}
};
};
/// This class provides APIs for ops that are known to have at least a specified
/// number of successors.
template <unsigned N>
class AtLeastNSuccessors {
public:
template <typename ConcreteType>
class Impl
: public detail::MultiSuccessorTraitBase<ConcreteType,
AtLeastNSuccessors<N>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyAtLeastNSuccessors(op, N);
}
};
};
/// This class provides the API for ops which have an unknown number of
/// successors.
template <typename ConcreteType>
class VariadicSuccessors
: public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
};
//===----------------------------------------------------------------------===//
// Misc Traits
/// This class provides verification for ops that are known to have the same
/// operand shape: all operands are scalars, vectors/tensors of the same
/// shape.
template <typename ConcreteType>
class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsShape(op);
}
};
/// This class provides verification for ops that are known to have the same
/// operand and result shape: both are scalars, vectors/tensors of the same
/// shape.
template <typename ConcreteType>
class SameOperandsAndResultShape
: public TraitBase<ConcreteType, SameOperandsAndResultShape> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsAndResultShape(op);
}
};
/// This class provides verification for ops that are known to have the same
/// operand element type (or the type itself if it is scalar).
///
template <typename ConcreteType>
class SameOperandsElementType
: public TraitBase<ConcreteType, SameOperandsElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsElementType(op);
}
};
/// This class provides verification for ops that are known to have the same
/// operand and result element type (or the type itself if it is scalar).
///
template <typename ConcreteType>
class SameOperandsAndResultElementType
: public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsAndResultElementType(op);
}
};
/// This class provides verification for ops that are known to have the same
/// operand and result type.
///
/// Note: this trait subsumes the SameOperandsAndResultShape and
/// SameOperandsAndResultElementType traits.
template <typename ConcreteType>
class SameOperandsAndResultType
: public TraitBase<ConcreteType, SameOperandsAndResultType> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifySameOperandsAndResultType(op);
}
};
/// This class verifies that any results of the specified op have a boolean
/// type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyResultsAreBoolLike(op);
}
};
/// This class verifies that any results of the specified op have a floating
/// point type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class ResultsAreFloatLike
: public TraitBase<ConcreteType, ResultsAreFloatLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyResultsAreFloatLike(op);
}
};
/// This class verifies that any results of the specified op have a signless
/// integer or index type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class ResultsAreSignlessIntegerLike
: public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyResultsAreSignlessIntegerLike(op);
}
};
/// This class adds property that the operation is commutative.
template <typename ConcreteType>
class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {
public:
static AbstractOperation::OperationProperties getTraitProperties() {
return static_cast<AbstractOperation::OperationProperties>(
OperationProperty::Commutative);
}
};
/// This class verifies that all operands of the specified op have a float type,
/// a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class OperandsAreFloatLike
: public TraitBase<ConcreteType, OperandsAreFloatLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOperandsAreFloatLike(op);
}
};
/// This class verifies that all operands of the specified op have a signless
/// integer or index type, a vector thereof, or a tensor thereof.
template <typename ConcreteType>
class OperandsAreSignlessIntegerLike
: public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
public:
static LogicalResult verifyTrait(Operation *op) {
return impl::verifyOperandsAreSignlessIntegerLike(op);
}
};
/// This class verifies that all operands of the specified op have the same
/// type.
template <typename ConcreteType>