@@ -78,25 +78,25 @@ class TypeParamCommaFormatter : public llvm::detail::format_adapter {
78
78
// / [...]".
79
79
TypeNamePairs,
80
80
81
- // / Emit ", parameter1Type parameter1Name, parameter2Type parameter2Name,
82
- // / [...]".
83
- TypeNamePairsPrependComma,
84
-
85
81
// / Emit "parameter1(parameter1), parameter2(parameter2), [...]".
86
- TypeNameInitializer
82
+ TypeNameInitializer,
83
+
84
+ // / Emit "param1Name, param2Name, [...]".
85
+ JustParams,
87
86
};
88
87
89
- TypeParamCommaFormatter (EmitFormat emitFormat, ArrayRef<TypeParameter> params)
90
- : emitFormat(emitFormat), params(params) {}
88
+ TypeParamCommaFormatter (EmitFormat emitFormat, ArrayRef<TypeParameter> params,
89
+ bool prependComma = true )
90
+ : emitFormat(emitFormat), params(params), prependComma(prependComma) {}
91
91
92
92
// / llvm::formatv will call this function when using an instance as a
93
93
// / replacement value.
94
94
void format (raw_ostream &os, StringRef options) override {
95
- if (params.size () && emitFormat == EmitFormat::TypeNamePairsPrependComma )
95
+ if (params.size () && prependComma )
96
96
os << " , " ;
97
+
97
98
switch (emitFormat) {
98
99
case EmitFormat::TypeNamePairs:
99
- case EmitFormat::TypeNamePairsPrependComma:
100
100
interleaveComma (params, os,
101
101
[&](const TypeParameter &p) { emitTypeNamePair (p, os); });
102
102
break ;
@@ -105,6 +105,10 @@ class TypeParamCommaFormatter : public llvm::detail::format_adapter {
105
105
emitTypeNameInitializer (p, os);
106
106
});
107
107
break ;
108
+ case EmitFormat::JustParams:
109
+ interleaveComma (params, os,
110
+ [&](const TypeParameter &p) { os << p.getName (); });
111
+ break ;
108
112
}
109
113
}
110
114
@@ -120,6 +124,7 @@ class TypeParamCommaFormatter : public llvm::detail::format_adapter {
120
124
121
125
EmitFormat emitFormat;
122
126
ArrayRef<TypeParameter> params;
127
+ bool prependComma;
123
128
};
124
129
125
130
} // end anonymous namespace
@@ -168,10 +173,9 @@ static const char *const typeDefParsePrint = R"(
168
173
// / The code block for the verifyConstructionInvariants and getChecked.
169
174
// /
170
175
// / {0}: List of parameters, parameters style.
171
- // / {1}: C++ type class name.
172
176
static const char *const typeDefDeclVerifyStr = R"(
173
177
static ::mlir::LogicalResult verifyConstructionInvariants(Location loc{0});
174
- static {1} getChecked(Location loc{0});
178
+ static ::mlir::Type getChecked(Location loc{0});
175
179
)" ;
176
180
177
181
// / Generate the declaration for the given typeDef class.
@@ -194,14 +198,13 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
194
198
os << *extraDecl << " \n " ;
195
199
196
200
TypeParamCommaFormatter emitTypeNamePairsAfterComma (
197
- TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma , params);
201
+ TypeParamCommaFormatter::EmitFormat::TypeNamePairs , params);
198
202
os << llvm::formatv (" static {0} get(::mlir::MLIRContext* ctxt{1});\n " ,
199
203
typeDef.getCppClassName (), emitTypeNamePairsAfterComma);
200
204
201
205
// Emit the verify invariants declaration.
202
206
if (typeDef.genVerifyInvariantsDecl ())
203
- os << llvm::formatv (typeDefDeclVerifyStr, emitTypeNamePairsAfterComma,
204
- typeDef.getCppClassName ());
207
+ os << llvm::formatv (typeDefDeclVerifyStr, emitTypeNamePairsAfterComma);
205
208
206
209
// Emit the mnenomic, if specified.
207
210
if (auto mnenomic = typeDef.getMnemonic ()) {
@@ -317,6 +320,17 @@ static const char *const typeDefStorageClassConstructorReturn = R"(
317
320
}
318
321
)" ;
319
322
323
+ // / The code block for the getChecked definition.
324
+ // /
325
+ // / {0}: List of parameters, parameters style.
326
+ // / {1}: C++ type class name.
327
+ // / {2}: Comma separated list of parameter names.
328
+ static const char *const typeDefDefGetCheckeStr = R"(
329
+ ::mlir::Type {1}::getChecked(Location loc{0}) {{
330
+ return Base::getChecked(loc{2});
331
+ }
332
+ )" ;
333
+
320
334
// / Use tgfmt to emit custom allocation code for each parameter, if necessary.
321
335
static void emitParameterAllocationCode (TypeDef &typeDef, raw_ostream &os) {
322
336
SmallVector<TypeParameter, 4 > parameters;
@@ -355,27 +369,28 @@ static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
355
369
auto parameterTypeList = join (parameterTypes, " , " );
356
370
357
371
// 1) Emit most of the storage class up until the hashKey body.
358
- os << formatv (
359
- typeDefStorageClassBegin, typeDef.getStorageNamespace (),
360
- typeDef.getStorageClassName (),
361
- TypeParamCommaFormatter (
362
- TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
363
- TypeParamCommaFormatter (
364
- TypeParamCommaFormatter::EmitFormat::TypeNameInitializer, parameters),
365
- parameterList, parameterTypeList);
372
+ os << formatv (typeDefStorageClassBegin, typeDef.getStorageNamespace (),
373
+ typeDef.getStorageClassName (),
374
+ TypeParamCommaFormatter (
375
+ TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
376
+ parameters, /* prependComma=*/ false ),
377
+ TypeParamCommaFormatter (
378
+ TypeParamCommaFormatter::EmitFormat::TypeNameInitializer,
379
+ parameters, /* prependComma=*/ false ),
380
+ parameterList, parameterTypeList);
366
381
367
382
// 2) Emit the haskKey method.
368
383
os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n " ;
369
384
// Extract each parameter from the key.
370
385
for (size_t i = 0 , e = parameters.size (); i < e; ++i)
371
- os << formatv (" const auto &{0} = std::get<{1}>(key);\n " ,
372
- parameters[i].getName (), i);
386
+ os << llvm:: formatv (" const auto &{0} = std::get<{1}>(key);\n " ,
387
+ parameters[i].getName (), i);
373
388
// Then combine them all. This requires all the parameters types to have a
374
389
// hash_value defined.
375
- os << " return :: llvm::hash_combine( " ;
376
- interleaveComma (parameterNames, os);
377
- os << " ); \n " ;
378
- os << " } \n " ;
390
+ os << llvm::formatv (
391
+ " return ::llvm::hash_combine({0}); \n } \n " ,
392
+ TypeParamCommaFormatter (TypeParamCommaFormatter::EmitFormat::JustParams,
393
+ parameters, /* prependComma */ false )) ;
379
394
380
395
// 3) Emit the construct method.
381
396
if (typeDef.hasStorageCustomConstructor ())
@@ -462,14 +477,12 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
462
477
463
478
os << llvm::formatv (
464
479
" {0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n "
465
- " return Base::get(ctxt" ,
480
+ " return Base::get(ctxt{2}); \n } \n " ,
466
481
typeDef.getCppClassName (),
467
482
TypeParamCommaFormatter (
468
- TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma,
469
- parameters));
470
- for (TypeParameter ¶m : parameters)
471
- os << " , " << param.getName ();
472
- os << " );\n }\n " ;
483
+ TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
484
+ TypeParamCommaFormatter (TypeParamCommaFormatter::EmitFormat::JustParams,
485
+ parameters));
473
486
474
487
// Emit the parameter accessors.
475
488
if (typeDef.genAccessors ())
@@ -481,6 +494,17 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
481
494
typeDef.getCppClassName ());
482
495
}
483
496
497
+ // Generate getChecked() method.
498
+ if (typeDef.genVerifyInvariantsDecl ()) {
499
+ os << llvm::formatv (
500
+ typeDefDefGetCheckeStr,
501
+ TypeParamCommaFormatter (
502
+ TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
503
+ typeDef.getCppClassName (),
504
+ TypeParamCommaFormatter (TypeParamCommaFormatter::EmitFormat::JustParams,
505
+ parameters));
506
+ }
507
+
484
508
// If mnemonic is specified maybe print definitions for the parser and printer
485
509
// code, if they're specified.
486
510
if (typeDef.getMnemonic ())
0 commit comments