diff --git a/mlir/include/mlir/TableGen/AttrOrTypeDef.h b/mlir/include/mlir/TableGen/AttrOrTypeDef.h index c2aafca0831b0..0a23e0fed56cc 100644 --- a/mlir/include/mlir/TableGen/AttrOrTypeDef.h +++ b/mlir/include/mlir/TableGen/AttrOrTypeDef.h @@ -58,6 +58,9 @@ class AttrOrTypeParameter { /// Get the parameter name. StringRef getName() const; + /// Get the parameter accessor name. + std::string getAccessorName() const; + /// If specified, get the custom allocator code for this parameter. Optional getAllocator() const; diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index b1d0b696009da..6a14e6a0f4d47 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -147,8 +147,8 @@ createLinalgBodyCalculationForElementwiseOp(Operation *op, ValueRange args, cast(op).quantization_info()) { auto quantizationInfo = cast(op).quantization_info(); int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); - int64_t inZp = quantizationInfo.getValue().getInput_zp(); - int64_t outZp = quantizationInfo.getValue().getOutput_zp(); + int64_t inZp = quantizationInfo.getValue().getInputZp(); + int64_t outZp = quantizationInfo.getValue().getOutputZp(); // Compute the maximum value that can occur in the intermediate buffer. int64_t zpAdd = inZp + outZp; @@ -1847,7 +1847,7 @@ class PadConverter : public OpRewritePattern { } else if (elementTy.isa() && !padOp.quantization_info()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); } else if (elementTy.isa() && padOp.quantization_info()) { - int64_t value = padOp.quantization_info().getValue().getInput_zp(); + int64_t value = padOp.quantization_info().getValue().getInputZp(); constantAttr = rewriter.getIntegerAttr(elementTy, value); } if (constantAttr) diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 2154cd98204e0..bd3eb0feca647 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -202,7 +202,7 @@ class ConvConverter : public OpConversionPattern { if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); - int64_t iZp = quantizationInfo.getInput_zp(); + int64_t iZp = quantizationInfo.getInputZp(); int64_t intMin = APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) @@ -274,8 +274,8 @@ class ConvConverter : public OpConversionPattern { if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); - auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp()); - auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp()); + auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()); + auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()); auto iZpVal = rewriter.create(loc, iZp); auto kZpVal = rewriter.create(loc, kZp); @@ -366,8 +366,8 @@ class DepthwiseConvConverter if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); - iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp()); - kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp()); + iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp()); + kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp()); } auto weightShape = weightTy.getShape(); @@ -378,7 +378,7 @@ class DepthwiseConvConverter if (isQuantized) { auto quantizationInfo = op->getAttr("quantization_info").cast(); - int64_t iZp = quantizationInfo.getInput_zp(); + int64_t iZp = quantizationInfo.getInputZp(); int64_t intMin = APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth()) @@ -542,9 +542,9 @@ class MatMulConverter : public OpConversionPattern { auto quantizationInfo = op.quantization_info().getValue(); auto aZp = rewriter.create( - loc, rewriter.getI32IntegerAttr(quantizationInfo.getA_zp())); + loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp())); auto bZp = rewriter.create( - loc, rewriter.getI32IntegerAttr(quantizationInfo.getB_zp())); + loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp())); rewriter.replaceOpWithNewOp( op, TypeRange{op.getType()}, ValueRange{adaptor.a(), adaptor.b(), aZp, bZp}, zeroTensor); @@ -652,9 +652,9 @@ class FullyConnectedConverter auto quantizationInfo = op.quantization_info().getValue(); auto inputZp = rewriter.create( - loc, rewriter.getI32IntegerAttr(quantizationInfo.getInput_zp())); + loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp())); auto outputZp = rewriter.create( - loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeight_zp())); + loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp())); Value matmul = rewriter .create( @@ -892,8 +892,7 @@ class AvgPool2dConverter : public OpRewritePattern { if (op.quantization_info()) { auto quantizationInfo = op.quantization_info().getValue(); auto inputZp = rewriter.create( - loc, - b.getIntegerAttr(accETy, quantizationInfo.getInput_zp())); + loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp())); Value offset = rewriter.create(loc, accETy, countI, inputZp); poolVal = @@ -930,7 +929,7 @@ class AvgPool2dConverter : public OpRewritePattern { auto quantizationInfo = op.quantization_info().getValue(); auto outputZp = rewriter.create( loc, b.getIntegerAttr(scaled.getType(), - quantizationInfo.getOutput_zp())); + quantizationInfo.getOutputZp())); scaled = rewriter.create(loc, scaled, outputZp) .getResult(); } diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp index 17f3412999f3f..b588afa106476 100644 --- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp @@ -145,7 +145,7 @@ spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) { DenseIntElementsAttr spirv::lookupLocalWorkGroupSize(Operation *op) { if (auto entryPoint = spirv::lookupEntryPointABI(op)) - return entryPoint.getLocal_size(); + return entryPoint.getLocalSize(); return {}; } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 3a01557f5c9de..3ea2224a4408a 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -135,7 +135,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp, funcOp.getLoc(), executionModel.getValue(), funcOp, interfaceVars); // Specifies the spv.ExecutionModeOp. - auto localSizeAttr = entryPointAttr.getLocal_size(); + auto localSizeAttr = entryPointAttr.getLocalSize(); if (localSizeAttr) { auto values = localSizeAttr.getValues(); SmallVector localSize(values); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index d84e5d218dc2a..d7203c6afbe46 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -347,7 +347,7 @@ struct MaterializePadValue : public OpRewritePattern { } else if (elementTy.isa() && !op.quantization_info()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); } else if (elementTy.isa() && op.quantization_info()) { - auto value = op.quantization_info().getValue().getInput_zp(); + auto value = op.quantization_info().getValue().getInputZp(); constantAttr = rewriter.getIntegerAttr(elementTy, value); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index 3389dda46e1b0..1db101280ef23 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -214,7 +214,7 @@ class TransposeConvStridedConverter weight = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(weightETy), weight, weightPaddingVal, nullptr, - rewriter.getAttr(quantInfo.getWeight_zp())); + rewriter.getAttr(quantInfo.getWeightZp())); } else { weight = createOpAndInfer(rewriter, loc, @@ -278,7 +278,7 @@ class TransposeConvStridedConverter input = createOpAndInfer( rewriter, loc, UnrankedTensorType::get(inputETy), input, inputPaddingVal, nullptr, - rewriter.getAttr(quantInfo.getInput_zp())); + rewriter.getAttr(quantInfo.getInputZp())); } else { input = createOpAndInfer(rewriter, loc, UnrankedTensorType::get(inputETy), diff --git a/mlir/lib/TableGen/AttrOrTypeDef.cpp b/mlir/lib/TableGen/AttrOrTypeDef.cpp index 444db742bd32e..8467af0ee74c0 100644 --- a/mlir/lib/TableGen/AttrOrTypeDef.cpp +++ b/mlir/lib/TableGen/AttrOrTypeDef.cpp @@ -215,6 +215,11 @@ StringRef AttrOrTypeParameter::getName() const { return def->getArgName(index)->getValue(); } +std::string AttrOrTypeParameter::getAccessorName() const { + return "get" + + llvm::convertToCamelFromSnakeCase(getName(), /*capitalizeFirst=*/true); +} + Optional AttrOrTypeParameter::getAllocator() const { return getDefValue("allocator"); } diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index 6e32b431bf823..6895cb15fcdb5 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -7,16 +7,12 @@ //===----------------------------------------------------------------------===// #include "AttrOrTypeFormatGen.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/TableGen/AttrOrTypeDef.h" #include "mlir/TableGen/Class.h" #include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" -#include "llvm/ADT/Sequence.h" -#include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" #include "llvm/TableGen/Error.h" @@ -31,13 +27,6 @@ using namespace mlir::tblgen; // Utility Functions //===----------------------------------------------------------------------===// -std::string mlir::tblgen::getParameterAccessorName(StringRef name) { - assert(!name.empty() && "parameter has empty name"); - auto ret = "get" + name.str(); - ret[3] = llvm::toUpper(ret[3]); // uppercase first letter of the name - return ret; -} - /// Find all the AttrOrTypeDef for the specified dialect. If no dialect /// specified and can only find one dialect's defs, use that. static void collectAllDefs(StringRef selectedDialect, @@ -288,7 +277,7 @@ void DefGen::emitParserPrinter() { void DefGen::emitAccessors() { for (auto ¶m : params) { Method *m = defCls.addMethod( - param.getCppAccessorType(), getParameterAccessorName(param.getName()), + param.getCppAccessorType(), param.getAccessorName(), def.genStorageClass() ? Method::Const : Method::ConstDeclaration); // Generate accessor definitions only if we also generate the storage // class. Otherwise, let the user define the exact accessor definition. diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp index e1b0774e2a2b8..a31943790aada 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -58,7 +58,7 @@ class ParameterElement /// Generate the code to check whether the parameter should be printed. MethodBody &genPrintGuard(FmtContext &ctx, MethodBody &os) const { - std::string self = getParameterAccessorName(getName()) + "()"; + std::string self = param.getAccessorName() + "()"; ctx.withSelf(self); os << tgfmt("($_self", &ctx); if (llvm::Optional defaultValue = getParam().getDefaultValue()) { @@ -718,7 +718,7 @@ void DefFormat::genLiteralPrinter(StringRef value, FmtContext &ctx, void DefFormat::genVariablePrinter(ParameterElement *el, FmtContext &ctx, MethodBody &os, bool skipGuard) { const AttrOrTypeParameter ¶m = el->getParam(); - ctx.withSelf(getParameterAccessorName(param.getName()) + "()"); + ctx.withSelf(param.getAccessorName() + "()"); // Guard the printer on the presence of optional parameters and that they // aren't equal to their default values (if they have one). @@ -812,8 +812,7 @@ void DefFormat::genCustomPrinter(CustomDirective *el, FmtContext &ctx, if (auto *ref = dyn_cast(arg)) param = ref->getArg(); os << ",\n" - << getParameterAccessorName(cast(param)->getName()) - << "()"; + << cast(param)->getParam().getAccessorName() << "()"; } os.unindent() << ");\n"; } diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h index c371aee268b42..d4711532a79bb 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.h @@ -20,11 +20,6 @@ class AttrOrTypeDef; void generateAttrOrTypeFormat(const AttrOrTypeDef &def, MethodBody &parser, MethodBody &printer); -/// From the parameter name, get the name of the accessor function in camelcase. -/// The first letter of the parameter is upper-cased and prefixed with "get". -/// E.g. 'value' -> 'getValue'. -std::string getParameterAccessorName(llvm::StringRef name); - } // namespace tblgen } // namespace mlir