Skip to content

Commit f402e68

Browse files
committed
[MLIR] ODS TypeDefs: getChecked() and internal enhancements
Have the ODS TypeDef generator write the getChecked() definition. Also add to TypeParamCommaFormatter a `JustParams` format and refactor around that. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D89438
1 parent 03c267b commit f402e68

File tree

3 files changed

+61
-36
lines changed

3 files changed

+61
-36
lines changed

mlir/test/lib/Dialect/Test/TestTypeDefs.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def IntegerType : Test_Type<"TestInteger"> {
7575
int width;
7676
if ($_parser.parseInteger(width)) return Type();
7777
if ($_parser.parseGreater()) return Type();
78-
return get(ctxt, signedness, width);
78+
Location loc = $_parser.getEncodedSourceLoc($_parser.getNameLoc());
79+
return getChecked(loc, signedness, width);
7980
}];
8081

8182
// Any extra code one wants in the type's class declaration.

mlir/test/mlir-tblgen/typedefs.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def B_CompoundTypeA : TestType<"CompoundA"> {
5151

5252
// DECL-LABEL: class CompoundAType: public ::mlir::Type
5353
// DECL: static ::mlir::LogicalResult verifyConstructionInvariants(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
54-
// DECL: static CompoundAType getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
54+
// DECL: static ::mlir::Type getChecked(Location loc, int widthOfSomething, ::mlir::test::SimpleTypeA exampleTdType, SomeCppStruct exampleCppType, ::llvm::ArrayRef<int> dims);
5555
// DECL: static ::llvm::StringRef getMnemonic() { return "cmpnd_a"; }
5656
// DECL: static ::mlir::Type parse(::mlir::MLIRContext* ctxt, ::mlir::DialectAsmParser& parser);
5757
// DECL: void print(::mlir::DialectAsmPrinter& printer) const;

mlir/tools/mlir-tblgen/TypeDefGen.cpp

Lines changed: 58 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,25 @@ class TypeParamCommaFormatter : public llvm::detail::format_adapter {
7878
/// [...]".
7979
TypeNamePairs,
8080

81-
/// Emit ", parameter1Type parameter1Name, parameter2Type parameter2Name,
82-
/// [...]".
83-
TypeNamePairsPrependComma,
84-
8581
/// Emit "parameter1(parameter1), parameter2(parameter2), [...]".
86-
TypeNameInitializer
82+
TypeNameInitializer,
83+
84+
/// Emit "param1Name, param2Name, [...]".
85+
JustParams,
8786
};
8887

89-
TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef<TypeParameter> params)
90-
: emitFormat(emitFormat), params(params) {}
88+
TypeParamCommaFormatter(EmitFormat emitFormat, ArrayRef<TypeParameter> params,
89+
bool prependComma = true)
90+
: emitFormat(emitFormat), params(params), prependComma(prependComma) {}
9191

9292
/// llvm::formatv will call this function when using an instance as a
9393
/// replacement value.
9494
void format(raw_ostream &os, StringRef options) override {
95-
if (params.size() && emitFormat == EmitFormat::TypeNamePairsPrependComma)
95+
if (params.size() && prependComma)
9696
os << ", ";
97+
9798
switch (emitFormat) {
9899
case EmitFormat::TypeNamePairs:
99-
case EmitFormat::TypeNamePairsPrependComma:
100100
interleaveComma(params, os,
101101
[&](const TypeParameter &p) { emitTypeNamePair(p, os); });
102102
break;
@@ -105,6 +105,10 @@ class TypeParamCommaFormatter : public llvm::detail::format_adapter {
105105
emitTypeNameInitializer(p, os);
106106
});
107107
break;
108+
case EmitFormat::JustParams:
109+
interleaveComma(params, os,
110+
[&](const TypeParameter &p) { os << p.getName(); });
111+
break;
108112
}
109113
}
110114

@@ -120,6 +124,7 @@ class TypeParamCommaFormatter : public llvm::detail::format_adapter {
120124

121125
EmitFormat emitFormat;
122126
ArrayRef<TypeParameter> params;
127+
bool prependComma;
123128
};
124129

125130
} // end anonymous namespace
@@ -168,10 +173,9 @@ static const char *const typeDefParsePrint = R"(
168173
/// The code block for the verifyConstructionInvariants and getChecked.
169174
///
170175
/// {0}: List of parameters, parameters style.
171-
/// {1}: C++ type class name.
172176
static const char *const typeDefDeclVerifyStr = R"(
173177
static ::mlir::LogicalResult verifyConstructionInvariants(Location loc{0});
174-
static {1} getChecked(Location loc{0});
178+
static ::mlir::Type getChecked(Location loc{0});
175179
)";
176180

177181
/// Generate the declaration for the given typeDef class.
@@ -194,14 +198,13 @@ static void emitTypeDefDecl(const TypeDef &typeDef, raw_ostream &os) {
194198
os << *extraDecl << "\n";
195199

196200
TypeParamCommaFormatter emitTypeNamePairsAfterComma(
197-
TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma, params);
201+
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, params);
198202
os << llvm::formatv(" static {0} get(::mlir::MLIRContext* ctxt{1});\n",
199203
typeDef.getCppClassName(), emitTypeNamePairsAfterComma);
200204

