-
Notifications
You must be signed in to change notification settings - Fork 11.6k
/
OpBase.td
2969 lines (2518 loc) · 116 KB
/
OpBase.td
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
//===-- OpBase.td - Base op definition file ----------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the base operation definition file.
//
//===----------------------------------------------------------------------===//
#ifndef OP_BASE
#define OP_BASE
//===----------------------------------------------------------------------===//
// Common utilities for defining TableGen mechanisms
//===----------------------------------------------------------------------===//
// A workaround for the inability to define functions in Tablegen.
//
// The template parameter defines a string that can be extracted from an
// instance of this class by accessing the "result" member. Subclasses can take
// their own template parameters as function "arguments" and use them to
// populate result.
// For example, if it didn't already exist, a concat function could be defined
// like:
//
// class StrConcat<list<string> strings> :
// StrFunc<!foldl("", strings, prev, cur, prev # cur)>
//
// and then called like
//
// StrConcat<["a", "b", "c"]>.result
//
// to get the string "abc"
class StrFunc<string r> {
string result = r;
}
//===----------------------------------------------------------------------===//
// Predicate definitions
//===----------------------------------------------------------------------===//
// Base class for logical predicates.
//
// Predicates are used to compose constraints (see next section for details).
// There are two categories of predicates:
//
// 1. CPred: the primitive leaf predicate.
// 2. Compound predicate: a predicate composed from child predicates using
// predicate combiners ("conjunction", "disjunction", "negation" or
// "substitution").
class Pred;
// A logical predicate wrapping any C expression.
//
// This is the basis for composing more complex predicates. It is the "atom"
// predicate from the perspective of TableGen and the "interface" between
// TableGen and C++. What is inside is already C++ code, which will be treated
// as opaque strings with special placeholders to be substituted.
//
// ## Special placeholders
//
// Special placeholders can be used to refer to entities in the context where
// this predicate is used. They serve as "hooks" to the enclosing environment.
// The following special placeholders are supported in constraints for an op:
//
// * `$_builder` will be replaced by a mlir::Builder instance.
// * `$_op` will be replaced by the current operation.
// * `$_self` will be replaced with the entity this predicate is attached to.
// E.g., `BoolAttr` is an attribute constraint that wraps a
// `CPred<"$_self.isa<BoolAttr>()">` (see the following sections for details).
// Then for `F32:$attr`,`$_self` will be replaced by `$attr`.
// For type constraints, it's a little bit special since we want the
// constraints on each type definition reads naturally and we want to attach
// type constraints directly to an operand/result, $_self will be replaced
// by the operand/result's type. E.g., for `F32` in `F32:$operand`, its
// `$_self` will be expanded as `getOperand(...).getType()`.
class CPred<code pred> : Pred {
code predExpr = "(" # pred # ")";
}
// Kinds of predicate combiners. These must closely match the predicates
// implemented by the C++ backend (tblgen::PredCombinerKind).
class PredCombinerKind;
def PredCombinerAnd : PredCombinerKind;
def PredCombinerOr : PredCombinerKind;
def PredCombinerNot : PredCombinerKind;
def PredCombinerSubstLeaves : PredCombinerKind;
def PredCombinerConcat : PredCombinerKind;
// A predicate that combines other predicates as defined by PredCombinerKind.
// Instantiated below.
class CombinedPred<PredCombinerKind k, list<Pred> c> : Pred {
PredCombinerKind kind = k;
list<Pred> children = c;
}
// Predicate combiners
// A predicate that holds if all of its children hold. Always holds for zero
// children.
class And<list<Pred> children> : CombinedPred<PredCombinerAnd, children>;
// A predicate that holds if any of its children hold. Never holds for zero
// children.
class Or<list<Pred> children> : CombinedPred<PredCombinerOr, children>;
// A predicate that holds if its child does not.
class Neg<Pred child> : CombinedPred<PredCombinerNot, [child]>;
// A predicate that substitutes "pat" with "repl" in predicate calls of the
// leaves of the predicate tree (i.e., not CombinedPred).
//
// This is plain string substitution without regular expressions or captures.
// New predicates with more complex logical can be introduced should the need
// arise.
class SubstLeaves<string pat, string repl, Pred child>
: CombinedPred<PredCombinerSubstLeaves, [child]> {
string pattern = pat;
string replacement = repl;
}
// A predicate that prepends `pre` and appends `suf` to the final predicate
// string composed from `child`. This is plain string concatenation and there
// will be no substitution happening for `pre` and `suf`.
class Concat<string pre, Pred child, string suf> :
CombinedPred<PredCombinerConcat, [child]> {
string prefix = pre;
string suffix = suf;
}
//===----------------------------------------------------------------------===//
// Constraint definitions
//===----------------------------------------------------------------------===//
// TODO: Merge Constraints into Pred.
// Base class for named constraints.
//
// An op's operands/attributes/results can have various requirements, e.g.,
// having certain types, having values inside a certain range, and so on.
// Besides, for a graph rewrite rule, the source pattern used to match against
// the existing graph has conditions, like the op's operand must be of a more
// constrained subtype, the attribute must have a certain value, and so on.
//
// These requirements and conditions are modeled using this class. Records of
// this class are used to generate verification code in op verifier, and
// matching code in pattern matcher.
//
// Constraints are predicates with descriptive names, to facilitate inspection,
// provide nice error messages, etc.
class Constraint<Pred pred, string desc = ""> {
// The predicates that this constraint requires.
Pred predicate = pred;
// User-readable one line summary used in error reporting messages. If empty,
// a generic message will be used.
string summary = desc;
}
// Subclasses used to differentiate different constraint kinds. These are used
// as markers for the TableGen backend to handle different constraint kinds
// differently if needed. Constraints not deriving from the following subclasses
// are considered as uncategorized constraints.
// Subclass for constraints on a type.
class TypeConstraint<Pred predicate, string summary = "",
string cppClassNameParam = "::mlir::Type"> :
Constraint<predicate, summary> {
// The name of the C++ Type class if known, or Type if not.
string cppClassName = cppClassNameParam;
}
// Subclass for constraints on an attribute.
class AttrConstraint<Pred predicate, string summary = ""> :
Constraint<predicate, summary>;
// Subclass for constraints on a region.
class RegionConstraint<Pred predicate, string summary = ""> :
Constraint<predicate, summary>;
// Subclass for constraints on a successor.
class SuccessorConstraint<Pred predicate, string summary = ""> :
Constraint<predicate, summary>;
// How to use these constraint categories:
//
// * Use TypeConstraint to specify
// * Constraints on an op's operand/result definition
// * Further constraints to match an op's operand/result in source pattern
//
// * Use Attr (a subclass for AttrConstraint) for
// * Constraints on an op's attribute definition
// * Use AttrConstraint to specify
// * Further constraints to match an op's attribute in source pattern
//
// * Use uncategorized constraint to specify
// * Multi-entity constraints in rewrite rules
//===----------------------------------------------------------------------===//
// Common predicates
//===----------------------------------------------------------------------===//
// Whether a type is a VectorType.
def IsVectorTypePred : CPred<"$_self.isa<::mlir::VectorType>()">;
// Whether a type is a TensorType.
def IsTensorTypePred : CPred<"$_self.isa<::mlir::TensorType>()">;
// Whether a type is a MemRefType.
def IsMemRefTypePred : CPred<"$_self.isa<::mlir::MemRefType>()">;
// Whether a type is an UnrankedMemRefType
def IsUnrankedMemRefTypePred
: CPred<"$_self.isa<::mlir::UnrankedMemRefType>()">;
// Whether a type is an UnrankedTensorType
def IsUnrankedTensorTypePred
: CPred<"$_self.isa<::mlir::UnrankedTensorType>()">;
// Whether a type is a BaseMemRefType
def IsBaseMemRefTypePred
: CPred<"$_self.isa<::mlir::BaseMemRefType>()">;
// Whether a type is a ShapedType.
def IsShapedTypePred : CPred<"$_self.isa<::mlir::ShapedType>()">;
// For a ShapedType, verify that it has a static shape.
def HasStaticShapePred :
CPred<"$_self.cast<::mlir::ShapedType>().hasStaticShape()">;
// Whether a type is a TupleType.
def IsTupleTypePred : CPred<"$_self.isa<::mlir::TupleType>()">;
//===----------------------------------------------------------------------===//
// Dialect definitions
//===----------------------------------------------------------------------===//
class Dialect {
// The name of the dialect.
string name = ?;
// Short summary of the dialect.
string summary = ?;
// The description of the dialect.
string description = ?;
// A list of dialects this dialect will load on construction as dependencies.
// These are dialects that this dialect may involve in canonicalization
// pattern or interfaces.
list<string> dependentDialects = [];
// The C++ namespace that ops of this dialect should be placed into.
//
// By default, uses the name of the dialect as the only namespace. To avoid
// placing in any namespace, use "". To specify nested namespaces, use "::"
// as the delimiter, e.g., given "A::B", ops will be placed in
// `namespace A { namespace B { <ops> } }`.
//
// Note that this works in conjunction with dialect C++ code. Depending on how
// the generated files are included into the dialect, you may want to specify
// a full namespace path or a partial one.
string cppNamespace = name;
// An optional code block containing extra declarations to place in the
// dialect declaration.
code extraClassDeclaration = "";
// If this dialect overrides the hook for materializing constants.
bit hasConstantMaterializer = 0;
/// If the dialect definition provides a non-default destructor.
/// If false, a default destructor implementation will be generated.
bit hasNonDefaultDestructor = 0;
// If this dialect overrides the hook for verifying operation attributes.
bit hasOperationAttrVerify = 0;
// If this dialect overrides the hook for verifying region argument
// attributes.
bit hasRegionArgAttrVerify = 0;
// If this dialect overrides the hook for verifying region result attributes.
bit hasRegionResultAttrVerify = 0;
// If this dialect overrides the hook for op interface fallback.
bit hasOperationInterfaceFallback = 0;
// If this dialect overrides the hook for canonicalization patterns.
bit hasCanonicalizer = 0;
}
//===----------------------------------------------------------------------===//
// Type definitions
//===----------------------------------------------------------------------===//
// A type, carries type constraints.
class Type<Pred condition, string descr = "",
string cppClassName = "::mlir::Type"> :
TypeConstraint<condition, descr, cppClassName> {
string description = "";
string builderCall = "";
}
// Allows providing an alternative name and summary to an existing type def.
class TypeAlias<Type t, string summary = t.summary> :
Type<t.predicate, summary> {
let description = t.description;
let builderCall = t.builderCall;
}
// A type of a specific dialect.
class DialectType<Dialect d, Pred condition, string descr = "",
string cppClassName = "::mlir::Type"> :
Type<condition, descr, cppClassName> {
Dialect dialect = d;
}
// A variadic type constraint. It expands to zero or more of the base type. This
// class is used for supporting variadic operands/results.
class Variadic<Type type> : TypeConstraint<type.predicate, type.summary> {
Type baseType = type;
}
// A nested variadic type constraint. It expands to zero or more variadic ranges
// of the base type. This class is used for supporting variadic operands and
// results. `variadicSegmentAttrName` should correspond to the name of an
// I32ElementsAttr argument that provides the sizes of the inner variadic
// operand groups.
class VariadicOfVariadic<Type type, string variadicSegmentAttrName>
: Variadic<type> {
string segmentAttrName = variadicSegmentAttrName;
}
// An optional type constraint. It expands to either zero or one of the base
// type. This class is used for supporting optional operands/results.
class Optional<Type type> : TypeConstraint<type.predicate, type.summary> {
Type baseType = type;
}
// A type that can be constructed using MLIR::Builder.
// Note that this does not "inherit" from Type because it would require
// duplicating Type subclasses for buildable and non-buildable cases to avoid
// diamond "inheritance".
// TODO: we may extend this to a more general 'Buildable' trait, making some
// Types and some Attrs buildable.
class BuildableType<code builder> {
// The builder call to invoke (if specified) to construct the BuildableType.
code builderCall = builder;
}
// A type that's buildable iff the type passed as an argument is buildable.
// This is intended for use by types like container types, which are only
// buildable if the type of their elements is buildable.
class SameBuildabilityAs<Type type, code builder> {
code builderCall = !if(!empty(type.builderCall), "", builder);
}
// Any type at all.
def AnyType : Type<CPred<"true">, "any type">;
// None type
def NoneType : Type<CPred<"$_self.isa<::mlir::NoneType>()">, "none type",
"::mlir::NoneType">,
BuildableType<"$_builder.getType<::mlir::NoneType>()">;
// Any type from the given list
class AnyTypeOf<list<Type> allowedTypes, string summary = "",
string cppClassName = "::mlir::Type"> : Type<
// Satisfy any of the allowed type's condition
Or<!foreach(allowedtype, allowedTypes, allowedtype.predicate)>,
!if(!eq(summary, ""),
!interleave(!foreach(t, allowedTypes, t.summary), " or "),
summary),
cppClassName>;
// Integer types.
// Any integer type irrespective of its width and signedness semantics.
def AnyInteger : Type<CPred<"$_self.isa<::mlir::IntegerType>()">, "integer",
"::mlir::IntegerType">;
// Any integer type (regardless of signedness semantics) of a specific width.
class AnyI<int width>
: Type<CPred<"$_self.isInteger(" # width # ")">, width # "-bit integer"> {
int bitwidth = width;
}
class AnyIntOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, AnyI<w>),
!interleave(widths, "/") # "-bit integer",
"::mlir::IntegerType">;
def AnyI1 : AnyI<1>;
def AnyI8 : AnyI<8>;
def AnyI16 : AnyI<16>;
def AnyI32 : AnyI<32>;
def AnyI64 : AnyI<64>;
// Any signless integer type irrespective of its width.
def AnySignlessInteger : Type<
CPred<"$_self.isSignlessInteger()">, "signless integer",
"::mlir::IntegerType">;
// Signless integer type of a specific width.
class I<int width>
: Type<CPred<"$_self.isSignlessInteger(" # width # ")">,
width # "-bit signless integer", "::mlir::IntegerType">,
BuildableType<"$_builder.getIntegerType(" # width # ")"> {
int bitwidth = width;
}
class SignlessIntOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, I<w>),
!interleave(widths, "/") # "-bit signless integer">;
def I1 : I<1>;
def I8 : I<8>;
def I16 : I<16>;
def I32 : I<32>;
def I64 : I<64>;
// Any signed integer type irrespective of its width.
def AnySignedInteger : Type<
CPred<"$_self.isSignedInteger()">, "signed integer">;
// Signed integer type of a specific width.
class SI<int width>
: Type<CPred<"$_self.isSignedInteger(" # width # ")">,
width # "-bit signed integer", "::mlir::IntegerType">,
BuildableType<
"$_builder.getIntegerType(" # width # ", /*isSigned=*/true)"> {
int bitwidth = width;
}
class SignedIntOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, SI<w>),
!interleave(widths, "/") # "-bit signed integer">;
def SI1 : SI<1>;
def SI8 : SI<8>;
def SI16 : SI<16>;
def SI32 : SI<32>;
def SI64 : SI<64>;
// Any unsigned integer type irrespective of its width.
def AnyUnsignedInteger : Type<
CPred<"$_self.isUnsignedInteger()">, "unsigned integer">;
// Unsigned integer type of a specific width.
class UI<int width>
: Type<CPred<"$_self.isUnsignedInteger(" # width # ")">,
width # "-bit unsigned integer", "::mlir::IntegerType">,
BuildableType<
"$_builder.getIntegerType(" # width # ", /*isSigned=*/false)"> {
int bitwidth = width;
}
class UnsignedIntOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, UI<w>),
!interleave(widths, "/") # "-bit unsigned integer">;
def UI1 : UI<1>;
def UI8 : UI<8>;
def UI16 : UI<16>;
def UI32 : UI<32>;
def UI64 : UI<64>;
// Index type.
def Index : Type<CPred<"$_self.isa<::mlir::IndexType>()">, "index",
"::mlir::IndexType">,
BuildableType<"$_builder.getIndexType()">;
// Floating point types.
// Any float type irrespective of its width.
def AnyFloat : Type<CPred<"$_self.isa<::mlir::FloatType>()">, "floating-point",
"::mlir::FloatType">;
// Float type of a specific width.
class F<int width>
: Type<CPred<"$_self.isF" # width # "()">,
width # "-bit float", "::mlir::FloatType">,
BuildableType<"$_builder.getF" # width # "Type()"> {
int bitwidth = width;
}
class FloatOfWidths<list<int> widths> :
AnyTypeOf<!foreach(w, widths, F<w>),
!interleave(widths, "/") # "-bit float">;
def F16 : F<16>;
def F32 : F<32>;
def F64 : F<64>;
def F80 : F<80>;
def F128 : F<128>;
def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
BuildableType<"$_builder.getBF16Type()">;
class Complex<Type type>
: Type<And<[
CPred<"$_self.isa<::mlir::ComplexType>()">,
SubstLeaves<"$_self",
"$_self.cast<::mlir::ComplexType>().getElementType()",
type.predicate>]>,
"complex type with " # type.summary # " elements",
"::mlir::ComplexType">,
SameBuildabilityAs<type, "::mlir::ComplexType::get($_builder.get" # type #
"Type())"> {
Type elementType = type;
}
def AnyComplex : Type<CPred<"$_self.isa<::mlir::ComplexType>()">,
"complex-type", "::mlir::ComplexType">;
class OpaqueType<string dialect, string name, string summary>
: Type<CPred<"isOpaqueTypeWithName($_self, \""#dialect#"\", \""#name#"\")">,
summary, "::mlir::OpaqueType">,
BuildableType<"::mlir::OpaqueType::get("
"$_builder.getIdentifier(\"" # dialect # "\"), \""
# name # "\")">;
// Function Type
// Any function type.
def FunctionType : Type<CPred<"$_self.isa<::mlir::FunctionType>()">,
"function type", "::mlir::FunctionType">;
// A container type is a type that has another type embedded within it.
class ContainerType<Type etype, Pred containerPred, code elementTypeCall,
string descr, string cppClassName = "::mlir::Type"> :
// First, check the container predicate. Then, substitute the extracted
// element into the element type checker.
Type<And<[containerPred,
SubstLeaves<"$_self", !cast<string>(elementTypeCall),
etype.predicate>]>,
descr # " of " # etype.summary # " values", cppClassName> {
// The type of elements in the container.
Type elementType = etype;
// Call to retrieve.
code getElementTypeCall = elementTypeCall;
}
class ShapedContainerType<list<Type> allowedTypes,
Pred containerPred, string descr,
string cppClassName = "::mlir::Type"> :
ContainerType<AnyTypeOf<allowedTypes>, containerPred,
"$_self.cast<::mlir::ShapedType>().getElementType()", descr,
cppClassName>;
// Whether a shaped type is ranked.
def HasRankPred : CPred<"$_self.cast<::mlir::ShapedType>().hasRank()">;
// Whether a shaped type has one of the specified ranks.
class HasAnyRankOfPred<list<int> ranks> : And<[
HasRankPred,
Or<!foreach(rank, ranks,
CPred<[{$_self.cast<::mlir::ShapedType>().getRank()
== }]
# rank>)>]>;
// Vector types.
class VectorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsVectorTypePred, "vector",
"::mlir::VectorType">;
// Whether the number of elements of a vector is from the given
// `allowedRanks` list
class IsVectorOfRankPred<list<int> allowedRanks> :
And<[IsVectorTypePred,
Or<!foreach(allowedlength, allowedRanks,
CPred<[{$_self.cast<::mlir::VectorType>().getRank()
== }]
# allowedlength>)>]>;
// Any vector where the rank is from the given `allowedRanks` list
class VectorOfRank<list<int> allowedRanks> : Type<
IsVectorOfRankPred<allowedRanks>,
" of ranks " # !interleave(allowedRanks, "/"), "::mlir::VectorType">;
// Any vector where the rank is from the given `allowedRanks` list and the type
// is from the given `allowedTypes` list
class VectorOfRankAndType<list<int> allowedRanks,
list<Type> allowedTypes> : Type<
And<[VectorOf<allowedTypes>.predicate,
VectorOfRank<allowedRanks>.predicate]>,
VectorOf<allowedTypes>.summary # VectorOfRank<allowedRanks>.summary,
"::mlir::VectorType">;
// Whether the number of elements of a vector is from the given
// `allowedLengths` list
class IsVectorOfLengthPred<list<int> allowedLengths> :
And<[IsVectorTypePred,
Or<!foreach(allowedlength, allowedLengths,
CPred<[{$_self.cast<::mlir::VectorType>().getNumElements()
== }]
# allowedlength>)>]>;
// Any vector where the number of elements is from the given
// `allowedLengths` list
class VectorOfLength<list<int> allowedLengths> : Type<
IsVectorOfLengthPred<allowedLengths>,
" of length " # !interleave(allowedLengths, "/"),
"::mlir::VectorType">;
// Any vector where the number of elements is from the given
// `allowedLengths` list and the type is from the given `allowedTypes`
// list
class VectorOfLengthAndType<list<int> allowedLengths,
list<Type> allowedTypes> : Type<
And<[VectorOf<allowedTypes>.predicate,
VectorOfLength<allowedLengths>.predicate]>,
VectorOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
"::mlir::VectorType">;
def AnyVector : VectorOf<[AnyType]>;
// Shaped types.
def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped",
"::mlir::ShapedType">;
// Tensor types.
// Any tensor type whose element type is from the given `allowedTypes` list
class TensorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsTensorTypePred, "tensor",
"::mlir::TensorType">;
class RankedTensorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, And<[IsTensorTypePred, HasRankPred]>,
"ranked tensor", "::mlir::TensorType">;
def AnyTensor : TensorOf<[AnyType]>;
// Unranked Memref type
class UnrankedTensorOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes,
IsUnrankedTensorTypePred,
"unranked.tensor", "::mlir::UnrankedTensorType">;
def AnyRankedTensor : RankedTensorOf<[AnyType]>;
// TODO: Have an easy way to add another constraint to a type.
class StaticShapeTensorOf<list<Type> allowedTypes>
: Type<And<[TensorOf<allowedTypes>.predicate, HasStaticShapePred]>,
"statically shaped " # TensorOf<allowedTypes>.summary,
"::mlir::TensorType">;
def AnyStaticShapeTensor : StaticShapeTensorOf<[AnyType]>;
def I1Tensor : TensorOf<[I1]>;
def I8Tensor : TensorOf<[I8]>;
def I16Tensor : TensorOf<[I16]>;
def I32Tensor : TensorOf<[I32]>;
def I64Tensor : TensorOf<[I64]>;
def IndexTensor: TensorOf<[Index]>;
def BF16Tensor : TensorOf<[BF16]>;
def F16Tensor : TensorOf<[F16]>;
def F32Tensor : TensorOf<[F32]>;
def F64Tensor : TensorOf<[F64]>;
// Ranked tensor type with one of the specified types and ranks.
class TensorRankOf<list<Type> allowedTypes, list<int> ranks> :
Type<And<[TensorOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
TensorOf<allowedTypes>.summary, "::mlir::TensorType">;
class 0DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [0]>;
class 1DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [1]>;
class 2DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [2]>;
class 3DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [3]>;
class 4DTensorOf<list<Type> allowedTypes> : TensorRankOf<allowedTypes, [4]>;
// Unranked Memref type
class UnrankedMemRefOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes,
IsUnrankedMemRefTypePred, "unranked.memref",
"::mlir::UnrankedMemRefType">;
def AnyUnrankedMemRef : UnrankedMemRefOf<[AnyType]>;
// Memref type.
// Memrefs are blocks of data with fixed type and rank.
class MemRefOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsMemRefTypePred, "memref",
"::mlir::MemRefType">;
def AnyMemRef : MemRefOf<[AnyType]>;
class RankedOrUnrankedMemRefOf<list<Type> allowedTypes>:
AnyTypeOf<[UnrankedMemRefOf<allowedTypes>, MemRefOf<allowedTypes>]>;
def AnyRankedOrUnrankedMemRef: AnyTypeOf<[AnyUnrankedMemRef, AnyMemRef]>;
// Memref declarations handle any memref, independent of rank, size, (static or
// dynamic), layout, or memory space.
def I1MemRef : MemRefOf<[I1]>;
def I8MemRef : MemRefOf<[I8]>;
def I16MemRef : MemRefOf<[I16]>;
def I32MemRef : MemRefOf<[I32]>;
def I64MemRef : MemRefOf<[I64]>;
def BF16MemRef : MemRefOf<[BF16]>;
def F16MemRef : MemRefOf<[F16]>;
def F32MemRef : MemRefOf<[F32]>;
def F64MemRef : MemRefOf<[F64]>;
// TODO: Have an easy way to add another constraint to a type.
class MemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
Type<And<[MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
MemRefOf<allowedTypes>.summary,
"::mlir::MemRefType">;
class StaticShapeMemRefOf<list<Type> allowedTypes>
: Type<And<[MemRefOf<allowedTypes>.predicate, HasStaticShapePred]>,
"statically shaped " # MemRefOf<allowedTypes>.summary,
"::mlir::MemRefType">;
def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>;
// For a MemRefType, verify that it has strides.
def HasStridesPred : CPred<[{ isStrided($_self.cast<::mlir::MemRefType>()) }]>;
class StridedMemRefOf<list<Type> allowedTypes>
: Type<And<[MemRefOf<allowedTypes>.predicate, HasStridesPred]>,
"strided " # MemRefOf<allowedTypes>.summary>;
def AnyStridedMemRef : StridedMemRefOf<[AnyType]>;
class AnyStridedMemRefOfRank<int rank> :
Type<And<[AnyStridedMemRef.predicate,
MemRefRankOf<[AnyType], [rank]>.predicate]>,
AnyStridedMemRef.summary # " of rank " # rank>;
class StridedMemRefRankOf<list<Type> allowedTypes, list<int> ranks> :
Type<And<[MemRefOf<allowedTypes>.predicate, HasAnyRankOfPred<ranks>]>,
!interleave(!foreach(rank, ranks, rank # "D"), "/") # " " #
MemRefOf<allowedTypes>.summary>;
// This represents a generic tuple without any constraints on element type.
def AnyTuple : Type<IsTupleTypePred, "tuple", "::mlir::TupleType">;
// A container type that has other types embedded in it, but (unlike
// ContainerType) can hold elements with a mix of types. Requires a call that
// produces a list of all elements' types.
class MixedContainerType<Type etype, Pred containerPred, code elementTypesCall,
string descr> :
Type<
And<[
containerPred,
Concat<
"::llvm::all_of(" # elementTypesCall # ", [](Type t) { return ",
SubstLeaves<"$_self", "t", etype.predicate>,
"; })"
>
]>,
descr # " with any combination of " # etype.summary # " values"> {
// The type of elements in the container.
Type elementType = etype;
// Call to retrieve.
code getElementTypesCall = elementTypesCall;
}
// A Tuple that holds a mix of elements of the allowed types.
class TupleOf<list<Type> allowedTypes>
: MixedContainerType<AnyTypeOf<allowedTypes>, IsTupleTypePred,
"$_self.cast<::mlir::TupleType>().getTypes()",
"tuple">;
// A Tuple with arbitrary nesting, where all elements are a mix of the allowed
// types.
class NestedTupleOf<list<Type> allowedTypes> :
MixedContainerType<AnyTypeOf<allowedTypes>, IsTupleTypePred,
"getFlattenedTypes($_self.cast<::mlir::TupleType>())",
"nested tuple">;
//===----------------------------------------------------------------------===//
// Common type constraints
//===----------------------------------------------------------------------===//
// Type constraint for bool-like types: bools, vectors of bools, tensors of
// bools.
def BoolLike : TypeConstraint<Or<[I1.predicate, VectorOf<[I1]>.predicate,
TensorOf<[I1]>.predicate]>,
"bool-like">;
// Type constraint for signless-integer-like types: signless integers, indices,
// vectors of signless integers or indices, tensors of signless integers.
def SignlessIntegerLike : TypeConstraint<Or<[
AnySignlessInteger.predicate, Index.predicate,
VectorOf<[AnySignlessInteger, Index]>.predicate,
TensorOf<[AnySignlessInteger, Index]>.predicate]>,
"signless-integer-like">;
// Type constraint for float-like types: floats, vectors or tensors thereof.
def FloatLike : TypeConstraint<Or<[AnyFloat.predicate,
VectorOf<[AnyFloat]>.predicate, TensorOf<[AnyFloat]>.predicate]>,
"floating-point-like">;
// Type constraint for signless-integer-like or float-like types.
def SignlessIntegerOrFloatLike : TypeConstraint<Or<[
SignlessIntegerLike.predicate, FloatLike.predicate]>,
"signless-integer-like or floating-point-like">;
//===----------------------------------------------------------------------===//
// Attribute definitions
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Base attribute definition
// Base class for all attributes.
class Attr<Pred condition, string descr = ""> :
AttrConstraint<condition, descr> {
code storageType = ?; // The backing mlir::Attribute type
code returnType = ?; // The underlying C++ value type
// The call expression to convert from the storage type to the return
// type. For example, an enum can be stored as an int but returned as an
// enum class.
//
// Format: $_self will be expanded to the attribute.
//
// For example, `$_self.getValue().getSExtValue()` for `IntegerAttr val` will
// expand to `getAttrOfType<IntegerAttr>("val").getValue().getSExtValue()`.
code convertFromStorage = "$_self.getValue()";
// The call expression to build an attribute from a constant value.
//
// Format: $0 will be expanded to the constant value of the attribute.
//
// For example, `$_builder.getStringAttr("$0")` for `StringAttr:"foo"` will
// expand to `builder.getStringAttr("foo")`.
string constBuilderCall = ?;
// Default value for attribute.
// Requires a constBuilderCall defined.
string defaultValue = ?;
// The value type of this attribute. This corresponds to the mlir::Type that
// this attribute returns via `getType()`.
Type valueType = ?;
// Whether the attribute is optional. Typically requires a custom
// convertFromStorage method to handle the case where the attribute is
// not present.
bit isOptional = 0;
// What is the base-level Attr instantiation that this Attr is built upon.
// Unset means this is a base-level Attr.
//
// This field is used by attribute wrapper classes (DefaultValuedAttr,
// OptionalAttr, etc.) to retrieve the base-level attribute definition.
// This can be used for getting its name; otherwise, we will see
// "anonymous_<number>" as the attribute def name because of template
// instantiation.
// TOOD(b/132458159): deduplicate the fields in attribute wrapper classes.
Attr baseAttr = ?;
// The fully-qualified C++ namespace where the generated class lives.
string cppNamespace = "";
}
// An attribute of a specific dialect.
class DialectAttr<Dialect d, Pred condition, string descr = ""> :
Attr<condition, descr> {
Dialect dialect = d;
let cppNamespace = d.cppNamespace;
}
//===----------------------------------------------------------------------===//
// Attribute modifier definition
// Decorates an attribute to have an (unvalidated) default value if not present.
class DefaultValuedAttr<Attr attr, string val> :
Attr<attr.predicate, attr.summary> {
// Construct this attribute with the input attribute and change only
// the default value.
// Note: this has to be kept up to date with Attr above.
let storageType = attr.storageType;
let returnType = attr.returnType;
let convertFromStorage = attr.convertFromStorage;
let constBuilderCall = attr.constBuilderCall;
let defaultValue = val;
let valueType = attr.valueType;
let baseAttr = attr;
}
// Decorates an attribute as optional. The return type of the generated
// attribute accessor method will be Optional<>.
class OptionalAttr<Attr attr> : Attr<attr.predicate, attr.summary> {
// Rewrite the attribute to be optional.
// Note: this has to be kept up to date with Attr above.
let storageType = attr.storageType;
let returnType = "::llvm::Optional<" # attr.returnType #">";
let convertFromStorage = "$_self ? " # returnType # "(" #
attr.convertFromStorage # ") : (::llvm::None)";
let valueType = attr.valueType;
let isOptional = 1;
let baseAttr = attr;
}
//===----------------------------------------------------------------------===//
// Primitive attribute kinds
// A generic attribute that must be constructed around a specific buildable type
// `attrValType`. Backed by MLIR attribute kind `attrKind`.
class TypedAttrBase<Type attrValType, string attrKind, Pred condition,
string descr> :
Attr<condition, descr> {
let constBuilderCall = "$_builder.get" # attrKind # "(" #
attrValType.builderCall # ", $0)";
let storageType = "::mlir::" # attrKind;
let valueType = attrValType;
}
// Any attribute.
def AnyAttr : Attr<CPred<"true">, "any attribute"> {
let storageType = "::mlir::Attribute";
let returnType = "::mlir::Attribute";
let convertFromStorage = "$_self";
let constBuilderCall = "$0";
}
def BoolAttr : Attr<CPred<"$_self.isa<::mlir::BoolAttr>()">, "bool attribute"> {
let storageType = [{ ::mlir::BoolAttr }];
let returnType = [{ bool }];
let valueType = I1;
let constBuilderCall = "$_builder.getBoolAttr($0)";
}
// Index attribute.
def IndexAttr :
TypedAttrBase<
Index, "IntegerAttr",
And<[CPred<"$_self.isa<::mlir::IntegerAttr>()">,
CPred<"$_self.cast<::mlir::IntegerAttr>().getType()"
".isa<::mlir::IndexType>()">]>,
"index attribute"> {
let returnType = [{ ::llvm::APInt }];
}
// Base class for any integer (regardless of signedness semantics) attributes
// of fixed width.
class AnyIntegerAttrBase<AnyI attrValType, string descr> :
TypedAttrBase<
attrValType, "IntegerAttr",
And<[CPred<"$_self.isa<::mlir::IntegerAttr>()">,
CPred<"$_self.cast<::mlir::IntegerAttr>().getType()."
"isInteger(" # attrValType.bitwidth # ")">]>,
descr> {
let returnType = [{ ::llvm::APInt }];
let constBuilderCall = ?;
}
def AnyI1Attr : AnyIntegerAttrBase<AnyI1, "1-bit integer attribute">;
def AnyI8Attr : AnyIntegerAttrBase<AnyI8, "8-bit integer attribute">;
def AnyI16Attr : AnyIntegerAttrBase<AnyI16, "16-bit integer attribute">;
def AnyI32Attr : AnyIntegerAttrBase<AnyI32, "32-bit integer attribute">;
def AnyI64Attr : AnyIntegerAttrBase<AnyI64, "64-bit integer attribute">;
def APIntAttr : Attr<CPred<"$_self.isa<::mlir::IntegerAttr>()">,
"arbitrary integer attribute"> {
let storageType = [{ ::mlir::IntegerAttr }];
let returnType = [{ ::mlir::APInt }];
}
// Base class for signless integer attributes of fixed width.
class SignlessIntegerAttrBase<I attrValType, string descr> :
TypedAttrBase<
attrValType, "IntegerAttr",
And<[CPred<"$_self.isa<::mlir::IntegerAttr>()">,
CPred<"$_self.cast<::mlir::IntegerAttr>().getType()."
"isSignlessInteger(" # attrValType.bitwidth # ")">]>,
descr> {
let returnType = [{ ::llvm::APInt }];
}
// Base class for signless integer attributes of fixed width that have a
// corresponding C++ type.
class TypedSignlessIntegerAttrBase<I attrValType, string retType, string descr>
: SignlessIntegerAttrBase<attrValType, descr> {
let returnType = retType;
let convertFromStorage = "$_self.getValue().getZExtValue()";
}
def I1Attr : TypedSignlessIntegerAttrBase<
I1, "bool", "1-bit signless integer attribute">;
def I8Attr : TypedSignlessIntegerAttrBase<