201205
// Emit the verify invariants declaration.
202206
if (typeDef.genVerifyInvariantsDecl())
203-
os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma,
204-
typeDef.getCppClassName());
207+
os << llvm::formatv(typeDefDeclVerifyStr, emitTypeNamePairsAfterComma);
205208

206209
// Emit the mnenomic, if specified.
207210
if (auto mnenomic = typeDef.getMnemonic()) {
@@ -317,6 +320,17 @@ static const char *const typeDefStorageClassConstructorReturn = R"(
317320
}
318321
)";
319322

323+
/// The code block for the getChecked definition.
324+
///
325+
/// {0}: List of parameters, parameters style.
326+
/// {1}: C++ type class name.
327+
/// {2}: Comma separated list of parameter names.
328+
static const char *const typeDefDefGetCheckeStr = R"(
329+
::mlir::Type {1}::getChecked(Location loc{0}) {{
330+
return Base::getChecked(loc{2});
331+
}
332+
)";
333+
320334
/// Use tgfmt to emit custom allocation code for each parameter, if necessary.
321335
static void emitParameterAllocationCode(TypeDef &typeDef, raw_ostream &os) {
322336
SmallVector<TypeParameter, 4> parameters;
@@ -355,27 +369,28 @@ static void emitStorageClass(TypeDef typeDef, raw_ostream &os) {
355369
auto parameterTypeList = join(parameterTypes, ", ");
356370

357371
// 1) Emit most of the storage class up until the hashKey body.
358-
os << formatv(
359-
typeDefStorageClassBegin, typeDef.getStorageNamespace(),
360-
typeDef.getStorageClassName(),
361-
TypeParamCommaFormatter(
362-
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
363-
TypeParamCommaFormatter(
364-
TypeParamCommaFormatter::EmitFormat::TypeNameInitializer, parameters),
365-
parameterList, parameterTypeList);
372+
os << formatv(typeDefStorageClassBegin, typeDef.getStorageNamespace(),
373+
typeDef.getStorageClassName(),
374+
TypeParamCommaFormatter(
375+
TypeParamCommaFormatter::EmitFormat::TypeNamePairs,
376+
parameters, /*prependComma=*/false),
377+
TypeParamCommaFormatter(
378+
TypeParamCommaFormatter::EmitFormat::TypeNameInitializer,
379+
parameters, /*prependComma=*/false),
380+
parameterList, parameterTypeList);
366381

367382
// 2) Emit the haskKey method.
368383
os << " static ::llvm::hash_code hashKey(const KeyTy &key) {\n";
369384
// Extract each parameter from the key.
370385
for (size_t i = 0, e = parameters.size(); i < e; ++i)
371-
os << formatv(" const auto &{0} = std::get<{1}>(key);\n",
372-
parameters[i].getName(), i);
386+
os << llvm::formatv(" const auto &{0} = std::get<{1}>(key);\n",
387+
parameters[i].getName(), i);
373388
// Then combine them all. This requires all the parameters types to have a
374389
// hash_value defined.
375-
os << " return ::llvm::hash_combine(";
376-
interleaveComma(parameterNames, os);
377-
os << ");\n";
378-
os << " }\n";
390+
os << llvm::formatv(
391+
" return ::llvm::hash_combine({0});\n }\n",
392+
TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
393+
parameters, /* prependComma */ false));
379394

380395
// 3) Emit the construct method.
381396
if (typeDef.hasStorageCustomConstructor())
@@ -462,14 +477,12 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
462477

463478
os << llvm::formatv(
464479
"{0} {0}::get(::mlir::MLIRContext* ctxt{1}) {{\n"
465-
" return Base::get(ctxt",
480+
" return Base::get(ctxt{2});\n}\n",
466481
typeDef.getCppClassName(),
467482
TypeParamCommaFormatter(
468-
TypeParamCommaFormatter::EmitFormat::TypeNamePairsPrependComma,
469-
parameters));
470-
for (TypeParameter &param : parameters)
471-
os << ", " << param.getName();
472-
os << ");\n}\n";
483+
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
484+
TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
485+
parameters));
473486

474487
// Emit the parameter accessors.
475488
if (typeDef.genAccessors())
@@ -481,6 +494,17 @@ static void emitTypeDefDef(TypeDef typeDef, raw_ostream &os) {
481494
typeDef.getCppClassName());
482495
}
483496

497+
// Generate getChecked() method.
498+
if (typeDef.genVerifyInvariantsDecl()) {
499+
os << llvm::formatv(
500+
typeDefDefGetCheckeStr,
501+
TypeParamCommaFormatter(
502+
TypeParamCommaFormatter::EmitFormat::TypeNamePairs, parameters),
503+
typeDef.getCppClassName(),
504+
TypeParamCommaFormatter(TypeParamCommaFormatter::EmitFormat::JustParams,
505+
parameters));
506+
}
507+
484508
// If mnemonic is specified maybe print definitions for the parser and printer
485509
// code, if they're specified.
486510
if (typeDef.getMnemonic())

0 commit comments

Comments
 (0)